2017-07-22 20 views
2

KNearest Neighborsのマルチクラスマシン学習の問題を解決しようとしています。データ内の10個すべてのクラスの予測の混乱行列をプロットしています。 Matplotlib.pyplotのimshow。いくつかのクラスは、他のものよりもデータ内でかなり多くの出現を持ちますが、他のクラスでは50までしか持てないので、パーセンテージを表示するために正規化しました。グラフの横にカラーバーがあり、正規化されていない場合は1から3000の範囲で表示されます。しかし、それを正規化した後、その範囲は最大3000のままです。Scikitによって提供されるプロット関数を使用しています。hereで自分自身を学びます。欠けているものがあるのか​​、それともカラーバーの値の範囲を減らすために余分なステップがありますか?Matplotlibグラフのカラーバーが正規化された混同行列を使用して値を更新しない

コード

virdis = plt.cm.viridis 
blues = plt.cm.Blues 
autumn = plt.cm.autumn 

def plot_confusion_matrix(cm, classes, 
          normalize=False, 
          title='Confusion matrix', 
          cmap=blues): 
    """ 
    This function prints and plots the confusion matrix. 
    Normalization can be applied by setting `normalize=True`. 
    """ 
    plt.imshow(cm, interpolation='nearest', cmap=cmap) 
    plt.title(title) 

    bounds=[0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1] 
    plt.colorbar(boundaries=bounds) 

    tick_marks = np.arange(len(classes)) 
    plt.xticks(tick_marks, classes, rotation=45) 
    plt.yticks(tick_marks, classes) 

    if normalize: 
     cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] 
     print("Normalized confusion matrix") 
    else: 
     print('Confusion matrix, without normalization') 

    cm = np.around(cm, decimals=3) 

    thresh = cm.max()/2. 

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 
     plt.text(j, i, cm[i, j], 
       horizontalalignment="center", 
       color="white" if i == 9 and j == 9 else "black") 

    plt.ylabel('True label') 
    plt.xlabel('Predicted label') 

knn = KNeighborsClassifier() 
knn.fit(X_train, y_train) 

knn_score = knn.score(X_test, y_test) 
knn_fold_score = model_selection.cross_val_score(knn, X_test, y_test, cv=10).mean() 
predictions = knn.predict(X_test) 

c_matrix = confusion_matrix(y_test, predictions) 

# Plot normalized confusion matrix 
plt.figure() 
plot_confusion_matrix(c_matrix, classes=country_names, normalize=True, 
         title='Normalized confusion matrix') 

plt.show() 

enter image description here

答えて

2

あなたは、カラーバーとその範囲ステーを理解することができ、プロット、すなわちplt.imshowに滞在する必要があります。 Scikit Learnの例とあなたの例では、正規化を行うかどうかを決定する前にマトリックスをプロットしています。したがって、2つのプロットと関連するカラーバーはまったく同じに見えます。あなたは、次のブロックに移動すなわち、プロットする前に正規化を処理する場合:plt.imshow(cm, interpolation='nearest', cmap=cmap)の前でに

if normalize: 
    cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] 
    print("Normalized confusion matrix") 
else: 
    print('Confusion matrix, without normalization') 

cm = np.around(cm, decimals=3) 

を、正規化されたプロットのカラーバーは、その後の(色、ちょうどあなたを思い出させるために、再び0から1の範囲にあるだろう)プロット自体も変化します。カラーバー自体とそれに関連するプロットを変更することなく、カラーバーのテキストラベルを0から1の範囲に変更するだけの良い考えではないと思います。

関連する問題