2017-05-30 9 views
0

私はmy_tensorと呼ばれるテンソルのサイズが[batch_size, 5, 10]であるとします。 また、別のテンソルの[batch_size, 1]指数はselecterと呼ばれています。インデックス付きのテンソルに基づいてテンソルフローのTensorをフィルタリングする方法はありますか?

selecterに関してmy_tensorをフィルタリングして、新しいテンソルサイズ[batch_size, 10]を生成する、つまりselecterに含まれる値のみを選択したいとします。基本的には、それはちょっと中間のサイズ(サイズが5です)を減らしています。

tf.whereが正しい選択ですが、それについてはわかりません。 本当にありがとうございます!

答えて

1

解決策はtf.gather_ndとします。

tf.gather_nd(
    my_tensor, 
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1)) 

あなたが最初から1-Dであることをselecterを構築する場合は、squeezeを取り除くことができます。

+0

を。どうもありがとうございました! –

+0

あなたはテンソルのどのバージョンを使用していますか?私は1.3.0を持っていて、私のtf.gather_ndは軸パラメータを受け入れません。しかし、それを持っているtf.gatherがあります。 – omikron

0

代替ソリューションは、Tensorflow 1.3で動作します:完璧だ

max_selecter = tf.reduce_max(selecter) + 1 
my_tensor = tf.boolean_mask(
    outputs, 
    tf.logical_xor(
     tf.sequence_mask(my_tensor + 1, max_selecter), 
     tf.sequence_mask(my_tensor, max_selecter) 
    ) 
) 
関連する問題