2016-06-16 11 views
3

をロードするときに、レコードをスキップこれらのレコードをスキップします。ドロップ/ので、私は必要な私は私のトレーニングセット(誤ラベル例)で、いくつかの誤ったデータを発見したと私はソースを修正した一方で、私は同じデータセットを試していきたいと思いデータ

私はTFRecordReaderを使用していて、parse_single_example &シャッフルバッチをロードしています。どこかでフィルタを提供できますか?

答えて

3

docsには、tf.train.shuffle_batch()enqueue_many=Trueを使用する方法の簡単な説明があります。

X, y = tf.parse_single_example(...) 
is_correctly_labelled = correctly_labelled(X, y) 
X = tf.expand_dims(X, 0) 
y = tf.expand_dims(y, 0) 
empty = tf.constant([], tf.int32) 
X, y = tf.cond(is_correctly_labelled, 
       lambda: [X, y], 
       lambda: [tf.gather(X, empty), tf.gather(y, empty)]) 
Xs, ys = tf.train.shuffle_batch(
    [X, y], batch_size, capacity, min_after_dequeue, 
    enqueue_many=True) 

tf.gatherがゼロサイズのスライスを取得するだけの方法である:例えば、グラフ操作を使用して誤ってラベル付けされているかどうかを判断することができる場合は、それほど(another SO answerから適応)のような結果をフィルタリングすることができます。 numpyでは、それはちょうどX[[], ...]になります。

+0

ありがとうございます - これはうまくいくようです。 expand_dimsの仕組みを説明できますか?私には1サイズの余分なサイズを追加するように見えますが、どのように使用されているのかは分かりません。 –

+0

余分な1サイズの寸法は、 'enqueue_many = True'で' shuffle_batch() 'によって使用されます。最初の次元のサイズはエンキューするアイテムの数であり、 'is_correctly_labelled == True'の場合は1であり、そうでなければゼロです。それは理にかなっていますか? – Boris

+0

ありがとう、ありがとう –