私は少し架空の機械学習問題で、pythonとscikit-learnのツリー分類子を使用しています。私はバイナリ結果変数(wc_measure
)を持っており、それはいくつかの他の変数(cash
、crisis
、およびindustry
)に依存していると信じています。いくつかの属性を学習(ツリー)する方法は?
# import neccessary packages
import pandas as pd
import numpy as np
import sklearn as skl
from sklearn import tree
from sklearn.cross_validation import train_test_split as tts
# import data and give a little overview
sample = pd.read_stata('sample_data.dta')
s = sample
# What I want to learn on
X = [s.crisis, s.cash, s.industry]
y = s.wc_measure
X_train, X_test, y_train, y_test = tts(X, y, test_size = .5)
#let's learn a little
my_tree = tree.DecisionTreeClassifier()
clf = my_tree.fit(X_train, y_train)
predictions = my_tree.predict(X_test)
私が取得エラー以下:Number of labels=50 does not match number of samples=1
私は次のことを試してみました。私がX
を単一の変数(例えば、X = s.crisis
)に基づいている場合、私はX
を再形成するように求められます。私はこれらの問題のいずれかを持っている理由を完全に理解していません...アイデア?
PS:これは、印刷のリターン(X)
[0 4.0
1 4.0
2 5.0
3 3.0
4 4.0
5 2.0
6 2.0
7 1.0
8 3.0
9 3.0
10 4.0
11 3.0
12 2.0
13 4.0
14 5.0
15 4.0
16 2.0
17 2.0
18 3.0
19 2.0
20 5.0
21 4.0
22 2.0
23 4.0
24 5.0
25 1.0
26 5.0
27 3.0
28 4.0
29 2.0
...
70 1.0
71 4.0
72 4.0
73 1.0
74 4.0
75 3.0
76 4.0
77 2.0
78 2.0
79 5.0
80 2.0
81 3.0
82 5.0
83 4.0
84 4.0
85 5.0
86 3.0
87 3.0
88 4.0
89 2.0
90 2.0
91 3.0
92 3.0
93 4.0
94 3.0
95 1.0
96 4.0
97 2.0
98 3.0
99 4.0
Name: crisis, dtype: float32, 0 450.283417
1 113.472214
2 11.811784
3 1007.507446
4 293.895142
5 1133.297729
6 2237.830322
7 1475.787109
8 283.363678
9 626.888794
10 38.865730
11 991.999390
12 1115.746948
13 373.537231
14 97.570717
15 136.079193
16 2560.691406
17 667.062073
18 1378.384521
19 152.716400
20 5.779267
21 481.511566
22 677.809631
23 722.521790
24 32.927990
25 2504.450928
26 17.422865
27 651.585083
28 549.469177
29 297.458527
...
70 1198.370239
71 471.343933
72 389.709290
73 2962.622803
74 581.519287
75 1148.822388
76 67.653664
77 1346.391602
78 1764.086914
79 14.308219
80 973.152161
81 552.576904
82 2.863116
83 425.520752
84 321.773682
85 63.597332
86 1351.122559
87 735.856567
88 745.656677
89 2784.453125
90 1438.272705
91 768.780823
92 827.021423
93 591.778015
94 885.169434
95 1143.088867
96 399.816803
97 1517.454834
98 1311.692505
99 533.062561
Name: cash, dtype: float32, 0 5.0
1 2.0
2 3.0
3 5.0
4 4.0
5 3.0
6 5.0
7 1.0
8 1.0
9 2.0
10 1.0
11 5.0
12 2.0
13 4.0
14 6.0
15 2.0
16 6.0
17 2.0
18 5.0
19 1.0
20 3.0
21 4.0
22 2.0
23 6.0
24 4.0
25 4.0
26 3.0
27 3.0
28 5.0
29 1.0
...
70 2.0
71 4.0
72 3.0
73 6.0
74 6.0
75 5.0
76 1.0
77 3.0
78 5.0
79 4.0
80 2.0
81 3.0
82 2.0
83 5.0
84 3.0
85 5.0
86 5.0
87 4.0
88 6.0
89 6.0
90 4.0
91 3.0
92 4.0
93 6.0
94 3.0
95 2.0
96 3.0
97 4.0
98 6.0
99 4.0
PPSです:あなたがする必要が
clear matrix
clear all
set more off
set obs 100
gen id = _n
*Basics
gen industry = round(runiform()*5+1)
gen activity = round(runiform()*5+1)
gen crisis = round(runiform()*4+1)
egen min_crisis = min(crisis)
egen max_crisis = max(crisis)
gen n_crisis = (crisis - min_crisis)/(max_crisis-min_crisis)
*Company details
gen staff = round((0.5 * industry + 0.3 * activity - 0.2 * crisis) * runiform()*100+1)
gen revenue = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform()
replace revenue = 0 if revenue<0
*Working Capital (wc)
gen stock = runiform()*0.5*crisis*revenue
gen receivables = runiform()*0.5*crisis*revenue
gen payables = runiform()*-0.5*crisis*revenue
replace payables = 0 if payables < 0
gen wc = stock + receivables - payables
egen avg_wc = mean(wc), by(industry)
*Liquidity
gen loan = (0.5 * industry + 0.2 * activity - 0.3 * crisis) * 1000 + runiform()
replace loan = 0 if loan<0
egen pc_loan = pctile(loan), p(0.2) by(industry)
replace loan = 0 if loan<pc_loan
gen current_debt = n_crisis * loan + runiform()*100
gen cash = (1-n_crisis)*revenue + runiform()*100
*Measures
*WC-measure (binary)
gen wc_status = (wc-avg_wc)
egen max_wc_status = max(wc_status), by(industry)
egen min_wc_status = min(wc_status), by(industry)
gen n_wc_status = (wc_status - min_wc_status)/(max_wc_status-min_wc_status)
gen wc_measure = round(n_wc_status)
あなたはsample_data.dtaファイルを共有することはできますか? – Xevaquor
Tは正しい入力ですか? –
私はそれを共有することはできません。しかし、データの作成方法を示すsataスクリプトをアップロードします。 – Rachel