(触媒パッケージで)カスタムDeclarativeAggregateためmergeExpressionsを定義する方法。 org.apache.spark.sql.catalyst.expressions.aggregate.Averageのようなもののため mergeExpresssions方法は簡単です:CentralMomentAggアグリゲータのための私は1つの非自明なアグリゲータのためのmergeExpressions機能を決定するのにかかる一般的なアプローチを理解していない
override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
)
mergeExpressionsは少し複雑です。 私がしたいのは、CentralMomentAggが起動した後にモデル化されたWeightedStddevSampアグリゲーターを作成することです。 私はそれがほとんど機能していますが、それが作り出す加重標準偏差は、私が手で計算したものから少しずれています。 mergeExpressionsメソッドの正確なロジックをどのように計算できるかわからないため、デバッグに問題があります。 以下は私のコードです。 updateExpressionsメソッドはweighted incremental algorithmに基づいているので、メソッドが正しいことを確信しています。私は自分の問題がmergeExpressionsメソッドにあると信じています。どんなヒントもありがとう。任意のハッシュ集約のため
abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(child, weight)
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)()
protected val mean = AttributeReference("mean", DoubleType, nullable = false)()
protected val s = AttributeReference("s", DoubleType, nullable = false)()
override val aggBufferAttributes = Seq(wSum, mean, s)
override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0))
// See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
override val updateExpressions: Seq[Expression] = {
val newWSum = wSum + weight
val newMean = mean + (weight/newWSum) * (child - mean)
val newS = s + weight * (child - mean) * (child - newMean)
Seq(
If(IsNull(child), wSum, newWSum),
If(IsNull(child), mean, newMean),
If(IsNull(child), s, newS)
)
}
override val mergeExpressions: Seq[Expression] = {
val wSum1 = wSum.left
val wSum2 = wSum.right
val newWSum = wSum1 + wSum2
val delta = mean.right - mean.left
val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum)
val newMean = mean.left + wSum1/newWSum * delta // ???
val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN // ???
Seq(newWSum, newMean, newS)
}
}
// Compute the weighted sample standard deviation of a column
case class WeightedStddevSamp(child: Expression, weight: Expression)
extends WeightedCentralMomentAgg(child, weight) {
override val evaluateExpression: Expression = {
If(wSum === Literal(0.0), Literal.create(null, DoubleType),
If(wSum === Literal(1.0), Literal(Double.NaN),
Sqrt(s/wSum)))
}
override def prettyName: String = "wtd_stddev_samp"
}