2017-08-07 5 views
2

概要:新しいtf.contrib.data.Datasetを使用すると、グラフのprotobuffファイルのサイズが2倍になり、Tensorboardでグラフを視覚化できません。Tensorflow Dataset APIはグラフprotobuffファイルサイズを2倍にします

詳細は:

私はtf.contrib.learn.Experimentフレームワークと一緒に新しいTensorFlow tf.contrib.data.Dataset機能を実行しようとしています。私の入力データは、テンソルの特徴とラベルを返すinput functionsと定義されています。

私は、次のコードブロック(フルコードhere)のようにtf.train.slice_input_producer機能で自分の入力機能を作成する場合は、その後、私の結果graph.pbtxtファイルが620Mで、.metaファイルのサイズは165Mの周りにあります。

def train_inputs(): 
    with tf.name_scope('Training_data'): 
     x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1])) 
     y = tf.constant(mnist.train.labels) 
     sliced_input = tf.train.slice_input_producer(
      tensor_list=[x, y], shuffle=True) 
     return tf.train.shuffle_batch(
      sliced_input, batch_size=batch_size, 
      capacity=10000, min_after_dequeue=batch_size*10) 

私はその後、私のは、1.3Gと.metaファイルのサイズにgraph.pbtxtファイルダブルス結果、以下のコードブロック(フルコードhere)のように新しいtf.contrib.data.Dataset.from_tensor_slicesと私の入力機能を作成する場合は、330Mの大きさに倍増します。

def train_inputs(): 
    with tf.name_scope('Training_data'): 
     images = mnist.train.images.reshape([-1, 28, 28, 1]) 
     labels = mnist.train.labels 
     dataset = tf.contrib.data.Dataset.from_tensor_slices(
      (images, labels)) 
     dataset = dataset.repeat(None) # Infinite 
     dataset = dataset.shuffle(buffer_size=10000) 
     dataset = dataset.batch(batch_size) 
     iterator = dataset.make_one_shot_iterator() 
     next_example, next_label = iterator.get_next() 
     return next_example, next_label 

graph.pbtxtファイルがとても大きいためTensorBoardは、このファイルを解析するために年齢をとり、そして私は、視覚的に私のモデルグラフをデバッグすることができませんでしだ。 Dataset documentationこのサイズの増加は、"配列の内容は複数回コピーされます"solutionはプレースホルダを使用することになります。 tf.contrib.learn.Experimentフレームワークを使用したとき、私のコントロール外であることを、しかし、これはそう

sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels}) 

:ただし、この場合には、私はイテレータを初期化するためにアクティブなセッションでプレースホルダにnumpyのアレイに供給するために必要があるだろう。

実験フレームワークでイテレータのイニシャライザを初期化するにはどうすればよいですか?または、グラフサイズを大きくしないでデータセットAPIを使用する回避策を見つけますか?

答えて

2

tf.train.SessionRunHookを使用して問題を解決しました。初期化関数は、データセットのIteratorの作成時に設定されている

class IteratorInitializerHook(tf.train.SessionRunHook): 
    def __init__(self): 
     super(IteratorInitializerHook, self).__init__() 
     self.iterator_initiliser_func = None 

    def after_create_session(self, session, coord): 
     self.iterator_initiliser_func(session) 

:私は、セッションが作成された後、イテレータを初期化SessionRunHookオブジェクトを作成

iterator_initiliser_hook.iterator_initiliser_func = \ 
    lambda sess: sess.run(
     iterator.initializer, 
     feed_dict={images_placeholder: images, 
        labels_placeholder: labels}) 

を私はtrain_monitorsにフックオブジェクトを渡すと、 eval_hooksのパラメータはtf.contrib.learn.Experimentです。

のファイルはわずか244Kであるのに対して、結果のgraph.pbtxtファイルは現在わずか500Kです。

Full example here.

+0

ニース。私の問題も解決します。回避策のように思えますか? 私の投稿:https://stackoverflow.com/questions/46207211/tensorflow-dataset-api-causes-graph-size-to-explode –

関連する問題