2016-03-23 9 views
5

内部計算にJavaオブジェクトを使用するpyspark pythonで使用するUDFを作成する必要があります。Java UDFを実装し、pysparkから呼び出す

それは私のようなものだろう、単純なパイソンた場合:

def f(x): 
    return 7 
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType()) 

をして使用してそれを呼び出す:

df = sqlContext.range(0,5) 
df2 = df.withColumn("a",fudf(df.id)).show() 

しかし、私は必要な機能の実装は、JavaでありませんPython。私は何とかそれを包む必要があるので、私はそれをPythonから同様の方法で呼び出すことができます。

私の最初の試みは、Javaオブジェクトの実装を行い、それをPythonのpysparkでラップし、それをUDFに変換することでした。これはシリアル化エラーで失敗しました。

Javaコード:

package com.test1.test2; 

public class TestClass1 { 
    Integer internalVal; 
    public TestClass1(Integer val1) { 
     internalVal = val1; 
    } 
    public Integer do_something(Integer val) { 
     return internalVal; 
    }  
} 

pysparkコード:

from py4j.java_gateway import java_import 
from pyspark.sql.functions import udf 
from pyspark.sql.types import IntegerType 
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
audf = udf(a,IntegerType()) 

エラー:

--------------------------------------------------------------------------- 
Py4JError         Traceback (most recent call last) 
<ipython-input-2-9756772ab14f> in <module>() 
     4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
     5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
----> 6 audf = udf(a,IntegerType()) 

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType) 
    1595  [Row(slen=5), Row(slen=3)] 
    1596  """ 
-> 1597  return UserDefinedFunction(f, returnType) 
    1598 
    1599 blacklist = ['map', 'since', 'ignore_unicode_prefix'] 

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name) 
    1556   self.returnType = returnType 
    1557   self._broadcast = None 
-> 1558   self._judf = self._create_judf(name) 
    1559 
    1560  def _create_judf(self, name): 

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name) 
    1565   command = (func, None, ser, ser) 
    1566   sc = SparkContext.getOrCreate() 
-> 1567   pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) 
    1568   ctx = SQLContext.getOrCreate(sc) 
    1569   jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) 

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj) 
    2297  # the serialized command will be compressed by broadcast 
    2298  ser = CloudPickleSerializer() 
-> 2299  pickled_command = ser.dumps(command) 
    2300  if len(pickled_command) > (1 << 20): # 1M 
    2301   # The broadcast will have same life cycle as created PythonRDD 

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj) 
    426 
    427  def dumps(self, obj): 
--> 428   return cloudpickle.dumps(obj, 2) 
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol) 
    644 
    645  cp = CloudPickler(file,protocol) 
--> 646  cp.dump(obj) 
    647 
    648  return file.getvalue() 

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj) 
    105   self.inject_addons() 
    106   try: 
--> 107    return Pickler.dump(self, obj) 
    108   except RuntimeError as e: 
    109    if 'recursion' in e.args[0]: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj) 
    222   if self.proto >= 2: 
    223    self.write(PROTO + chr(self.proto)) 
--> 224   self.save(obj) 
    225   self.write(STOP) 
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    566   write(MARK) 
    567   for element in obj: 
--> 568    save(element) 
    569 
    570   if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name) 
    191   if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None: 
    192    #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) 
--> 193    self.save_function_tuple(obj) 
    194    return 
    195   else: 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func) 
    234   # create a skeleton function object and memoize it 
    235   save(_make_skel_func) 
--> 236   save((code, closure, base_globals)) 
    237   write(pickle.REDUCE) 
    238   self.memoize(func) 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    552   if n <= 3 and proto >= 2: 
    553    for element in obj: 
--> 554     save(element) 
    555    # Subtle. Same as in the big comment below. 
    556    if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj) 
    604 
    605   self.memoize(obj) 
--> 606   self._batch_appends(iter(obj)) 
    607 
    608  dispatch[ListType] = save_list 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items) 
    637     write(MARK) 
    638     for x in tmp: 
--> 639      save(x) 
    640     write(APPENDS) 
    641    elif n: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    304    reduce = getattr(obj, "__reduce_ex__", None) 
    305    if reduce: 
--> 306     rv = reduce(self.proto) 
    307    else: 
    308     reduce = getattr(obj, "__reduce__", None) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    811   answer = self.gateway_client.send_command(command) 
    812   return_value = get_return_value(
--> 813    answer, self.gateway_client, self.target_id, self.name) 
    814 
    815   for temp_arg in temp_args: 

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 
    43  def deco(*a, **kw): 
    44   try: 
---> 45    return f(*a, **kw) 
    46   except py4j.protocol.Py4JJavaError as e: 
    47    s = e.java_exception.toString() 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 
    310     raise Py4JError(
    311      "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 
--> 312      format(target_id, ".", name, value)) 
    313   else: 
    314    raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace: 
py4j.Py4JException: Method __getnewargs__([]) does not exist 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335) 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344) 
    at py4j.Gateway.invoke(Gateway.java:252) 
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) 
    at py4j.commands.CallCommand.execute(CallCommand.java:79) 
    at py4j.GatewayConnection.run(GatewayConnection.java:209) 
    at java.lang.Thread.run(Thread.java:745) 

EDIT:私もなく、無駄にシリアライズ可能なJavaクラスを作ってみました。

私の第二の試みは、そもそもJavaでUDFを定義することでしたが、私はそれを正しくラップする方法を確認していないとして、それは失敗しました:

Javaコード: パッケージcom.test1.test2を。

import org.apache.spark.sql.api.java.UDF1; 

public class TestClassUdf implements UDF1<Integer, Integer> { 

    Integer retval; 

    public TestClassUdf(Integer val) { 
     retval = val; 
    } 

    @Override 
    public Integer call(Integer arg0) throws Exception { 
     return retval; 
    } 
} 

しかし、どうすれば使用できますか? ----- を: は、私が試した:

from py4j.java_gateway import java_import 
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf") 
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
dfint = sqlContext.range(0,15) 
df = dfint.withColumn("a",a(dfint.id)) 

が、私は得る:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-5-514811090b5f> in <module>() 
     3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
     4 dfint = sqlContext.range(0,15) 
----> 5 df = dfint.withColumn("a",a(dfint.id)) 

TypeError: 'JavaObject' object is not callable 

と私はのa.call代わりに使用してみました:

df = dfint.withColumn("a",a.call(dfint.id)) 

を得たが、 -------------------------------------------------- -------------------- TypeErrorトレースバック(直近の最後のコール) () 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 4 dfint = sqlContext.range(0,15) ----> 5 df = dfint.withColumn( "a"、a .call(dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    796  def __call__(self, *args): 
    797   if self.converters is not None and len(self.converters) > 0: 
--> 798    (new_args, temp_args) = self._get_args(args) 
    799   else: 
    800    new_args = args 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args) 
    783     for converter in self.gateway_client.converters: 
    784      if converter.can_convert(arg): 
--> 785       temp_arg = converter.convert(arg, self.gateway_client) 
    786       temp_args.append(temp_arg) 
    787       new_args.append(temp_arg) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client) 
    510   HashMap = JavaClass("java.util.HashMap", gateway_client) 
    511   java_map = HashMap() 
--> 512   for key in object.keys(): 
    513    java_map[key] = object[key] 
    514   return java_map 

TypeError: 'Column' object is not callable 

助けが必要な場合は、

答えて

3

私はこれをUDAFについてのanother question (and answer) of your ownの助けを借りて手に入れました。

SparkはScala FunctionNをラップする方法を提供しているので、ScalaでJava関数をラップして使用することができます。あなたのJavaメソッドは静的であるか、またはimplements Serializableというクラスにある必要があります。PySparkで

package com.example 

import org.apache.spark.sql.UserDefinedFunction 
import org.apache.spark.sql.functions.udf 

class MyUdf extends Serializable { 
    def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod()) 
} 

使用法:あなたの他の質問と回答でUDAFと同様

def my_udf(): 
    from pyspark.sql.column import Column, _to_java_column, _to_seq 
    pcls = "com.example.MyUdf" 
    jc = sc._jvm.java.lang.Thread.currentThread() \ 
     .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply 
    return Column(jc(_to_seq(sc, [], _to_java_column))) 

rdd1 = sc.parallelize([{'c1': 'a'}, {'c1': 'b'}, {'c1': 'c'}]) 
df1 = rdd1.toDF() 
df2 = df1.withColumn('mycol', my_udf()) 

、我々はreturn Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))

でそれに列を渡すことができます