2017-07-27 13 views
2

私はすでに5時間以上スクリプトを実行しています。 TFレコードに変換したい258 CSVファイルがあります。私は、次のスクリプトを書いた、と私は言ってきたように、私はすでに5時間以上のためにそれを実行してきた:CSVファイルをTFレコードに変換する

import argparse 
import os 
import sys 
import standardize_data 
import tensorflow as tf 

FLAGS = None 
PATH = '/home/darth/GitHub Projects/gru_svm/dataset/train' 

def _int64_feature(value): 
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 

def _float_feature(value): 
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 

def convert_to(dataset, name): 
    """Converts a dataset to tfrecords""" 

    filename_queue = tf.train.string_input_producer(dataset) 

    # TF reader 
    reader = tf.TextLineReader() 

    # default values, in case of empty columns 
    record_defaults = [[0.0] for x in range(24)] 

    key, value = reader.read(filename_queue) 

    duration, service, src_bytes, dest_bytes, count, same_srv_rate, \ 
    serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, \ 
    dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, \ 
    flag, ids_detection, malware_detection, ashula_detection, label, src_ip_add, \ 
    src_port_num, dst_ip_add, dst_port_num, start_time, protocol = \ 
    tf.decode_csv(value, record_defaults=record_defaults) 

    features = tf.stack([duration, service, src_bytes, dest_bytes, count, same_srv_rate, 
         serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, 
         dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, 
         flag, ids_detection, malware_detection, ashula_detection, src_ip_add, 
         src_port_num, dst_ip_add, dst_port_num, start_time, protocol]) 

    filename = os.path.join(FLAGS.directory, name + '.tfrecords') 
    print('Writing {}'.format(filename)) 
    writer = tf.python_io.TFRecordWriter(filename) 
    with tf.Session() as sess: 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     try: 
      while not coord.should_stop(): 
       example, l = sess.run([features, label]) 
       print('Writing {dataset} : {example}, {label}'.format(dataset=sess.run(key), 
         example=example, label=l)) 
       example_to_write = tf.train.Example(features=tf.train.Features(feature={ 
        'duration' : _float_feature(example[0]), 
        'service' : _int64_feature(int(example[1])), 
        'src_bytes' : _float_feature(example[2]), 
        'dest_bytes' : _float_feature(example[3]), 
        'count' : _float_feature(example[4]), 
        'same_srv_rate' : _float_feature(example[5]), 
        'serror_rate' : _float_feature(example[6]), 
        'srv_serror_rate' : _float_feature(example[7]), 
        'dst_host_count' : _float_feature(example[8]), 
        'dst_host_srv_count' : _float_feature(example[9]), 
        'dst_host_same_src_port_rate' : _float_feature(example[10]), 
        'dst_host_serror_rate' : _float_feature(example[11]), 
        'dst_host_srv_serror_rate' : _float_feature(example[12]), 
        'flag' : _int64_feature(int(example[13])), 
        'ids_detection' : _int64_feature(int(example[14])), 
        'malware_detection' : _int64_feature(int(example[15])), 
        'ashula_detection' : _int64_feature(int(example[16])), 
        'label' : _int64_feature(int(l)), 
        'src_ip_add' : _float_feature(example[17]), 
        'src_port_num' : _float_feature(example[18]), 
        'dst_ip_add' : _float_feature(example[19]), 
        'dst_port_num' : _float_feature(example[20]), 
        'start_time' : _float_feature(example[21]), 
        'protocol' : _int64_feature(int(example[22])), 
        })) 
       writer.write(example_to_write.SerializeToString()) 
      writer.close() 
     except tf.errors.OutOfRangeError: 
      print('Done converting -- EOF reached.') 
     finally: 
      coord.request_stop() 

     coord.join(threads) 

def main(unused_argv): 
    files = standardize_data.list_files(path=PATH) 

    convert_to(dataset=files, name='train') 

それはすでに、私はおそらくそれが無限ループに陥っていますことを考えるようになりましたか?私がしたいのは、各CSVファイル(258 CSVファイル)のすべての行を読み込み、それらの行をTFレコード(フィーチャとラベル、もちろん)に書き込むことです。そして、利用可能な行がなくなったとき、またはCSVファイルがすでに使い果たされたときに、ループを停止します。

standardize_data.list_files(path)は、別のモジュールで書いた関数です。私はこのスクリプトのために再使用しました。それは、PATHにあるすべてのファイルのリストを返すことです。私のPATHのファイルにはCSVファイルしか含まれていないことに注意してください。

答えて

1

num_epochs=1string_input_producerに設定します。別の注意:tfrecordsにこれらのcsvを変換すると、あなたはtfrecordsで見ている利点を提供しないかもしれません、この種類のデータ(単一の機能/ラベルの数が多い)でオーバーヘッドは非常に高いです。この部分を試してみてください。

+0

つまり、言い換えれば、それらをTFレコードに変換することを提案していませんか? –

+0

この実験を行います:ファイルを1つだけ変換して、それぞれのサイズを確認します。あなたのデータは 'tfrecords'表現にとって効率的ではありません。すべての機能はラベルと一緒に保存されるので、それはcsvとして保存するよりも非常に大きなサイズになります。 –

+0

サンプルのCSVファイルは10.1 MB、tfrecord相当のファイルは9.6 MBです –

関連する問題