2017-11-01 23 views
2

多分誰かが私をここで助けることができます。私は、ネットワークの所与の出力pytorch nn.CrossEntropyLoss()でのクロスエントロピー損失

print output 
Variable containing: 
1.00000e-02 * 
-2.2739 2.9964 -7.8353 7.4667 4.6921 0.1391 0.6118 5.2227 6.2540  
-7.3584 
[torch.FloatTensor of size 1x10] 

、フォームのxは0と9の間の整数である

print lab 
Variable containing: 
x 
[torch.FloatTensor of size 1] 

所望ラベルの交差エントロピー損失を計算しようとしています。 pytorchのドキュメント(http://pytorch.org/docs/master/nn.html

criterion = nn.CrossEntropyLoss() 
loss = criterion(output, lab) 

によれば、これは動作するはずですが、残念ながら私は奇妙なエラー

を取得します

誰でも助けてくれますか?私は本当に混乱して、私が助けてくれると想像していたほとんどすべてを試しました。

ベスト

+1

あなたはLongTensorに 'ラボ' を変換してみてくださいことはできますか? –

答えて

3

確認してください。これはうまく損失をプリントアウトします

import torch 
import torch.nn as nn 
from torch.autograd import Variable 

output = Variable(torch.rand(1,10)) 
target = Variable(torch.LongTensor([1])) 

criterion = nn.CrossEntropyLoss() 
loss = criterion(output, target) 
print(loss) 

このコード:

Variable containing: 
2.4498 
[torch.FloatTensor of size 1] 
+0

ありがとうございました。問題は、ターゲット変数がLongテンソルでなければならず、私のコードにfloat型であるという問題でした。ありがとうございました! – Elias

関連する問題