2016-12-06 7 views
1

例えば:テンソルフローで多次元スライスを行う方法は?

array = [[1, 2, 3], [4, 5, 6]] 

slice = [[0, 0, 1], [0, 1, 2]] 

output = [[1, 1, 2], [4, 5,6]] 

私はarray[slice]を試してみたが、それは動作しませんでした。私もtf.gatherまたはtf.gather_ndを動作させることができませんでしたが、これらは当初は正しい機能のように思えました。これらはグラフ内のすべてのテンソルであることに注意してください。

スライスに従って配列内でこれらの値を選択するにはどうすればよいですか?

答えて

1

テンソルに次元を追加する必要があります。これはtf.packとすることができ、tf.gather_ndは問題なく使用できます。

import tensorflow as tf 

tensor = tf.constant([[1, 2, 3], [4, 5, 6]]) 
old_slice = tf.constant([[0, 0, 1], [0, 1, 2]]) 

# We need to add a dimension - we need a tensor of rank 2, 3, 2 instead of 2, 3 
dims = tf.constant([[0, 0, 0], [1, 1, 1]]) 
new_slice = tf.pack([dims, old_slice], 2) 
out = tf.gather_nd(tensor, new_slice) 

我々は以下のコードを実行した場合:

Input tensor: 
[[1 2 3] 
[4 5 6]] 
Correct param for gather_nd: 
[[[0 0] 
    [0 0] 
    [0 1]] 

[[1 0] 
    [1 1] 
    [1 2]]] 
Output: 
[[1 1 2] 
[4 5 6]] 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    run_tensor, run_slice, run_out = sess.run([tensor, new_slice, out]) 
    print 'Input tensor:' 
    print run_tensor 
    print 'Correct param for gather_nd:' 
    print run_slice 
    print 'Output:' 
    print run_out 

これは正しい出力を与える必要がありますが

関連する問題