heart.csvファイルデータをバッチで読み取ろうとしています。 tensorflowウェブサイトからドキュメントに続いて、私は次のコードは、行単位で読み取るために働いてきたTensorflow 1.2.0でCSVファイルを読み取る
import tensorflow as tf
filename_queue = tf.train.string_input_producer(["heart.csv"])
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]
nof_examples = 10
with tf.Session() as sess:
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
while nof_examples > 0:
nof_examples -= 1
try:
data_features, data_chd = sess.run([features, chd])
# data_features[4] = 1 if data_features[4] == 'Present' else 0
print(data_features, data_chd)
except tf.errors.OutOfRangeError:
coord.request_stop()
coord.join(threads)
break
coord.request_stop()
coord.join(threads)
出力:
([160, 12.0, 5.73, 23.110001, 'Present', 49, 25.299999, 97.199997, 52], 1)
([144, 0.0099999998, 4.4099998, 28.610001, 'Absent', 55, 28.870001, 2.0599999, 63], 1)
([118, 0.079999998, 3.48, 32.279999, 'Present', 52, 29.139999, 3.8099999, 46], 0)
([170, 7.5, 6.4099998, 38.029999, 'Present', 51, 31.99, 24.26, 58], 1)
([134, 13.6, 3.5, 27.780001, 'Present', 60, 25.99, 57.34, 49], 1)
([132, 6.1999998, 6.4699998, 36.209999, 'Present', 62, 30.77, 14.14, 45], 0)
([142, 4.0500002, 3.3800001, 16.200001, 'Absent', 59, 20.809999, 2.6199999, 38], 0)
([114, 4.0799999, 4.5900002, 14.6, 'Present', 62, 23.110001, 6.7199998, 58], 1)
([114, 0.0, 3.8299999, 19.4, 'Present', 49, 24.860001, 2.49, 29], 0)
([132, 0.0, 5.8000002, 30.959999, 'Present', 69, 30.110001, 0.0, 53], 1)
が、私はバッチで読み込むしようとすると、tensorflowのドキュメントに示したように、私取得
TypeError: Cannot convert a list containing a tensor of dtype <dtype:
float32'> to <dtype: 'int32'> (Tensor is: <tf.Tensor 'DecodeCSV_6:1'
shape=() dtype=float32>)
バッチ処理コード
import tensorflow as tf
batch_size = 1
def read_my_file_format(filename_queue):
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]]
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults)
feature = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age]
label = [chd]
return feature, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames,
num_epochs=num_epochs,
shuffle=True)
feature, label = read_my_file_format(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
feature_batch, label_batch = tf.train.shuffle_batch([feature, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
return feature_batch, label_batch
features, labels = input_pipeline(['heart.csv'], batch_size)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
feature_batch, label_batch = sess.run([features, labels])
print(feature_batch)
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)
テンソルフローを使用してCSVファイルを読み込むのは面倒ですが、ライブラリで重要なのは分散システムだと確信しています。私はそれが混乱していることを発見し、読み込みフィードパイプラインがcsvファイルのためにどのように働いていたかを把握するのに60分以上かかりました。ドキュメンテーションはより良くなり、より多くのビジュアルが必要になるかもしれません。
は[MCVE]あなたが与えたコードですか? – boardrider
@boardriderはいコードは完成しました – bicepjai
しかし、それは最小ですか? – boardrider