2017-09-22 9 views
1

470のフィーチャと1,000万のトレーニングインスタンスを持つデータセットのロジスティック回帰モデルに適合しようとしています。ここに私のコードのスニペットがあります。ここで示唆したように、私はまた、spark java.lang.stackoverflow大規模なデータセットとのロジスティック回帰フィット

sc.setCheckpointDir('checkpoint/') 

をチェックポイントディレクトリを作成している

from pyspark.ml import Pipeline 
from pyspark.ml.classification import LogisticRegression 
from pyspark.ml.feature import RFormula 

formula = RFormula(formula = "label ~ .-classWeight") 


bestregLambdaVal = 0.005 
bestregAlphaVal = 0.01 

lr = LogisticRegression(maxIter=1000, regParam=bestregLambdaVal, elasticNetParam=bestregAlphaVal,weightCol="classWeight") 
pipeLineLr = Pipeline(stages = [formula, lr]) 
pipeLineFit = pipeLineLr.fit(mySparkDataFrame[featureColumnNameList + ['classWeight','label']]) 

Spark gives a StackOverflowError when training using ALS

私はここでエラーを取得し、しかし、部分的なトレースです:

File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 64, in fit 
    File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 108, in _fit 
    File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 64, in fit 
    File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/wrapper.py", line 265, in _fit 
    File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/wrapper.py", line 262, in _fit_java 
    File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__ 
    File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 63, in deco 
    File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value 
py4j.protocol.Py4JJavaError: An error occurred while calling o383361.fit. 
: java.lang.StackOverflowError 
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1189) 
    at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) 
    at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509) 
    at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432) 
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178) 
    at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348) 
    at scala.collection.immutable.List$SerializationProxy.writeObject(List.scala:468) 
    at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source) 

また、470フィーチャの列withcolumn()を使用してデータフレームをスパークするために反復的に追加されました。

答えて

0

だから私は作っていたミスがデータフレームのチェックポイント時に、私はどうなる、ということです。

mySparkDataFrame.checkpoint(eager=True) 

権利がやっていた。

mySparkDataFrame = mySparkDataFrame.checkpoint(eager=True) 

これは、私は別の質問に基づいていますここで尋ねた(とための答えを持って)いました:

pyspark rdd isCheckPointed() is false

また、チェックポイントの前にデータフレームをpersist()し、チェックポイントの後でcount()することをお勧めします。

関連する問題