2016-10-25 9 views
1

Spark私は、集約とライブラリを提供するはずのライブラリを使用しています。その後、一連の結合/ groupbyを実行し、集約を最後に呼び出します。私はカプセル化違反を避けようとしていますが(必要であれば可能ですが)、このメソッドを集計します(通常は合計または最小など)。sparkデータフレームのモード(最も一般的な要素)の集計

この場合、私はモードを実行しようとしています集計で実行する方法がわからないここで

答えて

2

は、特定の列の統計モードを計算するスパーク(2.1.0)UDAFです:

package org.anish.spark.mostcommonvalue 

import org.apache.spark.sql.Row 
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 
import org.apache.spark.sql.types._ 

import scalaz.Scalaz._ 

/** 
    * Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to 
    * Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both 
    * these values together is considered as mode. 
    * 
    * Usage: 
    * 
    * DataFrame/DataSet DSL 
    * val mostCommonValue = new MostCommonValue 
    * df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city"))) 
    * 
    * Spark SQL: 
    * sqlContext.udf.register("mode", new MostCommonValue) 
    * %sql 
    * -- Use a group_by statement and call the UDAF. 
    * select group_id, mode(id) from table group by group_id 
    * 
    * Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html 
    * 
    * Created by anish on 26/05/17. 
    */ 
class MostCommonValue extends UserDefinedAggregateFunction { 

    // This is the input fields for your aggregate function. 
    // We use StringType, because Mode can also be meaningfully applied on nominal data 
    override def inputSchema: StructType = 
    StructType(StructField("value", StringType) :: Nil) 

    // This is the internal fields you keep for computing your aggregate. 
    // We store the frequency of all the distinct element we encounter for the given attribute in this HashMap 
    override def bufferSchema: StructType = StructType(
    StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil 
) 

    // This is the output type of your aggregation function. 
    override def dataType: DataType = StringType 

    override def deterministic: Boolean = true 

    // This is the initial value for the buffer schema. 
    override def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = Map[String, Long]() 
    } 

    // This is how to update your buffer schema given an input. 
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
    buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L) 
    } 

    // This is how you merge two objects with the bufferSchema type. 
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0) 
    } 

    // This is where you output the final value, given the final value of your bufferSchema. 
    override def evaluate(buffer: Row): String = { 
    buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1 
    } 
} 

クレジット/ソース: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa

関連する問題