2017-03-06 5 views
0

私は現在トーチ7を使用していますが、特にクロスエントロピーエラー機能をカスタマイズする必要があります。Torchの損失機能をカスタマイズするにはどうすればよいですか?

クロスエントロピーエラー関数にいくつかのパラメータを追加することを考えていますが、どの部分を修正するべきかを見つけることができませんでした。

私はCrossEntropyCriterion.luaを見ましたが、このファイルにはまだ何も表示されません。

誰でも私の方程式はどこですか?どのファイルを変更する必要がありますか?

+0

可能な複製(http://stackoverflow.com/questions/33648796/add -my-custom-loss-function-tochch) –

答えて

0

損失関数をカスタマイズするには、__init,updateOutputおよびupdateGradInputのメソッドを変更する必要があります。

  • __init

    はあなたが:backward()を使用する場合
  • updateGradInputが呼び出されます、あなたの基準で:forward()メソッドを使用する場合
  • updateOutputが呼ばれるクラスの初期化関数であり、それはあなたの基準の勾配である

カスタマイズされた基準の構造は次のようになります。

local yourCriterion, parent = torch.class('nn.yourCriterion', 'nn.Criterion') 

function yourCriterion:__init(your_parameters): 
    parent.__init(self) 
    ... (you can add as many parameters as you want to your criterion 
     and give them the name your prefer) 
    self.parameters = your_parameters 


function yourCriterion:updateOutput(input) 
    ... (your criterion code here) 
    return value_of_the_criterion 

function yourCriterion:updateGradInput(input): 
    ... (your criterion gradient code here) 
    return gradient 

[EDIT]:あなたがここにクロスエントロピー基準のコードを見つけることができますhttps://github.com/torch/nn/blob/master/CrossEntropyCriterion.lua

[トーチに私のカスタムの損失関数を追加]の
+0

申し訳ありませんが、クロスエントロピー基準へのリンクを忘れています。私はそれを追加する私の答えを編集しました。 –

+0

あなたの答えに感謝:)私は試してみます –

関連する問題