別の列でグループ化するときに、CSVから単語数を取得しようとしています。私のcsvにはid、message、user_idという3つのカラムがあります。私はこの中を読み、メッセージを分割し、ユニグラムのリストを格納します。私のデータフレームdf
与え、pysparkでgroupByデータに関数を適用する
+-----------------+--------------------+--------------------+
| id| message| user_id|
+-----------------+--------------------+--------------------+
|10100720363468236|[i'm, sad, to, mi...|dceafb541a1b8e894...|
|10100718944611636|[what, does, the,...|dceafb541a1b8e894...|
|10100718890699676|[at, the, oecd, w...|dceafb541a1b8e894...|
+-----------------+--------------------+--------------------+
次へ]を、私はuser_id
によってグループ化すると、その後ユニグラムごとにカウントを取得します。シンプルな最初のパスとして、私はuser_id
によってグループ化しようとしたグループ化されたメッセージフィールドの長さを取得:
from collections import Counter
from pyspark.sql.types import ArrayType, StringType, IntegerType
from pyspark.sql.functions import udf
df = self.session.read.csv(self.corptable, header=True,
mode="DROPMALFORMED",)
# split my messages ....
# message is now ArrayType(StringType())
grouped = df.groupBy(df["user_id"])
counter = udf(lambda l: len(l), ArrayType(StringType()))
grouped.agg(counter(df["message"]))
print(grouped.collect())
私は次のエラーを取得する:
pyspark.sql.utils.AnalysisException: "expression '`message`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;"
このエラーを回避する方法がわからないし。一般的に、別の列をグループ化するときに、どの列に関数を適用しますか?私はいつもユーザー定義関数を作成する必要がありますか? Sparkにとって非常に新しい。
編集:ここでは別のPythonのファイルにトークナイザ与え、私はこれを解決する方法である:
group_field = "user_id"
message_field = "message"
context = SparkContext()
session = SparkSession\
.builder\
.appName("dlastk")\
.getOrCreate()
# add tokenizer
context.addPyFile(tokenizer_path)
from tokenizer import Tokenizer
tokenizer = Tokenizer()
spark_tokenizer = udf(tokenizer.tokenize, ArrayType(StringType()))
df = session.read.csv("myFile.csv", header=True,)
df = df[group_field, message_field]
# tokenize the message field
df = df.withColumn(message_field, spark_tokenizer(df[message_field]))
# create ngrams from tokenized messages
n = 1
grouped = df.rdd.map(lambda row: (row[0], Counter([" ".join(x) for x in zip(*[row[1][i:] for i in range(n)])]))).reduceByKey(add)
# flatten the rdd so that each row contains (group_id, ngram, count, relative frequency
flat = grouped.flatMap(lambda row: [[row[0], x,y, y/sum(row[1].values())] for x,y in row[1].items()])
# rdd -> DF
flat = flat.toDF()
flat.write.csv("myNewCSV.csv")
データは次のようになります。
# after read
+--------------------+--------------------+
| user_id| message|
+--------------------+--------------------+
|00035fb0dcfbeaa8b...|To the douchebag ...|
|00035fb0dcfbeaa8b...| T minus 1 week...|
|00035fb0dcfbeaa8b...|Last full day of ...|
+--------------------+--------------------+
# after tokenize
+--------------------+--------------------+
| user_id| message|
+--------------------+--------------------+
|00035fb0dcfbeaa8b...|[to, the, doucheb...|
|00035fb0dcfbeaa8b...|[t, minus, 1, wee...|
|00035fb0dcfbeaa8b...|[last, full, day,...|
+--------------------+--------------------+
# grouped: after 1grams extracted and Counters added
[('00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', Counter({'!': 545, '.': 373, 'the': 306, '"': 225, ...
# flat: after calculating sum and relative frequency for each 1gram
[['00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', 'face', 3, 0.000320547066994337], ['00035fb0dcfbeaa8bb70ffe24d614d4dcee446b803eb4063dccf14dd2a474611', 'was', 26, 0.002778074580617587] ....
# after flat RDD to DF
+--------------------+---------+---+--------------------+
| _1| _2| _3| _4|
+--------------------+---------+---+--------------------+
|00035fb0dcfbeaa8b...| face| 3| 3.20547066994337E-4|
|00035fb0dcfbeaa8b...| was| 26|0.002778074580617587|
|00035fb0dcfbeaa8b...| how| 22|0.002350678491291...|
+--------------------+---------+---+--------------------+
その答えは、「あなたは私がやっているいくつかの変種を適用する必要があります。 – Sal