2017-12-20 11 views
1

リストをUDFに渡す必要があります。リストは距離のスコア/カテゴリを決定します。今のところ、すべての距離を4番目のスコアにするのは難しいです。PySpark - UDFへのパラメータとしてのリストを渡す

a= spark.createDataFrame([("A", 20), ("B", 30), ("D", 80)],["Letter", "distances"]) 

from pyspark.sql.functions import udf 
def cate(label, feature_list): 
    if feature_list == 0: 
     return label[4] 
label_list = ["Great", "Good", "OK", "Please Move", "Dead"] 
udf_score=udf(cate, StringType()) 
a.withColumn("category", udf_score(label_list,a["distances"])).show(10) 

このような場合、このエラーが発生します。

Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace: 
py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339) 
    at py4j.Gateway.invoke(Gateway.java:274) 
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) 
    at py4j.commands.CallCommand.execute(CallCommand.java:79) 
    at py4j.GatewayConnection.run(GatewayConnection.java:214) 
    at java.lang.Thread.run(Thread.java:745) 

答えて

2

を希望、このことができます!

from pyspark.sql.functions import udf, col 

#sample data 
a= sqlContext.createDataFrame([("A", 20), ("B", 30), ("D", 80)],["Letter", "distances"]) 
label_list = ["Great", "Good", "OK", "Please Move", "Dead"] 

def cate(label, feature_list): 
    if feature_list == 0: 
     return label[4] 
    else: #you may need to add 'else' condition as well otherwise 'null' will be added in this case 
     return 'I am not sure!' 

def udf_score(label_list): 
    return udf(lambda l: cate(l, label_list)) 
a.withColumn("category", udf_score(label_list)(col("distances"))).show() 

出力は次のとおりです。

+------+---------+--------------+ 
|Letter|distances|  category| 
+------+---------+--------------+ 
|  A|  20|I am not sure!| 
|  B|  30|I am not sure!| 
|  D|  80|I am not sure!| 
+------+---------+--------------+ 
0

データフレームの呼び出しで唯一の引数は、関数が行動したい列の名前になるように、機能をカリー化してみてください。

udf_score=udf(lambda x: cate(label_list,x), StringType()) 
a.withColumn("category", udf_score("distances")).show(10) 
関連する問題