2017-11-22 19 views
0

私はいくつかのデータセット(例として、CIFAR10)でトレーニングしたいネットワークを持っています。pytorchデータセットのサブセットの取得

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
             download=True, transform=transform) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
              shuffle=True, num_workers=2) 

質問:私はいくつかの異なるトレーニング反復をしたいとします。奇妙な位置にあるすべての画像、次に偶数の位置にあるすべての画像などでネットワークを最初に訓練したいとしましょう。そのためには、私はそれらの画像にアクセスできる必要があります。残念ながら、trainsetはそのようなアクセスを許可していないようです。つまり、trainset[:1000]以上を実行しようとすると、通常trainset[mask]がエラーになります。

私はすでに、私はそうtrainset.train_dataを変更しかし、それは(各反復でフルセットの新しいコピーを作成するために私を強制する代わりに

trainset.train_data=trainset.train_data[mask] 
trainset.train_labels=trainset.train_labels[mask] 

、その後

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
               shuffle=True, num_workers=2) 

を行うことができますtrainsetを再定義する必要があります)。それを避ける方法はありますか?

理想的には、私はあなたがデータセットローダーが(ちょうどそれぞれ異なるサンプリングのための新しいローダーを作成する)データセットを再作成回避するためのカスタム・サンプラーを定義することができます

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4, 
               shuffle=True, num_workers=2) 

答えて

1

と「同等」何かを持っていると思います。

class YourSampler(Sampler): 
    def __init__(self, mask): 
     self.mask = mask 

    def __iter__(self): 
     return (self.indices[i] for i in torch.nonzero(self.mask)) 

    def __len__(self): 
     return len(self.mask) 

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
             download=True, transform=transform) 

sampler1 = YourSampler(your_mask) 
sampler2 = YourSampler(your_other_mask) 
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4, 
              sampler = sampler1, shuffle=True, num_workers=2) 
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4, 
              sampler = sampler2, shuffle=True, num_workers=2) 

PS:コードを確認しませんでした。

PS2:詳細はこちら:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

+1

ありがとうございます! 1つの小さな発言:明らかにサンプラーはシャッフルと互換性がないので、同じ結果を得るためには、次のようにします:torch.utils.data.DataLoader(trainset、batch_size = 4、 sampler = SubsetRandomSampler(np.where(mask)[ 0])、シャッフル= False、num_workers = 2) –