私はScala 2.0.2でUserDefinedAggregateFunction
をScalaで使用しようとしていますが、私はマッチエラーに遭遇しています。私はテストケースとして以下を作成しました。私が書いているコードは、以下のようなものです。Spark UserDefinedAggregateFunction:scala.MatchError 0.0(クラスjava.lang.Doubleの)
集計ウィンドウで値を累積しようとしています。これは単なる累積合計ではありませんが、いくつかの条件に基づいて、保持する数値を計算する必要があります。
テストケースとして、償却テーブルを作成しました。そこでは、毎月の期首残高と口座残高を計算する必要があります。
データは次のようになります。
+------+--------+------------+---------+
|Period| Capital|InterestRate|Repayment|
+------+--------+------------+---------+
|201601| 0.00 | 0.10 | 0.00 |
|201602|1000.00 | 0.00 | 0.00 |
|201603|2000.00 | 0.10 | 0.00 |
|201604| 0.00 | 0.10 | -200.00 |
|201605| 0.00 | 0.10 | -200.00 |
|201606| 0.00 | 0.10 | -200.00 |
|201607| 0.00 | 0.10 | -200.00 |
|201608| 0.00 | 0.00 | -200.00 |
|201609| 0.00 | 0.10 | -200.00 |
|201610| 0.00 | 0.10 | -200.00 |
|201611| 0.00 | 0.10 | -200.00 |
|201612| 0.00 | 0.10 | -200.00 |
+------+--------+------------+---------+
私はCSVを適切にフォーマットすることができなかったが、私はここに要旨にそれを追加しました:https://gist.github.com/nevi-me/8b2362a5365e73af947fc13bb5836adc。
私は、Opening
とClosing
残高を計算し、集計からClosing
残高を返しています。
package me.nevi
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.types.{StructType, DoubleType, DataType}
object AggregationTest {
object amortisedClosingBalance extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType)
override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)
override def dataType: DataType = new StructType().add("Closing", DoubleType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
println(buffer.get(0))
println(buffer.get(1))
buffer.update(0, buffer.getDouble(1))
// (opening + capital) * interestrate - repayment
buffer.update(1, (buffer.getDouble(0) + input.getDouble(0)) * input.getDouble(1) + input.getDouble(2))
} else {
// if first record?
buffer.update(0, input.getDouble(0))
buffer.update(1, input.getDouble(0))
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getDouble(0))
buffer1.update(1, buffer1.getDouble(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getDouble(1)
}
}
def main(args: Array[String]): Unit = {
System.setProperty("hadoop.home.dir", "C:/spark")
System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse")
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("Aggregation Test")
.getOrCreate()
import spark.implicits._
val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv")
df.show()
val windowSpec = Window.orderBy(df.col("Period"))
val calc = df.withColumn("Closing", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec))
calc.show()
}
}
Scalaは私は例外を取得:
scala.MatchError: 0.0 (of class java.lang.Double)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:256)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:251)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:403)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:440)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.AggregateProcessor.evaluate(WindowExec.scala:1029)
at org.apache.spark.sql.execution.UnboundedPrecedingWindowFunctionFrame.write(WindowExec.scala:822)
at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:398)
at org.apache.spark.sql.execution.WindowExec$$anonfun$15$$anon$1.next(WindowExec.scala:289)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:803)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
at org.apache.spark.scheduler.Task.run(Task.scala:86)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
誰もが私が間違っているのか知っていますか?最初はSpark 2.0.0を使用しましたが、UDTFの同様の問題を抱えて他の誰かに出くわしましたが、アップグレード後に2.0.1にアップグレードすることを提案しました。私の問題は依然として続きます。
ソリューション:
問題興味がある人のためには受け入れ答えあたりとしての私のスキーマを使用していました。以下は正しく計算されたスニペットです。下記の結果で
package me.nevi
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
object AggregationTest {
object amortisedClosingBalance extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("Capital", DoubleType).add("InterestRate", DoubleType).add("Repayment", DoubleType)
override def bufferSchema: StructType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)
override def dataType: DataType = new StructType().add("Opening", DoubleType).add("Closing", DoubleType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
println(buffer.get(0))
println(buffer.get(1))
buffer.update(0, buffer.getDouble(1))
// (opening + capital) * interestrate - repayment
buffer.update(1, input.getDouble(0)
+ buffer.getDouble(0) + input.getDouble(2) + (buffer.getDouble(0) + input.getDouble(0)) * (input.getDouble(1)/12))
} else {
// if first record?
buffer.update(0, input.getDouble(0))
buffer.update(1, input.getDouble(0))
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getDouble(0))
buffer1.update(1, buffer1.getDouble(1))
}
override def evaluate(buffer: Row): Any = {
Row(buffer.getDouble(0), buffer.getDouble(1))
}
}
def main(args: Array[String]): Unit = {
System.setProperty("hadoop.home.dir", "C:/spark")
System.setProperty("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse")
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("Aggregation Test")
.getOrCreate()
import spark.implicits._
val df = spark.read.option("header", true).csv("file:///d:/interest_calc.csv")
df.show()
val windowSpec = Window.orderBy(df.col("Period").asc)
var calc = df.withColumn("Calcs", amortisedClosingBalance($"Capital", $"InterestRate", $"Repayment").over(windowSpec))
calc = calc.withColumn("Opening", round($"Calcs".getField("Opening"), 2)).withColumn("Closing", round($"Calcs".getField("Closing"),2))
.drop("Calcs")
calc.show()
}
}
:
+------+--------+------------+---------+-------+-------+
|Period| Capital|InterestRate|Repayment|Opening|Closing|
+------+--------+------------+---------+-------+-------+
|201601| 0.00 | 0.10 | 0.00 | 0.0| 0.0|
|201602|1000.00 | 0.00 | 0.00 | 0.0| 1000.0|
|201603|2000.00 | 0.10 | 0.00 | 1000.0| 3025.0|
|201604| 0.00 | 0.10 | -200.00 | 3025.0|2850.21|
|201605| 0.00 | 0.10 | -200.00 |2850.21|2673.96|
|201606| 0.00 | 0.10 | -200.00 |2673.96|2496.24|
|201607| 0.00 | 0.10 | -200.00 |2496.24|2317.05|
|201608| 0.00 | 0.00 | -200.00 |2317.05|2117.05|
|201609| 0.00 | 0.10 | -200.00 |2117.05|1934.69|
|201610| 0.00 | 0.10 | -200.00 |1934.69|1750.81|
|201611| 0.00 | 0.10 | -200.00 |1750.81| 1565.4|
|201612| 0.00 | 0.10 | -200.00 | 1565.4|1378.44|
+------+--------+------------+---------+-------+-------+
ありがとう、それは私を助けてくれました。私は将来的に自分のスキーマをチェックします。後者はネストされた構造になっています。これは、同時に開閉のバランスを戻すことができるためです。私がやったことで私の質問を更新します。 –