2
Torchでカスタム損失機能を実装するために必要な手順は何ですか?Torchでカスタム損失機能を追加する
updateOutputとupdateGradInputの実装を記述する必要があるようです。
これはすべてですか?だから、あなたは基本的に新しいクラスを作成します。
local CustomCriterion, parent = torch.class('CustomCriterion','nn.Criterion')
と次の2つの関数を実装:
function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)
は正しい、または行うことがよりがあることですか?
また、これらの機能はC言語で実装されていますが、Luaの実装も動作すると思いますが、少し遅いですが?