2017-05-24 10 views
2

私はインターネット上で見つけたものでは解決できない問題に直面しています。Tensorflow:入力キューを使用した同じグラフへのトレーニングとテスト

私は自分のニューラルネットワークを構築し、これをパイプラインに接続しました。 tf.train.batchとqueueRunners、COORDS、などとtfrecordから 読み取りデータ、..私は、私は次のように使う "モデル" という名前のPythonのクラスに私のNNを構築してい

モデル=モデル(ここでは...すべてのハイパー...)

...

model.predict()

または

model.step()

非常によく、すべてのトレーニング段階の作品。

しかし今、トレーニングのXエポック/ステップごとにテストフェーズを追加したいと思います。

私は本当にこれを行う方法がわかりません。 私はいくつかのアイデアを持っているが、私は最高のものを見つけることができません:loss_trainとloss_test、というように、私のグラフの各ノードのために:

  • は得るために、私のクラスにコードを複製しますか? =

model_train(=リユース真)

model_test =モデル(リユース= false)を

モデル:
  • は私のモデルの2のインスタンスを作成します(電車とテストの間で共有変数を使用して)
    • tf.make_templateを使用しますか?私は本当にこの功績の良い例を見つけられませんでした...
    • 他の解決策はありますか? TFRecordsデータセットの実験をするとき、私はどんな提案をいただければ幸いです

  • 答えて

    1

    私は同じ問題に出くわしました。いくつかの可能性があります。私はとにかく唯一のGPUを搭載したコンピュータ上でこれをやってみたかったので、次のように私はそれを実装:

    # Training Dataset 
    train_dataset = tf.contrib.data.TFRecordDataset(train_files) 
    train_dataset = train_dataset.map(parse_function) 
    train_dataset = train_dataset.shuffle(buffer_size=10000) 
    train_dataset = train_dataset.batch(200) 
    # Validation Dataset 
    validation_dataset = tf.contrib.data.TFRecordDataset(val_files) 
    validation_dataset = validation_dataset.map(parse_function) 
    validation_dataset = validation_dataset.batch(200) 
    
    # A feedable iterator is defined by a handle placeholder and its structure. We 
    # could use the `output_types` and `output_shapes` properties of either 
    # `training_dataset` or `validation_dataset` here, because they have 
    # identical structure. 
    handle = tf.placeholder(tf.string, shape=[]) 
    iterator = tf.contrib.data.Iterator.from_string_handle(handle, 
    train_dataset.output_types, train_dataset.output_shapes) 
    next_element = iterator.get_next() 
    
    # Generate the Iterators 
    training_iterator = train_dataset.make_initializable_iterator() 
    validation_iterator = validation_dataset.make_one_shot_iterator() 
    
    # The `Iterator.string_handle()` method returns a tensor that can be evaluated 
    # and used to feed the `handle` placeholder. 
    training_handle = sess.run(training_iterator.string_handle()) 
    validation_handle = sess.run(validation_iterator.string_handle()) 
    

    は、その後の要素にアクセスするために、あなただけのように行くことができます:

    img, lbl = sess.run(next_element, feed_dict={handle: training_handle}) 
    

    、ハンドルを交換しますあなたがATMをやりたいことに依存しています。

    ただし、これは並列化できません。このリンクの後、複数の入力パイプラインを作成するさまざまな方法についての洞察を得ることができますTensorflow | Reading Data

    関連する問題