2017-12-12 11 views
0

Simple multi-task network can be done here.しかし、私はこのようなものが欲しいenter image description here。 今、私は以下のようにモデルを構築:torch7でマルチタスク学習を行うには?

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 
prl1 = nn.ConcatTable() 
prl1:add(nn.Linear(5,1)) 
prl2 = nn.ConcatTable() 
prl2:add(nn.Linear(5,1)) 
prl2:add(nn.Linear(5,1)) 
prl1:add(prl2) 
model:add(prl1) 

そして、私の出力は次のとおりです。

input = torch.rand(5,3) 
output = model:forward(input) 
output 
{ 
    1 : DoubleTensor - size: 5x1 
    2 : 
    { 
     1 : DoubleTensor - size: 5x1 
     2 : DoubleTensor - size: 5x1 
    } 
} 

どのように私は私の基準を構築する必要がありますか?

答えて

0

私は2つのステップを通してそれを把握するように見える:上記のネットワーク内代わりnn.ConcatTableの

1.Use nn.Concat、これは出力は、例えば、簡単なN×Mのテンソルであることを作りますnx.ConcatTableの代わりにnn.Concatを使用している間に、5x3テンソルが上記のネットワークに入ります。

2.NxMテンソルを取得した後、出力結果をTensorの各結果を含む単純なテーブルにするには、nn.ConcatTable、nn.Concat、nn.Selectの組み合わせを使用します。

th> output 
{ 
    1 : DoubleTensor - size: 5x3 
    2 : DoubleTensor - size: 5x2 
} 

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 

prl = nn.ConcatTable() 

spl1 = nn.Concat(2) 

seq1 = nn.Sequential() 
seq1:add(nn.Select(2, 1)) 
seq1:add(nn.Reshape(1)) 

seq2 = nn.Sequential() 
seq2:add(nn.Select(2, 2)) 
seq2:add(nn.Reshape(1)) 

seq3 = nn.Sequential() 
seq3:add(nn.Select(2, 3)) 
seq3:add(nn.Reshape(1)) 

spl1:add(seq1) 
spl1:add(seq2) 
spl1:add(seq3) 
prl:add(spl1) 

spl2 = nn.Concat(2) 

seq4 = nn.Sequential() 
seq4:add(nn.Select(2, 4)) 
seq4:add(nn.Reshape(1)) 

seq5 = nn.Sequential() 
seq5:add(nn.Select(2, 5)) 
seq5:add(nn.Reshape(1)) 

spl2:add(seq4) 
spl2:add(seq5) 
prl:add(spl2) 

model:add(prl) 

input = torch.rand(5,3) 
output = model:forward(input) 

出力は次のようになります。ここでは

は、ステップ2の簡単な例であります
関連する問題