0
gather関数をpytorchで使用しようとしていますが、dim
パラメータの役割を理解できません。ギャザー機能のパラメータ次元の影響
コード:
t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
出力:
1 2
3 2
[torch.FloatTensor of size 2x2]
1にセット寸法:
print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
出力は次のようになります。
1 1
4 3
[torch.FloatTensor of size 2x2]
どうすれば、実際にgather
関数が機能しますか?