tf.train.SequenceExample
として保存されたシーケンス(シーケンス)シナリオのトレーニング例は、TFRecordWriter
という1つ(またはそれ以上)のファイルにあります。私はそれらを読んで、デコードし、それらのシャッフルされたバッチを私のネットワークに送りたいと思います。私はドキュメンテーションやいくつかのチュートリアルで苦労してきましたが、私はそのようなことから何もできませんでした。私は以下の自己完結型の例について取り組んでいます。バッチ処理とシャッフルパディングtf.train.SequenceExample
import random
import tensorflow as tf
from six.moves import xrange
MIN_LEN = 6
MAX_LEN = 12
NUM_EXAMPLES = 20
BATCH_SIZE = 3
PATH = 'ciaone.tfrecords'
MIN_AFTER_DEQUEUE = 10
NUM_THREADS = 2
SAFETY_MARGIN = 1
CAPACITY = MIN_AFTER_DEQUEUE + (NUM_THREADS + SAFETY_MARGIN) * BATCH_SIZE
def generate_example():
# fake examples which are just useful to have a quick visualization.
# The input is a sequence of random numbers.
# The output is a sequence made of those numbers from the
# input sequence which are greater or equal then the average.
length = random.randint(MIN_LEN, MAX_LEN)
input_ = [random.randint(0, 10) for _ in xrange(length)]
avg = sum([1.0 * item for item in input_])/len(input_)
output = [item for item in input_ if item >= avg]
return input_, output
def encode(input_, output):
length = len(input_)
example = tf.train.SequenceExample(
context=tf.train.Features(
feature={
'length': tf.train.Feature(
int64_list=tf.train.Int64List(value=[length]))
}),
feature_lists=tf.train.FeatureLists(
feature_list={
'input': tf.train.FeatureList(
feature=[
tf.train.Feature(
int64_list=tf.train.Int64List(value=[item]))
for item in input_]),
'output': tf.train.FeatureList(
feature=[
tf.train.Feature(
int64_list=tf.train.Int64List(value=[item]))
for item in output])
}
)
)
return example
def decode(example):
context_features = {
'length': tf.FixedLenFeature([], tf.int64)
}
sequence_features = {
'input': tf.FixedLenSequenceFeature([], tf.int64),
'output': tf.FixedLenSequenceFeature([], tf.int64)
}
ctx, seq = tf.parse_single_sequence_example(
example, context_features, sequence_features)
input_ = seq['input']
output = seq['output']
return input_, output
if __name__ == '__main__':
# STEP 1. -- generate a dataset.
with tf.python_io.TFRecordWriter(PATH) as writer:
for _ in xrange(NUM_EXAMPLES):
record = encode(*generate_example())
writer.write(record.SerializeToString())
with tf.Session() as sess:
queue = tf.train.string_input_producer([PATH])
reader = tf.TFRecordReader()
_, value = reader.read(queue)
input_, output = decode(value)
# HERE I AM STUCK!
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
try:
while True:
# do something...
except tf.errors.OutOfRangeError, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
coord.request_stop()
coord.join(threads)
どのように進めるのですか? ありがとうございます!
P.S.サイドリクエストとして:TensorFlowの入力パイプラインAPIをよりよく理解するためのリソースに関するポインタがありがたいです。
これは私の 'Example'sで行ったこととまったく同じですが、' SequenceExamples'をどのように扱うべきかを理解しなければなりません。 githubの問題を指摘してくれてありがとう! – petrux