0
私はdata
テンソルオブメーター[B X N X 3]
を持ち、indices
テンソルディメンション[B X M]
を持っています。 indices
テンソルを使用してdata
テンソルから[B X M X 3]
テンソルを抽出したいと思います。テンソルを使用してテンソルフローの別のテンソルをインデックスする方法
new_data= []
for i in range(B):
new_data.append(tf.gather(data[i], indices[i]))
new_data= tf.stack(new_data)
は、しかし、私はそれがこれを行うには正しい方法ではありません確信しています:
は私が働くこのコードを持っています。誰か良い方法を知っていますか? (私は何とかtf.gather_nd()
を使うべきだと思いますが、どうしたらいいか分かりませんでした)
私は同様の質問に対するいくつかの回答を見てきましたhere。しかし私は私の問題の解決策を見つけることができませんでした。
それは働きます! 'indices_help'、' indices_ext'、 'new_data')の3行のコードの背後にあるロジックは何かを説明できますか? –
'tf.gather_nd'はスライスのインデックス(2つの座標 '')を必要としますが、 ''という1つの座標しか持たず、 'indices'配列の暗黙の座標は' b'です。したがって、インデックスの_batch_を追加することで、 'index 'のすべての要素を拡張する必要があります。これだよ。まず、indices_help'をバッチのインデックスで作成し、ソース 'indices'と連結します –