1
batch_images
:形状の4Dテンソル(B, H, W, C)
x
:形状の三次元テンソル(B, H, W)
y
:形状の三次元テンソル(B, H, W)
ゴール
どの形状B, H, W, C
の4Dテンソルを得るためx
とy
座標を用いbatch_images
にIインデックス缶。つまり、各バッチについて、そして各ペアについて、(x, y)
テンソル形状C
を取得したいと考えています。
numpyでは、これは例えばinput_img[np.arange(B)[:,None,None], y, x]
を使用して達成されますが、テンソルフローでは機能しないようです。
これまでの形状(B, H, W)
の間違ったテンソルを返す
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for
coordinate vectors x and y from a 4D tensor image.
"""
H = tf.shape(img)[1]
W = tf.shape(img)[2]
C = tf.shape(img)[3]
# flatten image
img_flat = tf.reshape(img, [-1, C])
# flatten idx
idx_flat = (x*W) + y
return tf.gather(img_flat, idx_flat)
私の試み。