2017-03-31 4 views
4

に私はTensorflowに次のnumpyのコードやりたいような指標の2D-インデックスを使用してテンソルをする:Tensorflow:どのようにnumpyの

input = tf.constant([[1, 2, 3], 
        [4, 5, 6], 
        [7, 8, 9]]) 

I:入力など、与えられた

input = np.array([[1,2,3] 
        [4,5,6] 
        [7,8,9]]) 
index1 = [0,1,2] 
index2 = [2,2,0] 
output = input[index1, index2] 
>> output 
[3,6,7] 

を次のように試してみましたが、オーバーシュートのようです:

index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2 
output = tf.gather(tf.reshape(input, [-1]), index3) 
sess = tf.Session() 
sess.run(output) 
>> [3,6,7] 

これは、最初のインデックスが[0,1,2 ](例えば、本当に長くて醜い以外にも)[0,0,2]では実行できません。

もっと簡単な構文がありますか?よりテンソル/ピジョンがありますか?

答えて

4

次のようにtf.gather_nd(tf.gather_nd official doc)を使用してそれを行うことができます。

import tensorflow as tf 
inp = tf.constant([[1, 2, 3], 
        [4, 5, 6], 
        [7, 8, 9]]) 
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0]))) 
sess = tf.Session() 
sess.run(res) 

結果はtf.gather_ndの使用に関するarray([3, 6, 7])

+0

ありがとう、ありがとう! – ylnor

3

どうですか?

In [61]: input = tf.constant([[1, 2, 3], 
    ...:      [4, 5, 6], 
    ...:      [7, 8, 9]]) 

In [63]: row_idx = tf.constant([0, 1, 2]) 
In [64]: col_idx = tf.constant([2, 2, 0]) 
In [65]: coords = tf.transpose(tf.pack([row_idx, col_idx])) 

In [67]: sess = tf.InteractiveSession() 

In [68]: tf.gather_nd(input, coords).eval() 
Out[68]: array([3, 6, 7], dtype=int32) 
関連する問題