2016-10-31 4 views
1

トーチのnngraphパッケージのグラフモジュール(gModule)に新しいノードを追加するにはどうすればよいですか?私はadd関数を使用しようとしましたが、gModulesオブジェクトのモジュールスロットにノードを追加しました。ただし、出力は前の最後のノードから取得されます。トーチのグラフモジュールにレイヤーを追加する方法

簡体コード:

require "nn" 
require "nngraph" 

-- Function that builds a gModule 
function buildModule(input_size,hidden_size) 
    local x = nn.Identity()() 
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh() 
    return nn.gModule({x},{out}) 
end 

network = buildModule(5,3) 
-- Additional layer to add 
l2 = nn.Linear(3,10) 
network:add(l2) 

-- Expected a tensor of size 10 but got one with size 3 
print(network:forward(torch.randn(5))) 

答えて

1

gModuleが実際に変異していることが想定されていません。それがサポートする事実:addは、実際にはnn.Containerの子クラスであるという副作用であり、設計上の決定ではありません。一般的に、gModuleを作成したら、内部構造を変更しないでください。内部属性を変更して、非常に効果的に動作させる必要があります。代わりに、 "上に"何かを追加する場合は、前のものを入力とする新しいコンテナを定義します。

-- Function that builds a gModule 
function buildModule(input_size,hidden_size) 
    local x = nn.Identity()() 
    local out = x - nn.Linear(input_size,hidden_size) - nn.Tanh() 
    return nn.gModule({x},{out}) 
end 

network = buildModule(5,3) 

new_network = nn.Sequential() 
new_network:add(network) 
new_network:add(nn.Linear(3,10)) 
+0

ありがとうございます – nist

関連する問題