2017-01-29 1 views
1

私はsklearntest_train_splitモジュールを使用してトレーニングとテスト用のランダムなデータセットを生成しています。元のデータセットを指すファイルパスのリストがあります。私はまた、シャッフル/分割データセットが利用可能になったら、ファイルパスをトレースできるように、データがどのようにシャッフルされているか、同じ方法でファイルパスリストをシャッフルしていますか?sklearn test train split - 元のリストのファイル名へのインデックスを取得

答えて

4

同じrandom_statetest_train_splitshuffleに指定すると、同じ順序になります。

デモンストレーションについては、下記のスニペットをご覧ください。いいです

from sklearn.utils import shuffle 
from sklearn.model_selection import train_test_split 
import random 

X = list() 
y = list() 
files = list() 

random_state = 42 
test_size = 0.9 
n = 100 

for i in range(n): 
    X.append(i) 
    y.append(i + random.random()) 
    files.append('file_{0:02d}.csv'.format(i)) 

X_train, X_test, y_train, y_test = train_test_split(X, 
               y, 
               test_size=test_size, 
               random_state=random_state) 
X_shuffle = shuffle(X, random_state=random_state) 
y_shuffle = shuffle(y, random_state=random_state) 
file_shuffle = shuffle(files, random_state=random_state) 

print(X_train) 
print(X_shuffle[int(n * test_size):]) 
print(y_shuffle[int(n * test_size):]) 
print(file_shuffle[int(n * test_size):]) 
print(X_train == X_shuffle[int(n * test_size):]) 
[91, 74, 86, 82, 20, 60, 71, 14, 92, 51] 
[91, 74, 86, 82, 20, 60, 71, 14, 92, 51] 
[91.64119581793204, 74.77493553783724, 86.62410189510936, 82.40452263996107, 20.22784747831378, 60.913989700418675, 71.1940538438253, 14.644282494118647, 92.97808337955185, 51.289858815186356] 
['file_91.csv', 'file_74.csv', 'file_86.csv', 'file_82.csv', 'file_20.csv', 'file_60.csv', 'file_71.csv', 'file_14.csv', 'file_92.csv', 'file_51.csv'] 

True 
+1

- 理想的には、最初にそれを試してみたはず!精巧なデモをするのは本当に素晴らしい – rookie

関連する問題