2016-01-26 11 views
6

sklearnのStratifiedKFold関数を使用して、ここでエラーを理解するのを助けることができますか?StratifiedKFold:IndexError:配列のインデックスが多すぎます

私の推測では、ラベルの入力配列と何か関係があると思いますが、私がそれらを印刷すると(この例では最初の16個)、インデックスは0から15になりますが、余分な0はその上に印刷されます私は期待していませんでした。たぶん私はちょうどpythonのnoobですが、それは奇妙に見えます。

誰もがここでうわーを見ますか?

ドキュメント:http://scikit-learn.org...StratifiedKFold.html

コード:

import nltk 
import sklearn 

print('The nltk version is {}.'.format(nltk.__version__)) 
print('The scikit-learn version is {}.'.format(sklearn.__version__)) 

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape 
print skew_gendata_targets.head(16) 

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 

結果

The nltk version is 3.1. 
The scikit-learn version is 0.17. 
<type 'numpy.ndarray'> (500L, 1L) 
    0 
0 0 
1 0 
2 0 
3 0 
4 0 
5 0 
6 0 
7 0 
8 0 
9 0 
10 0 
11 0 
12 0 
13 0 
14 1 
15 0 
--------------------------------------------------------------------------- 
IndexError        Traceback (most recent call last) 
<ipython-input-373-653b6010b806> in <module>() 
     8 print skew_gendata_targets.head(16) 
     9 
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 
    11 
    12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')' 

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state) 
    531   for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)): 
    532    for label, (_, test_split) in zip(unique_labels, per_label_splits): 
--> 533     label_test_folds = test_folds[y == label] 
    534     # the test split can be too big because we used 
    535     # KFold(max(c, self.n_folds), self.n_folds) instead of 

IndexError: too many indices for array 

答えて

11

skew_gendata_targets.valuesの形状を確認してください。 StratifiedKFoldが期待するように1d配列(shape(500、))ではなく、(500,1)配列であることがわかります。 SKlearnはこれらを同じものと強制するのではなく、別々に扱います。それが役に立たない場合は教えてください

+0

印刷出力は、印刷タイプ(skew_gendata_targets.values)、skew_gendata_targets.values.shape、(500,1)numpyの配列です。私はpythonsのピットに投げ込まれたmatlabジャンキーで、500x1と500xnanadaの行列/配列/ thingieの違いを知らない。少なくともmatlabの世界では違いはありません。 –

+2

うん - その不幸な、やや紛らわしい。違いは、 '*'のような操作を行うときに重要です。あるケースでは、Pandas/numpyは要素ごとの乗算を行いますが、もう一方は行列乗算を行います。 StratifiedKFold操作を(500、)配列に強制した後でうまくいきたいと考えています。 – Brian

+1

matlabiesを再構成するのはmatlaberが理解できるものですが、これは修正されているようです:np.reshape(skew_gendata_targets.values、[500、])、ありがとう!! –

関連する問題