2017-11-29 38 views
3

私は、次のような構造を持つスパークデータフレームを持っている:ベクトル列の最大値のインデックスを見つける方法は?

root 
|-- distribution: vector (nullable = true) 

+--------------------+ 
| topicDistribution| 
+--------------------+ 
|  [0.1, 0.2]  | 
|  [0.3, 0.2]  | 
|  [0.5, 0.2]  | 
|  [0.1, 0.7]  | 
|  [0.1, 0.8]  | 
|  [0.1, 0.9]  | 
+--------------------+ 

私の質問は:各行の最大値のインデックスを持つ列を追加する方法は?

それはこのようなものでなければなりません:

root 
|-- distribution: vector (nullable = true) 
|-- max_index: integer (nullable = true) 

+--------------------+-----------+ 
| topicDistribution| max_index | 
+--------------------+-----------+ 
|  [0.1, 0.2]  | 1  | 
|  [0.3, 0.2]  | 0  | 
|  [0.5, 0.2]  | 0  | 
|  [0.1, 0.7]  | 1  | 
|  [0.1, 0.8]  | 1  | 
|  [0.1, 0.9]  | 1  | 
+--------------------+-----------+ 

どうもありがとう

- - - - 更新 - - - - - :

私は、次の方法を試してみましたが、私は得ましたエラー:

import org.apache.spark.sql.functions.udf 

val func = udf((x: Vector[Double]) => x.indices.maxBy(x)) 

df.withColumn("max_idx",func(($"topicDistribution"))).show() 

エラー:

Exception in thread "main" org.apache.spark.sql.AnalysisException: 
cannot resolve 'UDF(topicDistribution)' due to data type mismatch: 
argument 1 requires array<double> type, however, '`topicDistribution`' 
is of vector type.;; 

答えて

1
// create some sample data: 
import org.apache.spark.mllib.linalg.{Vectors,Vector} 
case class myrow(topics:Vector) 

val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2)))) 
val mydf = sqlContext.createDataFrame(rdd) 
mydf.show() 
+----------+ 
| topics| 
+----------+ 
|[0.1, 0.2]| 
|[0.6, 0.2]| 
+----------+ 

// build the udf 
import org.apache.spark.sql.functions.udf 
val func = udf((x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values)) 


mydf.withColumn("max_idx",func($"topics")).show() 
+----------+-------+ 
| topics|max_idx| 
+----------+-------+ 
|[0.1, 0.2]|  1| 
|[0.6, 0.2]|  0| 
+----------+-------+ 

//注:あなたが元の質問やコメントなどの代わりに配列のベクターを使用するように編集// ユースケースあなたの特定のための代わりに、配列のベクトルをするUDFを変更する必要があります尋ねた

+0

ありがとうございます。ただし、データ型の不一致によりUDF(トピック)を解決できません。引数1には配列型が必要ですが、 '' topics' 'はベクトル型です。 – Ippon

1

注:このソリューションは、パフォーマンスには最適ではないかもしれませんが、問題に取り組むための他のアプローチを示しています(また、Spark SQLのDataset APIの豊富さを示しています)。


vectorので、私は、最初のサンプルデータセットを作成してみましょうスパークMLlibのVectorUDTからです。

val input = Seq((0.1, 0.2), (0.3, 0.2)).toDF 
import org.apache.spark.ml.feature.VectorAssembler 
val vecAssembler = new VectorAssembler() 
    .setInputCols(Array("_1", "_2")) 
    .setOutputCol("distribution") 
val ds = vecAssembler.transform(input).select("distribution") 
scala> ds.printSchema 
root 
|-- distribution: vector (nullable = true) 

スキーマはあなたのものと全く同じです。


のはVectorUDTから定期Array[Double]にタイプを変更してみましょう。あなたがインデックスに配列の要素をposexplodeを使用することができarrays

import org.apache.spark.ml.linalg.Vector 
val arrays = ds 
    .map { r => r.getAs[Vector](0).toArray } 
    .withColumnRenamed("value", "distribution") 
scala> arrays.printSchema 
root 
|-- distribution: array (nullable = true) 
| |-- element: double (containsNull = false) 

、位置オーバーmaxからgroupByjoin解決のため。

val pos = arrays.select($"*", posexplode($"distribution")) 
val max_cols = pos 
    .groupBy("distribution") 
    .agg(max("col") as "max_col") 
val solution = pos 
    .join(max_cols, "distribution") 
    .filter($"col" === $"max_col") 
    .select("distribution", "pos") 
scala> solution.show 
+------------+---+ 
|distribution|pos| 
+------------+---+ 
| [0.1, 0.2]| 1| 
| [0.3, 0.2]| 0| 
+------------+---+ 
+0

私は 'pos'と' max_cols'を得ることができます、そして、彼らは正しいようです。ただし、最後のステップでは、空のデータフレームのみが表示されます。 私を助けてくれてありがとう。ほんとうにありがとう。 – Ippon

+0

私のディストリビューションカラムの要素は0.04939380187616082のように見えるので、カラムを2小数点以下に丸める必要があると思います。あなたはそれをする方法を知っていますか? – Ippon

+0

'solution.printSchema'できますか? –

関連する問題