2017-02-13 164 views
1

TensorFlowでscatter_nd関数を使用して、行列の行内の要素の順序を変更しようとしています。例えば、私は、コードを持っているとします行列の行と列内の要素を入れ替える - TensorFlow scatter_nd

indices = tf.constant([[1],[0]]) 
updates = tf.constant([ [5, 6, 7, 8], 
         [1, 2, 3, 4] ]) 
shape = tf.constant([2, 4]) 
scatter1 = tf.scatter_nd(indices, updates, shape) 
$ print(scatter1) = [[1,2,3,4] 
        [5,6,7,8]] 

これはupdates行列の行を並べ替えます。

行の順序を変更できるだけでなく、各行内の個々の要素の順序を変更したいと考えています。私はベクトル(ランク1のテンソルを)持っている場合、この例では動作します:

indices = tf.constant([[1],[0],[2],[3]]) 
updates = tf.constant([5, 6, 7, 8]) 
shape = tf.constant([4]) 
scatter2 = tf.scatter_nd(indices, updates, shape) 
$ print(scatter2) = [6,5,7,8] 

私は本当に気に私がscatter2に行っていたとして、scatter1の各行内の要素を交換できるようにすることですscatter1の行ごとに実行します。私はindicesのさまざまな組み合わせを試しましたが、サイズがscatter_nd関数によってスローされた矛盾したエラーが発生し続けます。

答えて

1

以下がscatter_nd

の出力与える
indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]], 
         [[1, 1], [1, 0], [1, 2], [1, 3]]]) 
updates = tf.constant([ [5, 6, 7, 8], 
         [1, 2, 3, 4] ]) 
shape = tf.constant([2, 4]) 
scatter1 = tf.scatter_nd(indices, updates, shape) 
with tf.Session() as sess: 
    print(sess.run(scatter1)) 

用いて、各列の各行の要素を入れ替え:値が取られている場合indices座標の位置が定義
[[6 5 7 8] [2 1 3 4]]

をfrom updates、実際の座標はscatter1のどこに値が配置されるかを定義します。

この回答は数ヶ月遅れていますが、うまくいけばまだ役立ちます。

0

第1次元の順序を維持しながら第2次元の要素を入れ替えたいとします。

import tensorflow as tf 
sess = tf.InteractiveSession() 


def prepare_fd(fd_indices, sd_dims): 
    fd_indices = tf.expand_dims(fd_indices, 1) 
    fd_indices = tf.tile(fd_indices, [1, sd_dims]) 
    return fd_indices 

# define the updates 
updates = tf.constant([[11, 12, 13, 14], 
         [21, 22, 23, 24], 
         [31, 32, 33, 34]]) 
sd_dims = tf.shape(updates)[1] 

sd_indices = tf.constant([[1, 0, 2, 3], [0, 2, 1, 3], [0, 1, 3, 2]]) 
fd_indices_range = tf.range(0, limit=tf.shape(updates)[0]) 
fd_indices_custom = tf.constant([2, 0, 1]) 

# define the indices 
indices1 = tf.stack((prepare_fd(fd_indices_range, sd_dims), sd_indices), axis=2) 
indices2 = tf.stack((prepare_fd(fd_indices_custom, sd_dims), sd_indices), axis=2) 

# define the shape 
shape = tf.shape(updates) 

scatter1 = tf.scatter_nd(indices1, updates, shape) 
scatter2 = tf.scatter_nd(indices2, updates, shape) 

print(scatter1.eval()) 

# array([[12, 11, 13, 14], 
#  [21, 23, 22, 24], 
#  [31, 32, 34, 33]], dtype=int32) 

print(scatter2.eval()) 

# array([[21, 23, 22, 24], 
#  [31, 32, 34, 33], 
#  [12, 11, 13, 14]], dtype=int32) 

この例は役に立ちます。

関連する問題