2017-10-01 11 views
1

私はSpark 2.2を使用しています。 ArrayTypeで作業することについての基本的な質問があります。私は組み込みの集約関数を見つけることができませんでした。データフレーム内の配列列のインデックスで平均を計算する方法

DataFrameの場合はidvaluesの場合はArrayTypeとなります。

idでグループ化し、平均値をインデックスで計算します。

したがって、次の入力

{"id": 1, "values":[1.0, 3.0]} 
{"id": 1, "values":[3.0, 7.0]} 
{"id": 2, "values":[2.0, 4.0]} 

与えられた私たちは、この出力

{"id": 1, "values":[2.0, 5.0]} 
{"id": 2, "values":[2.0, 4.0]} 

をしたい私はUDAF、以下のCFコードを用いて溶液を作ってみました。

パフォーマンス面で(UDAFを使用しないなど)より良い方法がありますか?

val meanByIndex = new UserDefinedAggregateFunction { 
    override def inputSchema: StructType = 
     StructType(
     StructField("values", ArrayType(DoubleType)) :: Nil 
    ) 

    override def dataType: DataType = ArrayType(DoubleType) 

    override def deterministic: Boolean = true 

    override def update(buffer: MutableAggregationBuffer, row: Row): Unit = { 
     buffer.update(0, buffer.getAs[Long](0) + 1) 
     buffer.update(1, sumSeq(buffer.getAs[Seq[Double]](1), row.getAs[Seq[Double]](0)) 
    ) 
    } 

    override def bufferSchema: StructType = 
     StructType(
     StructField("size", LongType) :: 
      StructField("sum", ArrayType(DoubleType)) :: Nil 
    ) 

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
     buffer1.update(0, buffer1.getAs[Long](0) + buffer2.getAs[Long](0)) 
     buffer1.update(1, sumSeq(buffer1.getAs[Seq[Double]](1), buffer2.getAs[Seq[Double]](1)) 
    ) 
    } 

    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
     buffer.update(0, 0L) 
     buffer.update(1, Seq.empty[Double]) 
    } 

    override def evaluate(buffer: Row): Any = { 
     buffer.getAs[Seq[Double]](1).map(_/buffer.getAs[Long](0)) 
    } 

    private def sumSeq(s1: Seq[Double], s2: Seq[Double]) = { 
     if (s1.isEmpty) 
     s2 
     else { 
     s1.zip(s2).map { case (v1, v2) => v1 + v2 } 
     } 
    } 
    } 

[更新] @ user6910411の回答に関して、私は実行計画を比較しました。 UDAF

なしUDAF

SortAggregate 
+- *Sort [id#1 ASC NULLS FIRST], false, 0 
    +- Exchange hashpartitioning(id#1, 200) 
     +- SortAggregate 
     +- *Sort [id#1 ASC NULLS FIRST], false, 0 
      +- *FileScan json 

*HashAggregate 
+- Exchange hashpartitioning(id#1, 200) 
    +- *HashAggregate 
     +- *FileScan json 

結論:我々は全体のデータセットをソートする必要がないため、UDAFのないソリューションが優れています。私はUserDefinedAggregateFunctionを気にし、標準の集計を使用することはありません固定サイズの配列で

答えて

0

import org.apache.spark.sql.functions._ 

val df = Seq(
    (1, Seq(1.0, 3.0)), 
    (1, Seq(3.0, 7.0)), 
    (2, Seq(2.0, 4.0)) 
).toDF("id", "values") 

df.groupBy("id").agg(array((0 until 2) map (i => avg($"values"(i))): _*)) 
+---+-------------------------------------+ 
| id|array(avg(values[0]), avg(values[1]))| 
+---+-------------------------------------+ 
| 1|       [2.0, 5.0]| 
| 2|       [2.0, 4.0]| 
+---+-------------------------------------+ 
+0

何UDAFバージョン対パフォーマンスについて? –

関連する問題