9

スパークのStringIndexerは非常に便利ですが、生成されたインデックス値と元の文字列の間の対応関係を取得する必要があります。この。私はSpark documentationからこの単純な例を使って説明します:インデックス文字列の対応を保持するスパーク文字列インデクサ

from pyspark.ml.feature import StringIndexer 

df = sqlContext.createDataFrame(
    [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], 
    ["id", "category"]) 
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") 
indexed_df = indexer.fit(df).transform(df) 

この簡略化された場合は、私たちを与える:すべての罰金とダンディ

+---+--------+-------------+ 
| id|category|categoryIndex| 
+---+--------+-------------+ 
| 0|  a|   0.0| 
| 1|  b|   2.0| 
| 2|  c|   1.0| 
| 3|  a|   0.0| 
| 4|  a|   0.0| 
| 5|  c|   1.0| 
+---+--------+-------------+ 

が、多くのユースケースのために、私は私との間のマッピングを知りたいです元の文字列とインデックスラベル。私は私が望んでいた場合、辞書または類似として保存することができ、その結果

In [8]: indexed.select('category','categoryIndex').distinct().show() 
+--------+-------------+ 
|category|categoryIndex| 
+--------+-------------+ 
|  b|   2.0| 
|  c|   1.0| 
|  a|   0.0| 
+--------+-------------+ 

::私は手を、これをオフに行うために考えることができる最も簡単な方法は、このようなものである

In [12]: mapping = {row.categoryIndex:row.category for row in 
      indexed.select('category','categoryIndex').distinct().collect()} 

In [13]: mapping 
Out[13]: {0.0: u'a', 1.0: u'c', 2.0: u'b'} 

私の質問はこれです:これは一般的な作業なので、文字列インデクサーがこのマッピングをどうにか格納していると推測していますが(もちろん間違っている可能性があります)、上記の作業をより簡単に行う方法はありますか?

私の解決策は多かれ少なかれ簡単ですが、大規模なデータ構造の場合、これは(おそらく)避けることができる余分な計算を必要とします。アイデア?

答えて

6

ラベルマッピングは、列メタデータから抽出することができます

ml_attr.vals位置とラベルの間のマッピングを提供
meta = [ 
    f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex" 
] 
meta[0] 
## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}} 

:あなたが数値に変換することができます

dict(enumerate(meta[0]["ml_attr"]["vals"])) 
## {0: 'a', 1: 'c', 2: 'b'} 

スパーク1.6+

をラベルに値はIndexToStringを使用します。上記のように列のメタデータが使用されます。

from pyspark.ml.feature import IndexToString 

idx_to_string = IndexToString(
    inputCol="categoryIndex", outputCol="categoryValue") 

idx_to_string.transform(indexed_df).drop("id").distinct().show() 
## +--------+-------------+-------------+ 
## |category|categoryIndex|categoryValue| 
## +--------+-------------+-------------+ 
## |  b|   2.0|   b| 
## |  a|   0.0|   a| 
## |  c|   1.0|   c| 
## +--------+-------------+-------------+ 

スパーク< = 1.5

それは汚いハックですが、次のようにあなたは、単にJavaのインデクサからラベルを抽出することができます。

from pyspark.ml.feature import StringIndexerModel 

# A simple monkey patch so we don't have to _call_java later 
def labels(self): 
    return self._call_java("labels") 

StringIndexerModel.labels = labels 

# Fit indexer model 
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df) 

# Extract mapping 
mapping = dict(enumerate(indexer.labels())) 
mapping 
## {0: 'a', 1: 'c', 2: 'b'} 
関連する問題