docsには、tf.train.shuffle_batch()
とenqueue_many=True
を使用する方法の簡単な説明があります。
X, y = tf.parse_single_example(...)
is_correctly_labelled = correctly_labelled(X, y)
X = tf.expand_dims(X, 0)
y = tf.expand_dims(y, 0)
empty = tf.constant([], tf.int32)
X, y = tf.cond(is_correctly_labelled,
lambda: [X, y],
lambda: [tf.gather(X, empty), tf.gather(y, empty)])
Xs, ys = tf.train.shuffle_batch(
[X, y], batch_size, capacity, min_after_dequeue,
enqueue_many=True)
tf.gather
がゼロサイズのスライスを取得するだけの方法である:例えば、グラフ操作を使用して誤ってラベル付けされているかどうかを判断することができる場合は、それほど(another SO answerから適応)のような結果をフィルタリングすることができます。 numpyでは、それはちょうどX[[], ...]
になります。
ありがとうございます - これはうまくいくようです。 expand_dimsの仕組みを説明できますか?私には1サイズの余分なサイズを追加するように見えますが、どのように使用されているのかは分かりません。 –
余分な1サイズの寸法は、 'enqueue_many = True'で' shuffle_batch() 'によって使用されます。最初の次元のサイズはエンキューするアイテムの数であり、 'is_correctly_labelled == True'の場合は1であり、そうでなければゼロです。それは理にかなっていますか? – Boris
ありがとう、ありがとう –