2017-09-23 4 views
0

ネットワークへの入力は、寸法がn mのRGB​​画像です。出力はどのようにしてn mの寸法になりますか?PyTorchのnxmラベルへのn * m * 3入力画像

class Net(nn.Module): 
    def __init__(self): 
     super(Net, self).__init__() 
     self.conv1 = nn.Conv2d(3, 20, kernel_size = 5) 
     self.conv2 = nn.Conv2d(20, 50, kernel_size = 3) 
     self.conv3 = nn.ConvTranspose2d(50,20, kernel_size = 5) 
     self.conv4 = nn.ConvTranspose2d(20,1, kernel_size = 3) 

    def forward(self, x): 
     x = F.relu(self.conv1(x)) 
     x = F.relu(self.conv2(x)) 
     x = F.relu(self.conv3(x)) 
     x = F.relu(self.conv4(x)) 
    return x 

現在、1 * n * mを出力しています。どのようにn * mを出力できますか?

答えて

0

テンソルのサイズを変更して同じ数の要素にする場合は、通常torch.viewを使用します。

さらに簡単な解決策があります。torch.squeezeは、サイズ1のすべての寸法が削除されたTensorを返します。

関連する問題