2016-08-31 6 views
4

Random Forest classifierを3倍の相互検証を使用してトレーニングする必要があります。各サンプルについて、テストセット内にあるときに予測確率を取得する必要があります。scikit-learnのpredict_probaで使用されるcross_val_predictからクラスラベルを取得する方法

私はscikit-learnバージョン0.18.dev0を使用しています。

この新しいバージョンでは、追加のパラメータmethodを使用してメソッドcross_val_predict()を使用して、エスティメータからどの種類の予測が必要かを定義する機能を追加しました。

私のケースでは、マルチクラスのシナリオでは、各クラスの確率を返すpredict_proba()メソッドを使用したいと考えています。

しかし、このメソッドを実行すると、予測確率の行列が得られます。各行はサンプルを表し、各列は特定のクラスの予測確率を表します。

問題は、このメソッドが各クラスに対応するクラスを示さないことです。

Iは( RandomForestClassifierを使用して私の場合)と同じである必要がある値が属性に返さ

は以下のように定義classes_:

classes_:形状の配列= [n_classes]またはそのようなアレイ ザのリストクラスラベル(単一出力問題)、またはクラスラベルの配列のリスト(複数出力問題)が含まれます。

クラスの順序は、属性classes_のものに対応していますので、そのマニュアルpredict_proba()によって必要とされる

は、いることを書かれています。

最小の例は以下である:

import numpy as np 
from sklearn.ensemble import RandomForestClassifier 
from sklearn.model_selection import cross_val_predict 

clf = RandomForestClassifier() 

X = np.random.randn(10, 10) 
y = y = np.array([1] * 4 + [0] * 3 + [2] * 3) 

# how to get classes from here? 
proba = cross_val_predict(estimator=clf, X=X, y=y, method="predict_proba") 

# using the classifier without cross-validation 
# it is possible to get the classes in this way: 
clf.fit(X, y) 
proba = clf.predict_proba(X) 
classes = clf.classes_ 
+0

'y'にアクセスできない場合は? –

+0

@ juanpa.arrivillaga私は 'y'にアクセスできますが、ラベルのソート順序はわかりません。私は彼らが昇順にソートされていると推測するかもしれませんが、私は完全にはわかりません。 – gc5

+0

申し訳ありませんが、私はまだ0.18を持っていないので、これをテストすることはできませんし、明白かもしれませんが、 'clf'オブジェクトに' class_'属性が含まれていませんか? –

答えて

2

はい、それらはソートされた順序であろう。 DecisionTreeClassifierRandomForestClassifierの場合はデフォルトのbase_estimatoruses np.unique to construct the classes_ attributeが入力配列のソートされた一意の値を返すためです。

+0

それは聞いていいです:)ありがとう。したがって、 'np.unique'を呼び出す前に同じ' class_'属性を取得する必要がありますか? – gc5

+1

うん、それはそれを行うべきだと思います – maxymoo

+0

他の分類器はどうですか?クラスは常にラベル順であることが保証されていますか? – simon

関連する問題