1
私のコードは次のようになります。あるクラスの戻り値を別のクラスで使用する方法は?
class RegularConv(nn.module):
def __init__(self):
super().__init__.()
self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True)
def forward (self, inputs):
inputs, indices = self.pool1(inputs)
return inputs, indices
class UpsamplingConv(nn.module):
def __init__(self, indices):
super().__init__.()
self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)
def forward (self, output):
output = self.unpool(output, indices, output_size=output_size)
return output
class Net(nn.module):
def __init__(self):
super().__init__.()
self.layer1 = RegularConv()
self.layer2 = UpsamplingConv()
def forward(self, input):
input, indices = self.layer1(input)
input = self.layer2(input, indices)
RegularConv
クラスのforwardメソッドは、入力とインデックスを返し、UpsamplingConv
クラスの前方方法は、これらのインデックスを使用する必要があります。
私は別のクラスでこれらのクラスを両方とも使用しますNet
。あるクラスの戻り値を使用し、別のクラスで使用したい。私はあなたが、以下のようなリストとしてそれを返すようにしようと
def forward (self, inputs):
inputs, indices = self.pool1(sub_inputs)
return [inputs, indices]
そしてアウトそのパラメータを呼び出すために、リストのインデックスを使用することができますUpsamplingConv