2017-08-06 18 views
0

私はdataテンソルオブメーター[B X N X 3]を持ち、indicesテンソルディメンション[B X M]を持っています。 indicesテンソルを使用してdataテンソルから[B X M X 3]テンソルを抽出したいと思います。テンソルを使用してテンソルフローの別のテンソルをインデックスする方法

new_data= []  
for i in range(B): 
     new_data.append(tf.gather(data[i], indices[i])) 
new_data= tf.stack(new_data) 

は、しかし、私はそれがこれを行うには正しい方法ではありません確信しています:

は私が働くこのコードを持っています。誰か良い方法を知っていますか? (私は何とかtf.gather_nd()を使うべきだと思いますが、どうしたらいいか分かりませんでした)

私は同様の質問に対するいくつかの回答を見てきましたhere。しかし私は私の問題の解決策を見つけることができませんでした。

答えて

0

あなたはこのようなコードでtf.gather_nd()を使用することができます。

import tensorflow as tf 

# B = 3 
# N = 4 
# M = 2 
# [B x N x 3] 
data = tf.constant([ 
    [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], 
    [[100, 101, 102], [103, 104, 105], [106, 107, 108], [109, 110, 111]], 
    [[200, 201, 202], [203, 204, 205], [206, 207, 208], [209, 210, 211]], 
    ]) 

# [B x M] 
indices = tf.constant([ 
    [0, 2], 
    [1, 3], 
    [3, 2], 
    ]) 

indices_shape = tf.shape(indices) 

indices_help = tf.tile(tf.reshape(tf.range(indices_shape[0]), [indices_shape[0], 1]) ,[1, indices_shape[1]]); 
indices_ext = tf.concat([tf.expand_dims(indices_help, 2), tf.expand_dims(indices, 2)], axis = 2) 
new_data = tf.gather_nd(data, indices_ext) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print('data') 
    print(sess.run(data)) 
    print('\nindices') 
    print(sess.run(indices)) 
    print('\nnew_data') 
    print(sess.run(new_data)) 

new_data次のようになります。

[[[ 0 1 2] 
    [ 6 7 8]] 

[[103 104 105] 
    [109 110 111]] 

[[209 210 211] 
    [206 207 208]]] 
+0

それは働きます! 'indices_help'、' indices_ext'、 'new_data')の3行のコードの背後にあるロジックは何かを説明できますか? –

+0

'tf.gather_nd'はスライスのインデックス(2つの座標 '')を必要としますが、 ''という1つの座標しか持たず、 'indices'配列の暗黙の座標は' b'です。したがって、インデックスの_batch_を追加することで、 'index 'のすべての要素を拡張する必要があります。これだよ。まず、indices_help'をバッチのインデックスで作成し、ソース 'indices'と連結します –

関連する問題