多分誰かが私をここで助けることができます。私は、ネットワークの所与の出力pytorch nn.CrossEntropyLoss()でのクロスエントロピー損失
print output
Variable containing:
1.00000e-02 *
-2.2739 2.9964 -7.8353 7.4667 4.6921 0.1391 0.6118 5.2227 6.2540
-7.3584
[torch.FloatTensor of size 1x10]
、フォームのxは0と9の間の整数である
print lab
Variable containing:
x
[torch.FloatTensor of size 1]
所望ラベルの交差エントロピー損失を計算しようとしています。 pytorchのドキュメント(http://pytorch.org/docs/master/nn.html)
criterion = nn.CrossEntropyLoss()
loss = criterion(output, lab)
によれば、これは動作するはずですが、残念ながら私は奇妙なエラー
を取得します誰でも助けてくれますか?私は本当に混乱して、私が助けてくれると想像していたほとんどすべてを試しました。
ベスト
あなたはLongTensorに 'ラボ' を変換してみてくださいことはできますか? –