バープロットを使用して視覚化しました。最初のプロットはクラスの分布を示しています。最初のタイトルは最初の分割基準を表します。この基準を満たすすべてのデータは、左下のサブプロットになります。そうでない場合は、正しいプロットが結果になります。したがって、すべてのタイトルは、次の分割の分割基準を示します。
パーセントは初期分布からの値です。したがって、前払い額を見ることで、少数の分割後に最初の量のデータがどれだけ残っているかを簡単に知ることができます。
注意、あなたはこれがサブプロットの多くを必要とするであろうことをMAX_DEPTH高く設定されている場合(MAX_DEPTH、2 ^奥行き)
Tree visualization using bar plots
コード:
def give_nodes(nodes,amount_of_branches,left,right):
amount_of_branches*=2
nodes_splits=[]
for node in nodes:
nodes_splits.append(left[node])
nodes_splits.append(right[node])
return (nodes_splits,amount_of_branches)
def plot_tree(tree, feature_names):
from matplotlib import gridspec
import matplotlib.pyplot as plt
from matplotlib import rc
import pylab
color = plt.cm.coolwarm(np.linspace(1,0,len(feature_names)))
plt.rc('text', usetex=True)
plt.rc('font', family='sans-serif')
plt.rc('font', size=14)
params = {'legend.fontsize': 20,
'axes.labelsize': 20,
'axes.titlesize':25,
'xtick.labelsize':20,
'ytick.labelsize':20}
plt.rcParams.update(params)
max_depth=tree.max_depth
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
fig = plt.figure(figsize=(3*2**max_depth,2*2**max_depth))
gs = gridspec.GridSpec(max_depth, 2**max_depth)
plt.subplots_adjust(hspace = 0.6, wspace=0.8)
# All data
amount_of_branches=1
nodes=[0]
normalize=np.sum(value[0][0])
for i,node in enumerate(nodes):
ax=fig.add_subplot(gs[0,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
ax.set_title(features[node]+"$<= "+str(threshold[node])+"$")
if(i==0): ax.set_ylabel(r'$\%$')
ind=np.arange(1,len(value[node][0])+1,1)
width=0.2
bars= (np.array(value[node][0])/normalize)*100
plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
plt.xticks(ind, [int(i) for i in ind-1])
pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))
# Splits
for j in range(1,max_depth):
nodes,amount_of_branches=give_nodes(nodes,amount_of_branches,left,right)
for i,node in enumerate(nodes):
ax=fig.add_subplot(gs[j,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
ax.set_title(features[node]+"$<= "+str(threshold[node])+"$")
if(i==0): ax.set_ylabel(r'$\%$')
ind=np.arange(1,len(value[node][0])+1,1)
width=0.2
bars= (np.array(value[node][0])/normalize)*100
plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
plt.xticks(ind, [int(i) for i in ind-1])
pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))
plt.tight_layout()
return fig
例:
X=[]
Y=[]
amount_of_labels=5
feature_names=[ '$x_1$','$x_2$','$x_3$','$x_4$','$x_5$']
for i in range(200):
X.append([np.random.normal(),np.random.randint(0,100),np.random.uniform(200,500) ])
Y.append(np.random.randint(0,amount_of_labels))
clf = tree.DecisionTreeClassifier(criterion='entropy',max_depth=4)
clf = clf.fit(X,Y)
fig=plot_tree(clf, feature_names)
>>> import os >>> os.unlink( 'iris.dot') –
I tはこれをすると言う^。ただし、ファイルを削除するだけです。何か案は?私もpydotplusを持っていません。私はピップを使用してダウンロードしようとしましたが、うまくいきませんでした。 –
私は問題がGraphizだと思うし、それをダウンロードする必要があります:http://www.graphviz.org/Download..php http://stackoverflow.com/questions/18438997/why-is-pydot-unable-to-find -graphvizs-executables-in-windows-8です。最初にgraphizをインストールしてからpydotをインストールします。または、linuxを使用してください。少し後で私はそれに戻ってくるでしょう。 – Roxanne