2017-10-18 11 views
0

gather関数をpytorchで使用しようとしていますが、dimパラメータの役割を理解できません。ギャザー機能のパラメータ次元の影響

コード:

t = torch.Tensor([[1,2],[3,4]]) 
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]]))) 

出力:

1 2 
3 2 
[torch.FloatTensor of size 2x2] 

1にセット寸法:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))) 

出力は次のようになります。

1 1 
4 3 
[torch.FloatTensor of size 2x2] 

どうすれば、実際にgather関数が機能しますか?

答えて

2

ギャザー機能の仕組みを理解しました。

​​

出力になるようdimensionは、ゼロであるので:

| t[index[0, 0] 0] t[index[0, 1] 1] | 
| t[index[1, 0] 0] t[index[1, 1] 1] | 

dimensionが1に設定されている場合、出力はとなるであろう:

| t[0 index[0, 0]] t[0 index[0, 1]] | 
| t[1 index[1, 0]] t[1 index[1, 1]] | 

それで式は:

For a 3-D tensor the output is specified by: 

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

リファレンス:http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather