2017-04-12 5 views
1

tf.train.SequenceExampleとして保存されたシーケンス(シーケンス)シナリオのトレーニング例は、TFRecordWriterという1つ(またはそれ以上)のファイルにあります。私はそれらを読んで、デコードし、それらのシャッフルされたバッチを私のネットワークに送りたいと思います。私はドキュメンテーションやいくつかのチュートリアルで苦労してきましたが、私はそのようなことから何もできませんでした。私は以下の自己完結型の例について取り組んでいます。バッチ処理とシャッフルパディングtf.train.SequenceExample

import random 

import tensorflow as tf 

from six.moves import xrange 


MIN_LEN = 6 
MAX_LEN = 12 
NUM_EXAMPLES = 20 
BATCH_SIZE = 3 
PATH = 'ciaone.tfrecords' 
MIN_AFTER_DEQUEUE = 10 
NUM_THREADS = 2 
SAFETY_MARGIN = 1 
CAPACITY = MIN_AFTER_DEQUEUE + (NUM_THREADS + SAFETY_MARGIN) * BATCH_SIZE 


def generate_example(): 
    # fake examples which are just useful to have a quick visualization. 
    # The input is a sequence of random numbers. 
    # The output is a sequence made of those numbers from the 
    # input sequence which are greater or equal then the average. 
    length = random.randint(MIN_LEN, MAX_LEN) 
    input_ = [random.randint(0, 10) for _ in xrange(length)] 
    avg = sum([1.0 * item for item in input_])/len(input_) 
    output = [item for item in input_ if item >= avg] 
    return input_, output 


def encode(input_, output): 
    length = len(input_) 
    example = tf.train.SequenceExample(
     context=tf.train.Features(
      feature={ 
       'length': tf.train.Feature(
        int64_list=tf.train.Int64List(value=[length])) 
      }), 
     feature_lists=tf.train.FeatureLists(
      feature_list={ 
       'input': tf.train.FeatureList(
        feature=[ 
         tf.train.Feature(
          int64_list=tf.train.Int64List(value=[item])) 
         for item in input_]), 
       'output': tf.train.FeatureList(
        feature=[ 
         tf.train.Feature(
          int64_list=tf.train.Int64List(value=[item])) 
         for item in output]) 
      } 
     ) 
    ) 
    return example 


def decode(example): 
    context_features = { 
     'length': tf.FixedLenFeature([], tf.int64) 
    } 
    sequence_features = { 
     'input': tf.FixedLenSequenceFeature([], tf.int64), 
     'output': tf.FixedLenSequenceFeature([], tf.int64) 
    } 
    ctx, seq = tf.parse_single_sequence_example(
     example, context_features, sequence_features) 
    input_ = seq['input'] 
    output = seq['output'] 
    return input_, output 

if __name__ == '__main__': 
    # STEP 1. -- generate a dataset. 
    with tf.python_io.TFRecordWriter(PATH) as writer: 
     for _ in xrange(NUM_EXAMPLES): 
      record = encode(*generate_example()) 
      writer.write(record.SerializeToString()) 

    with tf.Session() as sess: 
     queue = tf.train.string_input_producer([PATH]) 
     reader = tf.TFRecordReader() 
     _, value = reader.read(queue) 
     input_, output = decode(value) 

     # HERE I AM STUCK! 

     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     sess.run(tf.local_variables_initializer()) 
     sess.run(tf.global_variables_initializer()) 
     try: 
      while True: 
       # do something... 
     except tf.errors.OutOfRangeError, e: 
      coord.request_stop(e) 
     finally: 
      coord.request_stop() 
      coord.join(threads) 
     coord.request_stop() 
     coord.join(threads) 

どのように進めるのですか? ありがとうございます!

P.S.サイドリクエストとして:TensorFlowの入力パイプラインAPIをよりよく理解するためのリソースに関するポインタがありがたいです。

答えて

1

SequenceExampleではなくExampleを処理していた場合は、デコードされたテンソルにtf.train.shuffle_batchを追加するのは簡単です。

_, value = reader.read(queue) 
input_, output = decode(value) 
batch_input, batch_output = tf.train.shuffle_batch([input_, output], 
    batch_size=BATCH_SIZE, capacity=CAPACITY, 
    min_after_sequeue=MIN_AFTER_DEQUEUE) 

しかし、シャッフルバッチは、あなたが渡すテンソルはここに真実ではない静的な形状を持っていることが必要です。可変形状テンソルの場合は、代わりにtf.train.batchdynamic_pad=Trueを使用できます。これはバッチ処理(およびパディング)を処理しますが、サンプルをシャッフルしません。残念ながら、shuffle_batchdynamic_pad引数をとらない。

tf.train.batchの呼び出し前にRandomShuffleQueueを追加することができます回避策described hereあります:

inputs = decode(value) 
dtypes = list(map(lambda x: x.dtype, inputs)) 
shapes = list(map(lambda x: x.get_shape(), inputs)) 
queue = tf.RandomShuffleQueue(CAPACITY, MIN_AFTER_DEQUEUE, dtypes) 
enqueue_op = queue.enqueue(inputs) 
qr = tf.train.QueueRunner(queue, [enqueue_op] * NUM_THREADS) 
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, qr) 
inputs = queue.dequeue() 
for tensor, shape in zip(inputs, shapes): 
    tensor.set_shape(shape) 

# Now you can use tf.train.batch with dynamic_pad=True, and the order in which 
# it enqueues elements will be permuted because of RandomShuffleQueue. 
batch_input, batch_output = tf.train.batch(inputs, batch_size, capacity=capacity, 
           dynamic_pad=True, name=name) 

このパターンの例がありますが(Googleのマゼンタプロジェクトで)hereを実装するには。

+0

これは私の 'Example'sで行ったこととまったく同じですが、' SequenceExamples'をどのように扱うべきかを理解しなければなりません。 githubの問題を指摘してくれてありがとう! – petrux