2017-04-05 7 views
4

イテレータから塗りつぶされたキューを作成したいと思います。以下MWEではしかし、常に同じ値がエンキューされていますPythonイテレータからのキューの塗りつぶし

import tensorflow as tf 
import numpy as np 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 
enqueue_op = q.enqueue(list(next(it))) 

# setup queue runner 
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads) 
tf.train.add_queue_runner(qr) 

# dequeue 
dequeue_op = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()") 

# We start the session as usual ... 
with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 
     data = sess.run(dequeue_op) 
     print(data) 
. 
    coord.request_stop() 
    coord.join(threads) 

私は必ずしもfeed_dictを使用する必要がありますか?はいの場合は、どのようにQueueRunnerと組み合わせて使用​​する必要がありますか?

答えて

3

(次の(それは))、正確に1回のリストを実行します

enqueue_op = q.enqueue(list(next(it))) 

tensorflowを実行しています。その後、最初のリストを保存し、enqueue_opを実行するたびにqに追加します。これを避けるには、プレースホルダを使用する必要があります。給餌プレースホルダはtf.train.QueueRunnerと互換性がありません。代わりにこれを使用してください:

import tensorflow as tf 
import numpy as np 
import threading 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 

img_p = tf.placeholder(tf.float64, [None, None]) 
enqueue_op = q.enqueue(img_p) 

dequeue_op = q.dequeue() 


with tf.Session() as sess: 
    coord = tf.train.Coordinator() 

    def enqueue_thread(): 
     with coord.stop_on_exception(): 
      while not coord.should_stop(): 
       sess.run(enqueue_op, feed_dict={img_p: list(next(it))}) 

    numberOfThreads = 1 
    for i in range(numberOfThreads): 
     threading.Thread(target=enqueue_thread).start() 



    for i in range(3): 
     data = sess.run(dequeue_op) 
     print(data)