2016-01-14 8 views
6

今日私は私の分類モデルから混乱行列をプロットしようとしました。多くの文字列ラベルを持つMatplotlib matshow

一部のページで検索したところ、matshowからpyplotまでお手伝いできます。

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    cax = ax.matshow(cm) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     ax.set_xticklabels([''] + labels) 
     ax.set_yticklabels([''] + labels) 
    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.show() 

それは私が、私は多くのラベルを持っている場合でも、一部のラベルが表示されませんいくつかのラベル

y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a'] 
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a'] 
labels = list(set(y_true)) 
cm = confusion_matrix(y_true, y_pred) 
plot_confusion_matrix(cm, labels=labels) 

enter image description here

を持っている場合はうまく動作正しく

y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'] 
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'] 
labels = list(set(y_true)) 
cm = confusion_matrix(y_true, y_pred) 
plot_confusion_matrix(cm, labels=labels) 

enter image description here

私の質問はどのように私はmatshowプロットですべてのラベルを表示できますか?私はfontdictのようなものを試しましたが、それでも動作しません。

答えて

3

matplotlib.tickerモジュールを使用してティックの頻度を制御できます。この場合

、あなたはダニに1のすべての倍数を設定したいので、我々はあなたがplt.show()を呼び出す前に、これらの2行を追加します。MultipleLocator

を使用することができます。

ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 
ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 

そして、それは生成されますあなたのy_truey_predの各文字のダニとラベル。

私はまた、あなたが関数呼び出しで指定したカラーマップを利用するために、あなたのmatshow呼び出しに変更:完全性については

cax = ax.matshow(cm,cmap=cmap) 

enter image description here

を、あなたの全体の機能は次のようになります。

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 
import matplotlib.ticker as ticker 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call 
    cax = ax.matshow(cm,cmap=cmap) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     ax.set_xticklabels([''] + labels) 
     ax.set_yticklabels([''] + labels) 

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 

    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.savefig('confusionmatrix.png') 
+0

トムありがとう。別の質問、あなたはmatshowプロットで混乱行列の価値を引き出す方法を知っていますか? –

+0

あなたが望むものの詳細について新しい質問をする方がいいでしょう – tom

4

xticksメソッドを使用してラベルを指定できます。あなたの関数は次のようになります(上の答えから関数を修正します):

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call 
    cax = ax.matshow(cm,cmap=cmap) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     plt.xticks(range(len(labels)), labels) 
     plt.yticks(range(len(labels)), labels) 

    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.savefig('confusionmatrix.png') 
関連する問題