2017-06-25 16 views
1

heart.csvファイルデータをバッチで読み取ろうとしています。 tensorflowウェブサイトからドキュメントに続いて、私は次のコードは、行単位で読み取るために働いてきたTensorflow 1.2.0でCSVファイルを読み取る

import tensorflow as tf 
filename_queue = tf.train.string_input_producer(["heart.csv"]) 
reader = tf.TextLineReader(skip_header_lines=1) 
_, csv_row = reader.read(filename_queue) 

record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults) 
features = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age] 

nof_examples = 10 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    while nof_examples > 0: 
     nof_examples -= 1 
     try: 
      data_features, data_chd = sess.run([features, chd]) 
#    data_features[4] = 1 if data_features[4] == 'Present' else 0 
      print(data_features, data_chd) 
     except tf.errors.OutOfRangeError: 
      coord.request_stop() 
      coord.join(threads) 
      break 
    coord.request_stop() 
    coord.join(threads) 

出力:

([160, 12.0, 5.73, 23.110001, 'Present', 49, 25.299999, 97.199997, 52], 1) 
([144, 0.0099999998, 4.4099998, 28.610001, 'Absent', 55, 28.870001, 2.0599999, 63], 1) 
([118, 0.079999998, 3.48, 32.279999, 'Present', 52, 29.139999, 3.8099999, 46], 0) 
([170, 7.5, 6.4099998, 38.029999, 'Present', 51, 31.99, 24.26, 58], 1) 
([134, 13.6, 3.5, 27.780001, 'Present', 60, 25.99, 57.34, 49], 1) 
([132, 6.1999998, 6.4699998, 36.209999, 'Present', 62, 30.77, 14.14, 45], 0) 
([142, 4.0500002, 3.3800001, 16.200001, 'Absent', 59, 20.809999, 2.6199999, 38], 0) 
([114, 4.0799999, 4.5900002, 14.6, 'Present', 62, 23.110001, 6.7199998, 58], 1) 
([114, 0.0, 3.8299999, 19.4, 'Present', 49, 24.860001, 2.49, 29], 0) 
([132, 0.0, 5.8000002, 30.959999, 'Present', 69, 30.110001, 0.0, 53], 1) 

が、私はバッチで読み込むしようとすると、tensorflowのドキュメントに示したように、私取得

TypeError: Cannot convert a list containing a tensor of dtype <dtype: 
float32'> to <dtype: 'int32'> (Tensor is: <tf.Tensor 'DecodeCSV_6:1' 
shape=() dtype=float32>) 

バッチ処理コード

import tensorflow as tf 
batch_size = 1 
def read_my_file_format(filename_queue): 
    reader = tf.TextLineReader(skip_header_lines=1) 
    _, csv_row = reader.read(filename_queue) 
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
    sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults) 
    feature = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age] 
    label = [chd] 
    return feature, label 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, 
                num_epochs=num_epochs, 
                shuffle=True) 
    feature, label = read_my_file_format(filename_queue) 
    min_after_dequeue = 10000 
    capacity = min_after_dequeue + 3 * batch_size 
    feature_batch, label_batch = tf.train.shuffle_batch([feature, label], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 min_after_dequeue=min_after_dequeue) 
    return feature_batch, label_batch 

features, labels = input_pipeline(['heart.csv'], batch_size) 

with tf.Session() as sess: 
    tf.global_variables_initializer().run() 

    # start populating filename queue 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    try: 
     while not coord.should_stop(): 
      feature_batch, label_batch = sess.run([features, labels]) 
      print(feature_batch) 
    except tf.errors.OutOfRangeError: 
     print('Done training, epoch reached') 
    finally: 
     coord.request_stop() 
    coord.join(threads) 

テンソルフローを使用してCSVファイルを読み込むのは面倒ですが、ライブラリで重要なのは分散システムだと確信しています。私はそれが混乱していることを発見し、読み込みフィードパイプラインがcsvファイルのためにどのように働いていたかを把握するのに60分以上かかりました。ドキュメンテーションはより良くなり、より多くのビジュアルが必要になるかもしれません。

+0

は[MCVE]あなたが与えたコードですか? – boardrider

+0

@boardriderはいコードは完成しました – bicepjai

+0

しかし、それは最小ですか? – boardrider

答えて

1

私はコードを見ていたし、tf.train.shuffle_batchにおける内部関数の1行のすべてのテンソルは(tf.int32を小文字あなたに、最初の要素から推測される)同じdtypeを持っていることが必要と思われます。それらをストリングでデコードし、後で正しいタイプに変換することができます。それほど便利ではありません。

しかし、TensorFlow 1.2.0を使用しているときに私が助言するのは、新しいデータセットAPIを使用することです。これは、データを処理する新しい方法です(例:this answer参照)。ここでは、引用された回答に基づいて

作業をする必要があり、新しいAPIを使用した例です。

def read_row(csv_row): 
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
    row = tf.decode_csv(csv_row, record_defaults=record_defaults) 
    return row[:-1], row[-1] 

def input_pipeline(filenames, batch_size): 
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data. 
    dataset = (tf.contrib.data.TextLineDataset(filenames) 
       .skip(1) 
       .map(lambda line: read_row(line)) 
       .shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10. 
       .batch(batch_size)) 

    # Return an *initializable* iterator over the dataset, which will allow us to 
    # re-initialize it at the beginning of each epoch. 
    return dataset.make_initializable_iterator() 

iterator = input_pipeline(['heart.csv'], batch_size) 
features, labels = iterator.get_next() 


nof_examples = 10 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    sess.run(iterator.initializer) 
    while nof_examples > 0: 
     nof_examples -= 1 
     try: 
      data_features, data_labels = sess.run([features, labels]) 
      print(data_features) 
     except tf.errors.OutOfRangeError: 
      pass 
+0

コード編集を追加した後、正常に機能しました。ありがとう – bicepjai

関連する問題