2016-05-28 12 views
2

Torchでカスタム損失機能を実装するために必要な手順は何ですか?Torchでカスタム損失機能を追加する

updateOutputとupdateGradInputの実装を記述する必要があるようです。

これはすべてですか?だから、あなたは基本的に新しいクラスを作成します。

local CustomCriterion, parent = torch.class('CustomCriterion','nn.Criterion') 

と次の2つの関数を実装:

function CustomCriterion:updateOutput(input, target) 
function CustomCriterion:updateGradInput(input, target) 

は正しい、または行うことがよりがあることですか?

また、これらの機能はC言語で実装されていますが、Luaの実装も動作すると思いますが、少し遅いですが?

答えて

0

私は(擬似コードで)あなたはそれを記述するだけの方法を

--assuming input is partitioned in input_a,input_b 
--   target is accordingly partitionend in target_a, target_b 
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b) 

かなりの数回のフォームの機能を実装しました。だから、私が知る限り、私はあなたの両方の質問に対する答えははいと思う。

基本的にnn/MSECriterion.luathisがこれをバックアップしているようです。

関連する問題