2017-08-22 7 views
0

ここはおもちゃのモデルです。 backwardを正確に1回呼び出す前にモデルパラメータを印刷してから、モデルパラメータを再度印刷します。パラメータは変更されません。 backwardを呼び出した後にmodel:updateParameters(<learning_rate>)行を追加すると、パラメータの更新が表示されます。Torch:モデルパラメータはどのように更新されますか?

しかし、例のコードでは、たとえばhttps://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.luaのように、誰も実際にupdateParametersを呼び出していません。また、optim.sgdoptim.adam、またはnn.StochasticGradientのいずれかがupdateParametersと呼ばれることはありません。私はここで何が欠けていますか?パラメータは自動的にどのように更新されますか? updateParametersに電話しなければならない場合、どうしてその例はありませんか?

backward()は、パラメータを変更することになっていない
require 'nn' 
require 'optim' 

local model = nn.Sequential() 
model:add(nn.Linear(4, 1, false)) 
local params, grads = model:getParameters() 

local criterion = nn.MSECriterion() 
local inputs = torch.randn(1, 4) 
local labels = torch.Tensor{1} 

print(params) 

model:zeroGradParameters() 
local output = model:forward(inputs) 
local loss = criterion:forward(output, labels) 
local dfdw = criterion:backward(output, labels) 
model:backward(inputs, dfdw) 

-- With the line below uncommented, the parameters are updated: 
-- model:updateParameters(1000) 

print(params) 

答えて

1

、単にネットワークのすべてのパラメータに対する誤差関数の導関数を計算します。

一般的にトレーニングは一連のステップである:

repeat 
    local output = model:forward(input) --see what model predicts 
    local loss = criterion:forward(output, answer) --see how wrong it is 
    local loss_grad = criterion:backward(output, answer) --see where it is the most wrong 
    model:backward(input,loss_grad) --see how much each particular parameter of network is responsible for error 
    model:updateParameters(learningRate) --fix the parameters based on their wrongness 
    model:zeroGradParameters() --network parameters are different now, so old gradients are of no use now 
until is_user_satisfied() 

updateParametersは、ここでは最も簡単な最適化アルゴリズム(勾配降下)を実装しています。 そのような傾向がある場合は、独自の関数を代わりに使用することができます。理論的には、ネットワークストレージを介して明示的なループを実行して値を更新することがあります。実際に は、あなたは通常、あなたのすべての値および勾配の均質テンソルを得getParameters()

local model_parameters,model_parameters_gradient=model:getParameters() 

を呼び出します。これらのテンソルはネットワーク内のビューなので、それらの変化はネットワークに影響します。 ネットワーク内のどのポイントがどの値に対応しているかわからない場合がありますが、ほとんどのオプティマイザはそれを気にしません。

optim.sgd(
    function_to_return_error_and_its_gradients, 
    model_parameters, 
    optimizer_special_settings) 

詳細はデモで覆われているが、ここでは、オプティマイザは、それがネットワークへの書き込みアクセスを与えるパラメータとしてmodel_parametersを受けることに関連して次のように

demooptim.sgdの使用量があります。また、ドキュメントに明示的には記載されていませんが、source codeでは、オプティマイザはその入力テンソルの値を変更します(また、と同じテンソルを戻していることに注意してください)。

+0

本当に徹底的に、ありがとうございます。繰り返すために、 'updateParameters'はパラメータを手動で更新するために使用されます。 'updateParameters'を呼び出すと最適化が考慮されないため、最適化アルゴリズムはフラット化されたパラメータへの参照を更新します。 – gwg

関連する問題