2016-08-03 11 views
0

私は次のように定義されたテンソル持っている:私はまた、テンソルからフェッチする行のインデックスの配列いるTensorflowのテンソルから特定の行をフェッチする方法は?

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]])) 

を:今、私はそれらのtemp_varのサブセットを取るために

idx = tf.constant([0, 2]) 

をしたいですインデックスは、私は単一のインデックスやスライスを取るために、私たちは

ような何かを行うことができますことを知っている idx

、すなわち0

temp_var[single_row_index, :] 

または

temp_var[start:end, :] 

しかし、どのようにidx配列で示される行をフェッチするには? 何かのようなtemp_var[idx, :]

答えて

2

tf.gather() opは、必要なものを正確に実行します。行列(またはN次元のテンソルからの一般的な(N-1)次元のスライス)から行を選択します。あなたのケースでどのように動作するかは次のとおりです。

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])) 
idx = tf.constant([0, 2]) 

rows = tf.gather(temp_var, idx) 

init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 

print(sess.run(rows)) # ==> [[1, 2, 3], [7, 8, 9]]