私は私が何をしたいのか(1000, 10, 40)
MATMUL
を想定し、(1000, 50, 10)
Fy
形状(num, N, h)
を想定し、(1, 20, 40, 50)
Fx
形状(num, w, N)
を想定し、3テンソル
X
に形状(1, c, h, w)
を持っているがmatmul
をFy * (X * Fx)
(*
を意味しています)
X * Fx
形(num, c, h, N)
、(1000, 20, 40, 10)
Fy * (X * Fx)
形状(num, c, N, N)
、より良い方法を探してみてください 、私はそれを
を行うためにtf.tile
とtf.expand_dims
を使用していますが、私はそれが多くのメモリを使用すると思います(tile
コピーデータは右?)、および遅い(1000, 20, 10, 10)
を想定より速く、がmythical einsum
ため
# X: (1, c, h, w)
# Fx: (num, w, N)
# Fy: (num, N, h)
X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w)
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N)
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N)
Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h)
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N)
はい、私はこれまでにこれを見たことがありませんでした。理解するのが少し難しく、アイデアを理解するためにひねります – xxi
awwwwwesome、それは巨大なスピードの改善、多くのありがとう – xxi