2017-10-29 17 views
0

私は形状の点のxyz座標の配列を持っています(nsamples, npoints, 3)テンソルフローの第2次元のインデックス化

私は、起点に最も近い2つの点を選択するテンソルフローグラフを作成しようとしています。

私はこれまでのところ

r2 = tf.reduce_sum(tf.pow(centeredxyz, 2), axis=2) 
idx = tf.nn.top_k(-r2, 2)[1] 

を得ているこれは私の形状(nsamples, 2)の2Dマトリックスすなわち

[[3, 15], [6, 2], ...]の形で最も近い二つの点のインデックスを提供します。

centeredxyzからこれらのインデックスを使用してポイントを戻すにはどうすればよいですか? tf.gather_ndを試しましたが、最初のサンプルの3番目と15番目の点、2番目のサンプルの6番目と2番目などを求めているうちに、3番目のサンプルの15番目の点の座標を求めていると考えています。

私はtf.rangeを作成し、[[0, 3], [0, 15], [1, 6], [1, 2], ...]を得るために、インデックスにそれを積み重ねてみましたが、それはそれで現在、私は次のしようとする内容へと全く無知だ未知の寸法ValueError: Cannot convert an unknown Dimension to a Tensor: ?

の範囲を作成することはできませんので、それが失敗しました。

答えて

0

私は醜いバージョンを一緒にパッチを当てることができました。脳を傷つけますが、うまくいくようです。

def gather_second_multicol(data, idx): 
    nsamples = tf.shape(idx)[0] 
    nselcol = tf.shape(idx)[1] 
    idx = tf.reshape(idx, [-1, 1]) 
    range = tf.range(nsamples) 
    range = tf.tile(tf.expand_dims(range, 0), [nselcol, 1]) 
    range = tf.transpose(range) 
    range = tf.reshape(range, [-1, 1]) 
    idx = tf.concat([range, idx], 1) 
    gath = tf.gather_nd(data, idx) 
    return tf.reshape(gath, [-1, nselcol, 3]) 

def get_closest(centeredxyz): 
    r2 = tf.reduce_sum(tf.pow(centeredxyz, 2), axis=2) 
    idx = tf.nn.top_k(-r2, 2)[1] 
    closest = gather_second_multicol(centeredxyz, idx) 
    return closest 
関連する問題