2017-08-10 14 views
0

decode_csv関数はCSVのすべての行を読み取るようにするにはどうすればよいですか?Tensorflowのdecode_csvは1行だけを読み取る

現在、私のCSVファイルからGPUにデータをロードしようとしています。データはGPUに正常にロードされますが、実際には640行のCSVファイルの1行だけが実際に読み込まれます。どこが間違っていると思いますか?私はその後、場合

import tensorflow as tf 

with tf.device('/gpu:0'): 
    filename_queue = tf.train.string_input_producer(['dataset.csv']) 
    reader = tf.TextLineReader() 
    key, value = reader.read(filename_queue) 

    record_defaults = [['']]*121 
    all_columns = tf.decode_csv(value, record_defaults=record_defaults) 

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 
     # Start populating the filename queue. 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 

     # Iterate through all the columns 
     vals = [] 
     for x in range(121): 
      tmp = all_columns.pop() 
      myval = tmp.eval(session=sess) 
      vals.append(myval) 

     coord.request_stop() 
     coord.join(threads) 

...

>>> import numpy as np 
>>> vals = np.asarray(vals) 
>>> vals.shape 
(121,) 

は、私は私のCSVで640行のそれぞれにつき121列を持っています。 valsの値は私にはうまく見えますが、実際には640行すべてを読み取っているわけではありません。私はそれがとの関係を持っていると推測している:

all_columns = tf.decode_csv(value, record_defaults=record_defaults) 

答えて

0

Nvm。理解した。

明らかに、行データを抽出する方法に関して、sess.run()pop()の間に違いがあります。

私はそれゆえ、私のCSVファイルの640行と121個の列を持って起こる:

record_defaults = [['']]*121 

for x in range(640): 

これは主に、単にテスト目的のためにハードコードされていることに注意してください。以下の解決策:

import tensorflow as tf 

with tf.device('/gpu:0'): 
filename_queue = tf.train.string_input_producer(['../Datasets/CMU_face_images_dataset.csv']) 
    reader = tf.TextLineReader() 
    key, value = reader.read(filename_queue) 

    record_defaults = [['']]*121 
    all_columns = tf.decode_csv(value, record_defaults=record_defaults) 

    # TWO NEW LINES 
    name = all_columns[0] 
    data = all_columns[1:] 

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 
     # Start populating the filename queue. 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 

     vals = [] 
     names = [] 
     for x in range(640): 

      # THIS IS THE NEW LINE 
      _name, _val = sess.run([name, data]) 

      # OLD LINES 
      # tmp = all_columns.pop() 
      # myval = tmp.eval(session=sess) 
      # vals.append(myval) 

      names.append(_name) 
      vals.append(_val) 

     coord.request_stop() 
     coord.join(threads) 
関連する問題