2017-12-19 8 views
0

Tensorflow Dataset APIを使用して、フォルダごとに1つのバッチ(イメージを含む各フォルダ)を作成します。私は、次の簡単なコードスニペットを持っている:ListDirectoryを使用したTensorflowデータセットAPI

import tensorflow as tf 
import os 
import pdb 

def parse_file(filename): 
    image_string = tf.read_file(filename) 
    image_decoded = tf.image.decode_png(image_string) 
    image_resized = tf.image.resize_images(image_decoded, [48, 48]) 
    return image_resized #, label 

def parse_dir(frame_dir): 
    filenames = tf.gfile.ListDirectory(frame_dir) 
    batch = tf.constant(5) 
    batch = tf.map_fn(parse_file, filenames) 
    return batch 

directory = "../Detections/NAC20171125" 
# filenames = tf.constant([os.path.join(directory, f) for f in os.listdir(directory)]) 
frames = [os.path.join(directory, str(f)) for f in range(10)] 


dataset = tf.data.Dataset.from_tensor_slices((frames)) 
dataset = dataset.map(parse_dir) 

dataset = dataset.batch(256) 
iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next() 


with tf.Session() as sess: 
    sess.run(iterator.initializer) 
    while True: 
     try: 
      batch = sess.run(next_element) 
      print(batch.shape) 
     except tf.errors.OutOfRangeError: 
      break 

しかし、tf.gfile.ListDirectory(parse_dirでは)代わりテンソルの通常の文字列を期待しています。だから今エラーは

TypeError: Expected binary or unicode string, got <tf.Tensor 'arg0:0' shape=() dtype=string> 

これを解決する簡単な方法はありますか?ここ

答えて

2

問題はtf.gfile.ListDirectory()はPythonの文字列を期待するPythonの関数であり、parse_dir()frame_dir引数がtf.Tensorであるということです。したがって、ディレクトリ内のファイルをリストするには同等のTensorFlow操作が必要で、tf.data.Dataset.list_files()tf.matching_files()に基づく)は、おそらく最も近いものです。

directory = "../Detections/NAC20171125" 
frames = [os.path.join(directory, str(f)) for f in range(10)] 

# Start with a dataset of directory names. 
dataset = tf.data.Dataset.from_tensor_slices(frames) 

# Maps each subdirectory to the list of files in that subdirectory and flattens 
# the result. 
dataset = dataset.flat_map(lambda dir: tf.data.Dataset.list_files(dir + "/*")) 

# Maps each filename to the parsed and resized image data. 
dataset = dataset.map(parse_file) 

dataset = dataset.batch(256) 

iterator = dataset.make_initializable_iterator() 
next_element = iterator.get_next() 
関連する問題