2017-05-19 2 views
0

SQLデータフレームの列の要素を乗算するのが難しいです。pyspark SQLデータフレームのスパースベクトルの行を掛ける

sv1 = Vectors.sparse(3, [0, 2], [1.0, 3.0]) 
sv2 = Vectors.sparse(3, [0, 1], [2.0, 4.0]) 

def xByY(x,y): 
    return np.multiply(x,y) 

print(xByY(sv1, sv2)) 

上記は機能します。

ただし、以下は該当しません。

xByY_udf = udf(xByY) 

tempDF = sqlContext.createDataFrame([(sv1, sv2), (sv1, sv2)], ('v1', 'v2')) 
tempDF.show() 

print(tempDF.select(xByY_udf('v1', 'v2')).show()) 

多くの感謝!

+0

あなたが得るエラーが何でありますか? –

答えて

0

あなたのudfSparseVectorを返すようにしたい場合は、私たちが最初にあなたの関数の出力を変更する必要があり、第二にVectorUDT()udfの出力スキーマを設定します:

SparseVectorを宣言するには、私たちが必要サイズ元の配列の、両指数値は非ゼロ要素の

from pyspark.ml.linalg import Vectors, VectorUDT 

def xByY(x,y): 
    res = np.multiply(x,y).tolist() 
    vec_args = len(res), [i for i,x in enumerate(res) if x != 0], [x for x in res if x != 0] 
    return Vectors.sparse(*vec_args) 

今、私たちは私たちのudfを宣言し、それをテストすることができます:私たちは、乗算の中間結果がlistある場合は、これらの使用len()とリスト内包表記を見つけることができます

xByY_udf = udf(xByY, VectorUDT()) 
tempDF.select(xByY_udf('v1', 'v2')).show() 
+-------------+ 
| xByY(v1, v2)| 
+-------------+ 
|(3,[0],[2.0])| 
|(3,[0],[2.0])| 
+-------------+ 
+0

ありがとう!これ(ほとんど)作品!どのように疎ベクトルを戻すのですか? –

+0

更新@ f.gをご覧ください。 – mtoto

+0

完全に輝かしい、ありがとう。 (UDFの出力スキーマの仕様は、常に慎重に考える必要があります)。 –

関連する問題