2017-01-18 12 views
1

キューをより詳細に理解しようとしています。以下のコードを使用して、私はアルファベット順のリストをシャッフルしていないので、出力のコレクションはアルファベット順になると思います。これは初期の時代以外のすべての場合に当てはまると思われる。私は何かを誤解していますか?以下に上記の変更をテンソルフローキューからの順序

from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import time 
import tensorflow as tf 
import numpy as np 
import string 


# Basic model parameters as external flags. 
flags = tf.app.flags 
FLAGS = flags.FLAGS 
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 
flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.') 
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') 
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') 
flags.DEFINE_integer('batch_size', 100, 'Batch size. ' 
        'Must divide evenly into the dataset sizes.') 
flags.DEFINE_string('train_dir', '/tmp/data', 
        'Directory to put the training data.') 
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' 
        'for unit testing.') 


def run_training(): 
    # Tell TensorFlow that the model will be built into the default Graph. 
    with tf.Graph().as_default(): 
    with tf.name_scope('input'): 
     # Input data 
     images_initializer = tf.placeholder(
      dtype=tf.int64, 
      shape=[52,1]) 
     input_images = tf.Variable(
      images_initializer, trainable=False, collections=[]) 

     image = tf.train.slice_input_producer(
      [input_images], num_epochs=2) 
     images = tf.train.batch(
      [image], batch_size=1) 

     alph_initializer = tf.placeholder(
      dtype=tf.string, 
      shape=[26,1]) 
     input_alph = tf.Variable(
      alph_initializer, trainable=False, collections=[]) 

     alph = tf.train.slice_input_producer(
      [input_alph], shuffle=False, capacity=26) 
     alphs = tf.train.batch(
      [alph], batch_size=1) 


    my_list = np.array(list(range(0,52))).reshape(52,1) 
    my_list_val = np.array(list(string.ascii_lowercase)).reshape(26,1) 


    # Create the op for initializing variables. 
    init_op = tf.initialize_all_variables() 

    # Create a session for running Ops on the Graph. 
    sess = tf.Session() 

    # Run the Op to initialize the variables. 
    sess.run(init_op) 
    sess.run(input_images.initializer, 
      feed_dict={images_initializer: my_list}) 
    sess.run(input_alph.initializer, 
      feed_dict={alph_initializer: my_list_val}) 

    sess.run(tf.local_variables_initializer()) 
    sess.run(tf.global_variables_initializer()) 
    # Start input enqueue threads. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

    # And then after everything is built, start the training loop. 
    collection = [] 
    try: 
     step = 0 
     while not coord.should_stop(): 
     start_time = time.time() 

     # Run one step of the model. 
     integer = sess.run(image) 
     #print("Integer val", integer) 

     char = sess.run(alph) 
     collection.append(char[0][0]) 
     print("String val", char) 


     duration = time.time() - start_time 

    except tf.errors.OutOfRangeError: 
     print('Saving') 
     print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
    finally: 
     # When done, ask the threads to stop. 
     coord.request_stop() 
    print(str(collection)) 


    # Wait for threads to finish. 
    coord.join(threads) 
    sess.close() 


def main(_): 
    run_training() 


if __name__ == '__main__': 
    tf.app.run() 

答えて

0

は私の混乱をクリアし

try: 
     step = 0 
     while not coord.should_stop(): 
     start_time = time.time() 

     # Run one step of the model. 
     integer = sess.run(images) 
     #print("Integer val", integer) 

     char = sess.run(alphs) 
     collection.append(char[0][0]) 
     print("String val", char) 


     duration = time.time() - start_time 

    except tf.errors.OutOfRangeError: 
     print('Saving') 
     print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
    finally: 
     # When done, ask the threads to stop. 
     coord.request_stop() 
    print(str(collection))