0

私は少し架空の機械学習問題で、pythonとscikit-learnのツリー分類子を使用しています。私はバイナリ結果変数(wc_measure)を持っており、それはいくつかの他の変数(cashcrisis、およびindustry)に依存していると信じています。いくつかの属性を学習(ツリー)する方法は?

# import neccessary packages 
import pandas as pd 
import numpy as np 
import sklearn as skl 
from sklearn import tree 
from sklearn.cross_validation import train_test_split as tts 


# import data and give a little overview 
sample = pd.read_stata('sample_data.dta') 

s = sample 


# What I want to learn on 
X = [s.crisis, s.cash, s.industry] 
y = s.wc_measure 
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5) 


#let's learn a little 

my_tree = tree.DecisionTreeClassifier() 
clf = my_tree.fit(X_train, y_train) 
predictions = my_tree.predict(X_test) 

私が取得エラー以下:Number of labels=50 does not match number of samples=1私は次のことを試してみました。私がXを単一の変数(例えば、X = s.crisis)に基づいている場合、私はXを再形成するように求められます。私はこれらの問題のいずれかを持っている理由を完全に理解していません...アイデア?


PS:これは、印刷のリターン(X)

[0  4.0 
1  4.0 
2  5.0 
3  3.0 
4  4.0 
5  2.0 
6  2.0 
7  1.0 
8  3.0 
9  3.0 
10 4.0 
11 3.0 
12 2.0 
13 4.0 
14 5.0 
15 4.0 
16 2.0 
17 2.0 
18 3.0 
19 2.0 
20 5.0 
21 4.0 
22 2.0 
23 4.0 
24 5.0 
25 1.0 
26 5.0 
27 3.0 
28 4.0 
29 2.0 
    ... 
70 1.0 
71 4.0 
72 4.0 
73 1.0 
74 4.0 
75 3.0 
76 4.0 
77 2.0 
78 2.0 
79 5.0 
80 2.0 
81 3.0 
82 5.0 
83 4.0 
84 4.0 
85 5.0 
86 3.0 
87 3.0 
88 4.0 
89 2.0 
90 2.0 
91 3.0 
92 3.0 
93 4.0 
94 3.0 
95 1.0 
96 4.0 
97 2.0 
98 3.0 
99 4.0 
Name: crisis, dtype: float32, 0  450.283417 
1  113.472214 
2  11.811784 
3  1007.507446 
4  293.895142 
5  1133.297729 
6  2237.830322 
7  1475.787109 
8  283.363678 
9  626.888794 
10  38.865730 
11  991.999390 
12 1115.746948 
13  373.537231 
14  97.570717 
15  136.079193 
16 2560.691406 
17  667.062073 
18 1378.384521 
19  152.716400 
20  5.779267 
21  481.511566 
22  677.809631 
23  722.521790 
24  32.927990 
25 2504.450928 
26  17.422865 
27  651.585083 
28  549.469177 
29  297.458527 
     ...  
70 1198.370239 
71  471.343933 
72  389.709290 
73 2962.622803 
74  581.519287 
75 1148.822388 
76  67.653664 
77 1346.391602 
78 1764.086914 
79  14.308219 
80  973.152161 
81  552.576904 
82  2.863116 
83  425.520752 
84  321.773682 
85  63.597332 
86 1351.122559 
87  735.856567 
88  745.656677 
89 2784.453125 
90 1438.272705 
91  768.780823 
92  827.021423 
93  591.778015 
94  885.169434 
95 1143.088867 
96  399.816803 
97 1517.454834 
98 1311.692505 
99  533.062561 
Name: cash, dtype: float32, 0  5.0 
1  2.0 
2  3.0 
3  5.0 
4  4.0 
5  3.0 
6  5.0 
7  1.0 
8  1.0 
9  2.0 
10 1.0 
11 5.0 
12 2.0 
13 4.0 
14 6.0 
15 2.0 
16 6.0 
17 2.0 
18 5.0 
19 1.0 
20 3.0 
21 4.0 
22 2.0 
23 6.0 
24 4.0 
25 4.0 
26 3.0 
27 3.0 
28 5.0 
29 1.0 
    ... 
70 2.0 
71 4.0 
72 3.0 
73 6.0 
74 6.0 
75 5.0 
76 1.0 
77 3.0 
78 5.0 
79 4.0 
80 2.0 
81 3.0 
82 2.0 
83 5.0 
84 3.0 
85 5.0 
86 5.0 
87 4.0 
88 6.0 
89 6.0 
90 4.0 
91 3.0 
92 4.0 
93 6.0 
94 3.0 
95 2.0 
96 3.0 
97 4.0 
98 6.0 
99 4.0 

PPSです:あなたがする必要が

clear matrix 
clear all 
set more off 

set obs 100 
gen id = _n 


*Basics 
    gen industry = round(runiform()*5+1) 
    gen activity = round(runiform()*5+1) 
    gen crisis = round(runiform()*4+1) 
     egen min_crisis = min(crisis) 
     egen max_crisis = max(crisis) 
     gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis) 

*Company details 
    gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1) 

    gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform() 
     replace revenue = 0 if revenue<0 

    *Working Capital (wc) 
    gen stock = runiform()*0.5*crisis*revenue 
    gen receivables = runiform()*0.5*crisis*revenue 
    gen payables = runiform()*-0.5*crisis*revenue 
     replace payables = 0 if payables < 0 
    gen wc = stock + receivables - payables 
     egen avg_wc = mean(wc), by(industry) 


    *Liquidity 
    gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform() 
     replace loan = 0 if loan<0 
     egen pc_loan = pctile(loan), p(0.2) by(industry) 
     replace loan = 0 if loan<pc_loan 

    gen current_debt = n_crisis * loan + runiform()*100 

    gen cash = (1-n_crisis)*revenue + runiform()*100 


*Measures 

    *WC-measure (binary) 
     gen wc_status = (wc-avg_wc) 
      egen max_wc_status = max(wc_status), by(industry) 
      egen min_wc_status = min(wc_status), by(industry) 
      gen n_wc_status = (wc_status - min_wc_status)/(max_wc_status-min_wc_status) 
    gen wc_measure = round(n_wc_status) 
+1

あなたはsample_data.dtaファイルを共有することはできますか? – Xevaquor

+0

Tは正しい入力ですか? –

+0

私はそれを共有することはできません。しかし、データの作成方法を示すsataスクリプトをアップロードします。 – Rachel

答えて

0

私はついにこの問題を解決しました。問題は、サンプルsを配列として定義しなかったことです。対応して、Xはリストでした。助けてくれてありがとう!ここで

は私がやったことです:

# import data and give a little overview 
sample = pd.read_stata('sample_data.dta') 
s = sample 
print(s.shape) 


# Have some mor vars and an array of explanatory vars 


X = np.array((s.crisis, s.cash, s.industry)).reshape(100, 3) 
y = np.array(s.wc_measure) 
X_train, X_test, y_train, y_test = tts(X, y, test_size = .8) 


#let's learn a little 

my_tree = tree.DecisionTreeClassifier() 
clf = my_tree.fit(X_train, y_train) 
predictions = my_tree.predict(X_test) 
1

:ここで私はStataの中のデータを生成する方法でありますXがttsの正しい入力であるかどうかを確認しますか? Xは3行とN列を持ちます。 Xは3つの属性を持つN行を持つ必要があります。それは、番号が一致しないと不平を言う理由です。

+0

私は今のところこれを実際には得られません。私はPythonに全く新しいです。私はXのデータ構造をN行/ 3列と仮定しました。どうすれば確認できますか? – Rachel

+1

array([[1,2]、[3,4]、[5,6]、[7,9]])のようになります。この配列には2つの列と4つの行があります。こちらをご覧ください。 http://scikit-learn.org/stable/modules/generated/sklearn.cross_validation.train_test_split.html train_test_splitの例があります。うまくいけば、これは助ける –

+0

ありがとう!私は最初の数行などを印刷する方法を知っていますが、どのように配列の構造を理解するのですか? – Rachel

関連する問題