2017-12-07 18 views
0

冗長なタイトルをお詫び申し上げますが、私は本当に良いものを思いつくことができませんでした。Sparkデータフレーム内のマップ内の構造体の配列に集約

基本的に、私は次のスキーマとデータを持っている:

|-- id: string (nullable = true) 
|-- mainkey: map (nullable = true) 
| |-- key: string 
| |-- value: array (valueContainsNull = true) 
| | |-- element: struct (containsNull = true) 
| | | |-- price: double (nullable = true) 
| | | |-- recordtype: string (nullable = true) 

は、私は、次の例のデータを使用してみましょう:上記二つのレコードのそれぞれについて

{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}} 
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}} 

、私は、平均計算したいですレコードタイプが「BID」の場合のすべての価格。したがって、最初のレコード(「id」:1)では、価格が0.01と2.0の2つの入札があるため、小数点以下第2位に四捨五入された平均は1.01です。 2番目のレコード( "id":2)には、価格2.5、2.4、0.19の3つの入札があり、平均は1.70です。だから私は、次のような出力をしたい:

+---+---------+ 
| id|meanvalue| 
+---+---------+ 
| 1|  1.01| 
| 2|  1.7| 
+---+---------+ 

次のコードは、それをしない:

val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true)))) 
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example") 
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") { 
    case Row(recordValue: Seq[Row] @unchecked) => recordValue.map{ recordValue => 
    val price = recordValue(0).asInstanceOf[Double] 
    val recordtype = recordValue(1).asInstanceOf[String] 
    RecordValue(price, recordtype) 
    } 
    }.cache() 

val filteredExJson = explodeExJson.filter($"recordtype"==="BID") 

val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue")) 

問題は、それが「高価な」爆発する操作を使用していることであると私はたくさん取り扱っておりますときに問題となります特に地図にたくさんのキーがある場合は、

もっと簡単な解決策が考えられるかどうか、UDFなどを使用して教えてください。私はスパークの初心者であり、あなたには明白なものをいくつか見逃している可能性があることを覚えておいてください。

本当にありがとうございます。前もって感謝します!

答えて

2

集約は、単一の udfに制限されている場合は、この解決します:

import org.apache.spark.util.StatCounter 
import org.apache.spark.sql.functions.udf 
import org.apache.spark.sql.Row 

val meanPrice = udf((map: Map[String, Seq[Row]]) => { 
    val prices = map.values 
    .flatMap(x => x) 
    .filter(_.getAs[String]("recordtype") == "BID") 
    .map(_.getAs[Double]("price")) 
    StatCounter(prices).mean 
}) 

df.select($"id", meanPrice($"mainkey")) 
関連する問題