2017-12-03 8 views
1

私のコードは次のようになります。あるクラスの戻り値を別のクラスで使用する方法は?

class RegularConv(nn.module): 
    def __init__(self): 
     super().__init__.() 
     self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True) 

    def forward (self, inputs): 
     inputs, indices = self.pool1(inputs) 
     return inputs, indices 

class UpsamplingConv(nn.module): 
    def __init__(self, indices): 
     super().__init__.() 
     self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0) 

    def forward (self, output): 
     output = self.unpool(output, indices, output_size=output_size) 
     return output 
class Net(nn.module): 
    def __init__(self): 
     super().__init__.() 
     self.layer1 = RegularConv() 
     self.layer2 = UpsamplingConv() 

    def forward(self, input): 
     input, indices = self.layer1(input) 
     input = self.layer2(input, indices) 

RegularConvクラスのforwardメソッドは、入力とインデックスを返し、UpsamplingConvクラスの前方方法は、これらのインデックスを使用する必要があります。

私は別のクラスでこれらのクラスを両方とも使用しますNet。あるクラスの戻り値を使用し、別のクラスで使用したい。私はあなたが、以下のようなリストとしてそれを返すようにしようと

def forward (self, inputs): 
     inputs, indices = self.pool1(sub_inputs) 
     return [inputs, indices] 

そしてアウトそのパラメータを呼び出すために、リストのインデックスを使用することができますUpsamplingConv

答えて

0

に引数としてRegularConvから返されたインデックスを渡したいです。

test = UpsamplingConv(RegularConv()[1]) 
関連する問題