私はいくつかのデータセット(例として、CIFAR10
)でトレーニングしたいネットワークを持っています。pytorchデータセットのサブセットの取得
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
質問:私はいくつかの異なるトレーニング反復をしたいとします。奇妙な位置にあるすべての画像、次に偶数の位置にあるすべての画像などでネットワークを最初に訓練したいとしましょう。そのためには、私はそれらの画像にアクセスできる必要があります。残念ながら、trainset
はそのようなアクセスを許可していないようです。つまり、trainset[:1000]
以上を実行しようとすると、通常trainset[mask]
がエラーになります。
私はすでに、私はそうtrainset.train_data
を変更しかし、それは(各反復でフルセットの新しいコピーを作成するために私を強制する代わりに
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
、その後
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
を行うことができますtrainset
を再定義する必要があります)。それを避ける方法はありますか?
理想的には、私はあなたがデータセットローダーが(ちょうどそれぞれ異なるサンプリングのための新しいローダーを作成する)データセットを再作成回避するためのカスタム・サンプラーを定義することができます
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
ありがとうございます! 1つの小さな発言:明らかにサンプラーはシャッフルと互換性がないので、同じ結果を得るためには、次のようにします:torch.utils.data.DataLoader(trainset、batch_size = 4、 sampler = SubsetRandomSampler(np.where(mask)[ 0])、シャッフル= False、num_workers = 2) –