3

私はニュース推薦システムをやっているので、読んでいるユーザーやニュースのためのテーブルを作る必要があります。私の生のちょうどこのようなデータ:スパークSQLを使用して、グループベースの後にスパースベクトルを追加する方法はありますか?

001436800277225 ["9161492","9161787","9378531"] 
009092130698762 ["9394697"] 
010003000431538 ["9394697","9426473","9428530"] 
010156461231357 ["9350394","9414181"] 
010216216021063 ["9173862","9247870"] 
010720006581483 ["9018786"] 
011199797794333 ["9017977","9091134","9142852","9325464","9331913"] 
011337201765123 ["9161294","9198693"] 
011414545455156 ["9168185","9178348","9182782","9359776"] 
011425002581540 ["9083446","9161294","9309432"] 

と私は爆発し、1つのホットエンコーディング、その後

df = getdf() 
df1 = df.select('uuid',explode('news').alias('news')) 
stringIndexer = StringIndexer(inputCol="news", outputCol="newsIndex") 
model = stringIndexer.fit(df1) 
indexed = model.transform(df1) 
encoder = OneHotEncoder(inputCol="newsIndex", outputCol="newsVec") 
encoded = encoder.transform(indexed) 
encoded.show(20,False) 

ん火花-SQLを使用するには、私のデータは次のようになります。

+---------------+-------+---------+----------------------+ 
|uuid   |news |newsIndex|newsVec    | 
+---------------+-------+---------+----------------------+ 
|014324000386050|9398253|10415.0 |(105721,[10415],[1.0])| 
|014324000386050|9428530|70.0  |(105721,[70],[1.0]) | 
|014324000631752|654112 |1717.0 |(105721,[1717],[1.0]) | 
|014324000674240|730531 |2282.0 |(105721,[2282],[1.0]) | 
|014324000674240|694306 |1268.0 |(105721,[1268],[1.0]) | 
|014324000674240|712016 |4766.0 |(105721,[4766],[1.0]) | 
|014324000674240|672307 |7318.0 |(105721,[7318],[1.0]) | 
|014324000674240|698073 |1241.0 |(105721,[1241],[1.0]) | 
|014324000674240|728044 |5302.0 |(105721,[5302],[1.0]) | 
|014324000674240|672256 |1619.0 |(105721,[1619],[1.0]) | 
|014324000674240|730236 |2376.0 |(105721,[2376],[1.0]) | 
|014324000674240|730235 |14274.0 |(105721,[14274],[1.0])| 
|014324000674240|728509 |1743.0 |(105721,[1743],[1.0]) | 
|014324000674240|704528 |10310.0 |(105721,[10310],[1.0])| 
|014324000715399|774134 |8876.0 |(105721,[8876],[1.0]) | 
|014324000725836|9357431|3479.0 |(105721,[3479],[1.0]) | 
|014324000725836|9358028|15621.0 |(105721,[15621],[1.0])| 
|014324000730349|812106 |4599.0 |(105721,[4599],[1.0]) | 
|014324000730349|699237 |754.0 |(105721,[754],[1.0]) | 
|014324000730349|748109 |4854.0 |(105721,[4854],[1.0]) | 
+---------------+-------+---------+----------------------+ 

しかし1 idには複数の行があるので、groupBy('uuid')とし、addこれらのベクトルにします。しかし、単にgroupByを使用し、追加するとエラーが発生します。どうすればいい?

答えて

1

indexedから開始して、をリストとして収集し、udfを使用してSparseVectorに変換することができます。

スパースベクトルを宣言するには、フィーチャの数と位置と値を含むタプルのリストが必要です。カテゴリ変数を扱っているので、使用する値は1.0です。そしてインデックスは、列newsIndex次のようになります。

from pyspark.sql.functions import collect_list, max, lit 
from pyspark.ml.linalg import Vectors, VectorUDT 

def encode(arr, length): 

    vec_args = length, [(x,1.0) for x in arr] 

    return Vectors.sparse(*vec_args) 

encode_udf = udf(encode, VectorUDT()) 

機能の数は、(StrinIndexer0.0から始まるため)max(newsIndex) + 1次のとおりです。

feats = indexed.agg(max(indexed["newsIndex"])).take(1)[0][0] + 1 

はすべて一緒にそれを持って来る:

indexed.groupBy("uuid") \ 
     .agg(collect_list("newsIndex") 
     .alias("newsArr")) \ 
     .select("uuid", 
       encode_udf("newsArr", lit(feats)) 
       .alias("OHE")) \ 
     .show(truncate = False) 
+---------------+-----------------------------------------+ 
|uuid   |OHE          | 
+---------------+-----------------------------------------+ 
|009092130698762|(24,[0],[1.0])       | 
|010003000431538|(24,[0,3,15],[1.0,1.0,1.0])    | 
|010720006581483|(24,[11],[1.0])       | 
|010216216021063|(24,[10,22],[1.0,1.0])     | 
|001436800277225|(24,[2,12,23],[1.0,1.0,1.0])    | 
|011425002581540|(24,[1,5,9],[1.0,1.0,1.0])    | 
|010156461231357|(24,[13,18],[1.0,1.0])     | 
|011199797794333|(24,[7,8,17,19,20],[1.0,1.0,1.0,1.0,1.0])| 
|011414545455156|(24,[4,6,14,21],[1.0,1.0,1.0,1.0])  | 
|011337201765123|(24,[1,16],[1.0,1.0])     | 
+---------------+-----------------------------------------+ 
関連する問題