2017-03-09 7 views
0

(触媒パッケージで)カスタムDeclarativeAggregateためmergeExpressionsを定義する方法。 org.apache.spark.sql.catalyst.expressions.aggregate.Averageのようなもののため mergeExpresssions方法は簡単です:CentralMomentAggアグリゲータのための私は1つの非自明なアグリゲータのためのmergeExpressions機能を決定するのにかかる一般的なアプローチを理解していない

override lazy val mergeExpressions = Seq(
    /* sum = */ sum.left + sum.right, 
    /* count = */ count.left + count.right 
) 

mergeExpressionsは少し複雑です。 私がしたいのは、CentralMomentAggが起動した後にモデル化されたWeightedStddevSampアグリゲーターを作成することです。 私はそれがほとんど機能していますが、それが作り出す加重標準偏差は、私が手で計算したものから少しずれています。 mergeExpressionsメソッドの正確なロジックをどのように計算できるかわからないため、デバッグに問題があります。 以下は私のコードです。 updateExpressionsメソッドはweighted incremental algorithmに基づいているので、メソッドが正しいことを確信しています。私は自分の問題がmergeExpressionsメソッドにあると信じています。どんなヒントもありがとう。任意のハッシュ集約のため

abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate { 

    override def children: Seq[Expression] = Seq(child, weight) 
    override def nullable: Boolean = true 
    override def dataType: DataType = DoubleType 
    override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) 

    protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)() 
    protected val mean = AttributeReference("mean", DoubleType, nullable = false)() 
    protected val s = AttributeReference("s", DoubleType, nullable = false)() 
    override val aggBufferAttributes = Seq(wSum, mean, s) 
    override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0)) 

    // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm 
    override val updateExpressions: Seq[Expression] = { 

    val newWSum = wSum + weight 
    val newMean = mean + (weight/newWSum) * (child - mean) 
    val newS = s + weight * (child - mean) * (child - newMean) 

    Seq(
     If(IsNull(child), wSum, newWSum), 
     If(IsNull(child), mean, newMean), 
     If(IsNull(child), s, newS) 
    ) 
    } 

    override val mergeExpressions: Seq[Expression] = { 
    val wSum1 = wSum.left 
    val wSum2 = wSum.right 
    val newWSum = wSum1 + wSum2 
    val delta = mean.right - mean.left 
    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum) 
    val newMean = mean.left + wSum1/newWSum * delta    // ??? 
    val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN  // ??? 
    Seq(newWSum, newMean, newS) 
    } 
} 


// Compute the weighted sample standard deviation of a column 
case class WeightedStddevSamp(child: Expression, weight: Expression) 
    extends WeightedCentralMomentAgg(child, weight) { 

    override val evaluateExpression: Expression = { 
    If(wSum === Literal(0.0), Literal.create(null, DoubleType), 
     If(wSum === Literal(1.0), Literal(Double.NaN), 
     Sqrt(s/wSum))) 
    } 

    override def prettyName: String = "wtd_stddev_samp" 
} 

答えて

0

を使用して、バッファからの最終的な結果を得る私は、加重標準偏差をmergeExpressions関数を作成する方法を発見しました。私は実際にそれが正しいと思ったが、evaluateExpressionのサンプル分散計算ではなく、母集団の分散を使用していた。以下に示す実装は上記と同じ結果を示しますが、理解しやすくなります。ここで

override val mergeExpressions: Seq[Expression] = { 
    val newN = n.left + n.right 
    val wSum1 = wSum.left 
    val wSum2 = wSum.right 
    val newWSum = wSum1 + wSum2 
    val delta = mean.right - mean.left 

    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta/newWSum) 
    val newMean = mean.left + deltaN * wSum2 
    val newS = (((wSum1 * s.left) + (wSum2 * s.right))/newWSum) + (wSum1 * wSum2 * deltaN * deltaN) 

    Seq(newN, newWSum, newMean, newS) 
} 

は、いくつかの参照 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
  • ​​
  • http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
  • https://blog.cordiner.net/2010/06/16/calculating-variance-and-mean-with-mapreduce-python/

    • ある
    (この最後のものは私がmergeExpressions機能のために必要な手がかりを与えました)

    Daviesの記事はアプローチの概要を示していますが、多くの重要でないアグリゲータにとっては、mergeExpressions関数は非常に複雑で、正確で効率的な解を決定する高度な数学を必要とすると思います。幸いにも、この場合、私はそれを練習した人を見つけました。

    このソリューションは、私が手で働くものと一致します。母集団分散の代わりに標本分散が必要な場合は、evaluateExpressionを少し変更する必要があることに注意することが重要です(s /((n-1)* wSum/n))。

  • 2

    、それは4つのステップに分かれています。)

    1キーのバッファを更新し、パーティション内のバッファ(wSum、平均、S)

    2)を初期化しますシャッフルした後、すべての入力(入力のそれぞれについてupdateExpressionを呼び出す)

    3)与えられ、mergeExpressionを使用して、同じキーに対するすべてのバッファをマージします。 wSum.leftは左のバッファにwSumを意味し、wSum.rightは

    4他のバッファにwSumを意味する)のValueExpression

    関連する問題