2016-11-21 3 views
2

私はScala 2.0.2でUserDefinedAggregateFunctionをScalaで使用しようとしていますが、私はマッチエラーに遭遇しています。私はテストケースとして以下を作成しました。私が書いているコードは、以下のようなものです。Spark UserDefinedAggregateFunction:scala.MatchError 0.0(クラスjava.lang.Doubleの)

集計ウィンドウで値を累積しようとしています。これは単なる累積合計ではありませんが、いくつかの条件に基づいて、保持する数値を計算する必要があります。

テストケースとして、償却テーブルを作成しました。そこでは、毎月の期首残高と口座残高を計算する必要があります。

データは次のようになります。

+------+--------+------------+---------+ 
|Period| Capital|InterestRate|Repayment| 
+------+--------+------------+---------+ 
|201601| 0.00 |  0.10 | 0.00 | 
|201602|1000.00 |  0.00 | 0.00 | 
|201603|2000.00 |  0.10 | 0.00 | 
|201604| 0.00 |  0.10 | -200.00 | 
|201605| 0.00 |  0.10 | -200.00 | 
|201606| 0.00 |  0.10 | -200.00 | 
|201607| 0.00 |  0.10 | -200.00 | 
|201608| 0.00 |  0.00 | -200.00 | 
|201609| 0.00 |  0.10 | -200.00 | 
|201610| 0.00 |  0.10 | -200.00 | 
|201611| 0.00 |  0.10 | -200.00 | 
|201612| 0.00 |  0.10 | -200.00 | 
+------+--------+------------+---------+ 

私はCSVを適切にフォーマットすることができなかったが、私はここに要旨にそれを追加しました:https://gist.github.com/nevi-me/8b2362a5365e73af947fc13bb5836adc

私は、OpeningClosing残高を計算し、集計からClosing残高を返しています。

package me.nevi 

import org.apache.spark.sql._ 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} 
import org.apache.spark.sql.types.{StructType, DoubleType, DataType} 

object AggregationTest { 

    object amortisedClosingBalance extends UserDefinedAggregateFunction { 
    override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType) 

    override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType) 

    override def dataType: DataType = new StructType().add("Closing", DoubleType) 

    override def deterministic: Boolean = true 

    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
     buffer.update(0, 0.0) 
     buffer.update(1, 0.0) 
    } 

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
     if (!input.isNullAt(0)) { 
     println(buffer.get(0)) 
     println(buffer.get(1)) 
     buffer.update(0, buffer.getDouble(1)) 
     // (opening + capital) * interestrate - repayment 
     buffer.update(1, (buffer.getDouble(0) + input.getDouble(0)) * input.getDouble(1) + input.getDouble(2)) 
     } else { 
     // if first record? 
     buffer.update(0, input.getDouble(0)) 
     buffer.update(1, input.getDouble(0)) 
     } 
    } 

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
     buffer1.update(0, buffer1.getDouble(0)) 
     buffer1.update(1, buffer1.getDouble(1)) 
    } 

    override def evaluate(buffer: Row): Any = { 
     buffer.getDouble(1) 
    } 
    } 

    def main(args: Array[String]): Unit = { 
    System.setProperty("hadoop.home.dir", "C:/spark") 
    System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse") 

    val spark: SparkSession = SparkSession.builder() 
     .master("local[*]") 
     .appName("Aggregation Test") 
     .getOrCreate() 

    import spark.implicits._ 

    val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv") 

    df.show() 

    val windowSpec = Window.orderBy(df.col("Period")) 

    val calc = df.withColumn("Closing", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec)) 

    calc.show() 

    } 
} 

Scalaは私は例外を取得:

scala.MatchError: 0.0 (of class java.lang.Double) 
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:256) 
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:251) 
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103) 
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:403) 
    at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:440) 
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source) 
    at org.apache.spark.sql.execution.AggregateProcessor.evaluate(WindowExec.scala:1029) 
    at org.apache.spark.sql.execution.UnboundedPrecedingWindowFunctionFrame.write(WindowExec.scala:822) 
    at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:398) 
    at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:289) 
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) 
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) 
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370) 
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246) 
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240) 
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) 
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803) 
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) 
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319) 
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:283) 
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70) 
    at org.apache.spark.scheduler.Task.run(Task.scala:86) 
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274) 
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) 
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) 
    at java.lang.Thread.run(Thread.java:745) 

誰もが私が間違っているのか知っていますか?最初はSpark 2.0.0を使用しましたが、UDTFの同様の問題を抱えて他の誰かに出くわしましたが、アップグレード後に2.0.1にアップグレードすることを提案しました。私の問題は依然として続きます。


ソリューション:

問題興味がある人のためには受け入れ答えあたりとしての私のスキーマを使用していました。以下は正しく計算されたスニペットです。下記の結果で

package me.nevi 

import org.apache.spark.sql._ 
import org.apache.spark.sql.functions._ 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} 
import org.apache.spark.sql.types.{DataType, DoubleType, StructType} 

object AggregationTest { 

    object amortisedClosingBalance extends UserDefinedAggregateFunction { 
    override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType) 

    override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType) 

    override def dataType: DataType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType) 

    override def deterministic: Boolean = true 

    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
     buffer.update(0, 0.0) 
     buffer.update(1, 0.0) 
    } 

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
     if (!input.isNullAt(0)) { 
     println(buffer.get(0)) 
     println(buffer.get(1)) 
     buffer.update(0, buffer.getDouble(1)) 
     // (opening + capital) * interestrate - repayment 
     buffer.update(1, input.getDouble(0) 
      + buffer.getDouble(0) + input.getDouble(2) + (buffer.getDouble(0) + input.getDouble(0)) * (input.getDouble(1)/12)) 
     } else { 
     // if first record? 
     buffer.update(0, input.getDouble(0)) 
     buffer.update(1, input.getDouble(0)) 
     } 
    } 

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
     buffer1.update(0, buffer1.getDouble(0)) 
     buffer1.update(1, buffer1.getDouble(1)) 
    } 

    override def evaluate(buffer: Row): Any = { 
     Row(buffer.getDouble(0), buffer.getDouble(1)) 
    } 
    } 

    def main(args: Array[String]): Unit = { 
    System.setProperty("hadoop.home.dir", "C:/spark") 
    System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse") 

    val spark: SparkSession = SparkSession.builder() 
     .master("local[*]") 
     .appName("Aggregation Test") 
     .getOrCreate() 

    import spark.implicits._ 

    val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv") 

    df.show() 

    val windowSpec = Window.orderBy(df.col("Period").asc) 

    var calc = df.withColumn("Calcs", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec)) 
    calc = calc.withColumn("Opening", round($"Calcs".getField("Opening"), 2)).withColumn("Closing", round($"Calcs".getField("Closing"),2)) 
     .drop("Calcs") 

    calc.show() 

    } 
} 

+------+--------+------------+---------+-------+-------+ 
|Period| Capital|InterestRate|Repayment|Opening|Closing| 
+------+--------+------------+---------+-------+-------+ 
|201601| 0.00 |  0.10 | 0.00 | 0.0| 0.0| 
|201602|1000.00 |  0.00 | 0.00 | 0.0| 1000.0| 
|201603|2000.00 |  0.10 | 0.00 | 1000.0| 3025.0| 
|201604| 0.00 |  0.10 | -200.00 | 3025.0|2850.21| 
|201605| 0.00 |  0.10 | -200.00 |2850.21|2673.96| 
|201606| 0.00 |  0.10 | -200.00 |2673.96|2496.24| 
|201607| 0.00 |  0.10 | -200.00 |2496.24|2317.05| 
|201608| 0.00 |  0.00 | -200.00 |2317.05|2117.05| 
|201609| 0.00 |  0.10 | -200.00 |2117.05|1934.69| 
|201610| 0.00 |  0.10 | -200.00 |1934.69|1750.81| 
|201611| 0.00 |  0.10 | -200.00 |1750.81| 1565.4| 
|201612| 0.00 |  0.10 | -200.00 | 1565.4|1378.44| 
+------+--------+------------+---------+-------+-------+ 

答えて

3

あなたが誤っdataType定義に例外を取得します。実際にあなたがスカラーを返すときにそれを:

StructType(StructField(Closing,DoubleType,true)) 

と宣言します。それはおそらくある

|-- Closing: struct (nullable = true) 
| |-- Closing: double (nullable = true) 

:それはのように定義する必要があります。

override def dataType: DataType = DoubleType 

か、evaluteを再定義する必要があり、例えばに:

override def evaluate(buffer: Row): Any = { 
    Row(buffer.getDouble(1)) 
} 

後者は、ネストされた列を返します。あなたが探しているものではありません。

+0

ありがとう、それは私を助けてくれました。私は将来的に自分のスキーマをチェックします。後者はネストされた構造になっています。これは、同時に開閉のバランスを戻すことができるためです。私がやったことで私の質問を更新します。 –

関連する問題