2016-03-07 23 views
8

TensorFlowでインデックスを作成する方法に関する基本的な質問があります。 numpyのでTensorFlow:テンソルを使用して別のテンソルをインデックスする

x = np.asarray([1,2,3,3,2,5,6,7,1,3]) 
e = np.asarray([0,1,0,1,1,1,0,1]) 
#numpy 
print x * e[x] 

私はTensorFlowでこれを行うことができますどのように

[1 0 3 3 0 5 0 7 1 3] 

を得ることができますか?

x = np.asarray([1,2,3,3,2,5,6,7,1,3]) 
e = np.asarray([0,1,0,1,1,1,0,1]) 
x_t = tf.constant(x) 
e_t = tf.constant(e) 
with tf.Session(): 
    ???? 

ありがとう!幸いにも、あなたは約求めている正確な場合がtf.gather()でTensorFlowでサポートされて

+0

http://stackoverflow.com/questions/33736795/tensorflow-numpy-like-tensor-indexing?rq=1 あなたが尋ねているのはそれですか? – Alleo

答えて

19

result = x_t * tf.gather(e_t, x_t) 

with tf.Session() as sess: 
    print sess.run(result) # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])' 

tf.gather() opがNumPy's advanced indexing未満強力です:それだけでその0番目の次元にテンソルの完全なスライスを抽出することをサポートしてい。より一般的なインデックス作成のサポートがリクエストされており、this GitHub issueで追跡されています。

+1

ありがとうございます! – user200340

+3

Tensorflowにはより強力な 'tf.gather_nd()' opが追加されました。 – fritzo

関連する問題