2017-04-19 21 views
0

CNNTFにトレーニング中に事前にロードしたいのですが、簡単な実装は次のとおりです。しかし、私は1つの奇妙な現象を見つける。これは同期プロセスのようです。 1つのバッチデータをロードするための時間コストは、PRE_FETCHTrueまたはFalseであるかどうかにかかわらずほぼ同じです。TFの事前ロードデータ

class Demo(object): 
    def __init__(self): 
     self._name = 'demo' 

    def load_batch(self): 
     ... 

    def prefetch(self, func): 
     while True: 
      data = func() 
      self.queue.put(data) 

    def train(self): 
     input_data = tf.placeholder(tf.float32, shape=[B, H, W, C]) 
     optim_op = build_model(input_data) 

     if PRE_FETCH: 
      self.queue = Queue(30) 
      self.process = Process(target=self.prefetch, args=(self.load_batch)) 
      self.process.start() 
      def cleanup(): 
       self.process.terminate() 
       self.process.join() 
      import atexit 
      atexit.register(cleanup) 
     sess = tf.Session() 
     i = 1 
     while i < MAX_ITER_SIZE: 
      if PRE_FETCH: 
       start = time.time() 
       tmp = self.queue.get() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      else: 
       start = time.time() 
       tmp = self.load_batch() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      sess.run(optim_op, feed_dict={input_data: tmp} 

答えて

0

グラフにプレースホルダを使用してデータを読み込むには時間がかかります。プリロードを効果的にするには、pythonのキューとスレッドmecanismeをグラフ内のテンソルフローに置き換えることを検討する必要があります。テンソルフローのWebサイトでこれを行う方法に関する良いチュートリアルがあります:https://www.tensorflow.org/programmers_guide/reading_data

関連する問題