、あなたはtf.while_loop
を使用することができます。
import numpy as np
import tensorflow as tf
y = tf.constant([1, 2, 1, 1])
x = tf.constant([[[1,2],[3,4],[5,6]],
[[1,2],[3,4],[5,6]],
[[1,2],[3,4],[5,6]],
[[1,2],[3,4],[5,6]]], dtype=np.float32)
y_len = tf.shape(y)[0]
idx = tf.zeros((), dtype=tf.int32)
res = tf.zeros((0,2))
_, res = tf.while_loop(
lambda idx, res: idx < y_len,
lambda idx, res: (idx + 1, tf.concat([res, tf.reduce_mean(x[idx, :y[idx]], axis=0)[tf.newaxis]], axis=0)),
[idx, res],
shape_invariants=[idx.get_shape(), tf.TensorShape([None, 2])])
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
res.eval()
# returns
# array([[ 1., 2.],
# [ 2., 3.],
# [ 1., 2.],
# [ 1., 2.]], dtype=float32)
y
の長さは、グラフ構築時に知られている以下の一般的なケースでは、あなた自身tf.while_loop
の使用とループを惜しまことができ(y
に多くの要素がある場合は大きなグラフが表示される可能性があります)。
y_len = y.shape.num_elements()
res = tf.Variable(np.zeros((y_len, 2), dtype=np.float32))
res = tf.tuple([res] + [res[idx].assign(tf.reduce_mean(x[idx, :y[idx]], axis=0))
for idx in range(y_len)])[0]
あなたは、単にtf.while_loop
ではない一般の場合とは異なり、アップデートをカスケードでき注:
y_len = y.shape.num_elements()
res = tf.zeros((0,2))
for idx in range(y_len):
res = tf.concat([res, tf.reduce_mean(x[idx, :y[idx]], axis=0)[tf.newaxis]], axis=0)
が、今アップデートが順次発生する必要があります。前者の解決策では、発生する各行の更新は独立していて、並行して実行することができます。
ありがとうございます!コードの2番目のブロックは私にFailedPreconditionErrorを与えています。 – Alex
私はそれを間違った場所で実行しました。問題は修正されました。ありがとう。 – Alex