2017-06-03 9 views
1

私はTransfer Learningアプローチを使用しようとしています。ここに私のコードは、トレーニングデータ上で学習されているコードのスナップショットは、次のとおりです。Tensorflowの保存と復元モデルに関する問題

max_accuracy = 0.0 
    saver = tf.train.Saver() 
    for epoch in range(epocs): 
     shuffledRange = np.random.permutation(n_train) 
     y_one_hot_train = encode_one_hot(len(classes), Y_input) 
     y_one_hot_validation = encode_one_hot(len(classes), Y_validation) 
     shuffledX = X_input[shuffledRange,:] 
     shuffledY = y_one_hot_train[shuffledRange] 
     for Xi, Yi in iterate_mini_batches(shuffledX, shuffledY, mini_batch_size): 
      sess.run(train_step, 
        feed_dict={bottleneck_tensor: Xi, 
           ground_truth_tensor: Yi}) 
      # Every so often, print out how well the graph is training. 
      is_last_step = (i + 1 == FLAGS.how_many_training_steps) 
      if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 
       train_accuracy, cross_entropy_value = sess.run(
        [evaluation_step, cross_entropy], 
        feed_dict={bottleneck_tensor: Xi, 
          ground_truth_tensor: Yi}) 
       validation_accuracy = sess.run(
        evaluation_step, 
        feed_dict={bottleneck_tensor: X_validation, 
          ground_truth_tensor: y_one_hot_validation}) 
       print('%s: Step %d: Train accuracy = %.1f%%, Cross entropy = %f, Validation accuracy = %.1f%%' % 
        (datetime.now(), i, train_accuracy * 100, cross_entropy_value, validation_accuracy * 100)) 
       result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name)) 
       probs = sess.run(result_tensor,feed_dict={'pool_3/_reshape:0': Xi[0].reshape(1,2048)}) 
       if validation_accuracy > max_accuracy : 
        saver.save(sess, 'models/superheroes_model') 
        max_accuracy = validation_accuracy 
        print(probs) 
      i+=1 

ここではどこ私はモデルロードしています私のコードです:同じデータのために今そう

def load_model() : 
    sess=tf.Session()  
    #First let's load meta graph and restore weights 
    saver = tf.train.import_meta_graph('models/superheroes_model.meta') 
    saver.restore(sess,tf.train.latest_checkpoint('models/')) 
    sess.run(tf.global_variables_initializer()) 
    result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name)) 
    X_feature = features[0].reshape(1,2048)   
    probs = sess.run(result_tensor, 
         feed_dict={'pool_3/_reshape:0': X_feature}) 
    print probs 
    return sess 

をポイント私は訓練とテストの間に全く異なる結果を得ています。その近くにもありません。テスト中、4つのクラスがあるので、私の確率は25%近くになります。しかし、訓練中、最高クラスの確率は90%です。
モデルを保存または復元する際に問題はありますか?

+0

問題が解決のために1/4の確率が生成されます初期値に置き換えられます。私は多数のエポチェスを訓練していたので、いくつかのエポッシュの後に確率が減少した。 – neel

答えて

2

は注意してください - あなたは

sess.run(tf.global_variables_initializer()) 

を呼び出している

saver.restore(sess,tf.train.latest_checkpoint('models/')) 

を呼び出した後、私は以前にも同様のやった、と私はすべてのあなたの訓練を受けた重み/偏見の/ etcをリセットと思います。復元されたモデルで。

モデルを復元する前にイニシャライザを呼び出し、復元されたモデルから特定のものを初期化する必要がある場合は、個別に行います。あなたがそれを行う場合

+0

それは助けになりましたが、確率は0.25から0.3までしか増加しません。私はすべてのトレーニングポイントを確認しました。 – neel

+0

モデルを保存して復元するためにこれを続けました。https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model – neel

+0

あなたの前にキーモデルの値を書き留めてみませんか?それを保存して(重み付け/バイアスなど)、復元したものと比較してください。私はいくつかの変数のために前もってこの運動を行ってきました。そして、私にとっては、それはOKでしたが、誰が知っていますか? –

2

あなたの関数load_modelsess.run(tf.global_variables_initializer())を削除するには、すべての訓練を受けたパラメータは、各クラス

+0

確率は0.25から0.3に増加するだけです。私はすべてのトレーニングポイントを確認しました。 – neel

+0

私はモデルを保存して復元するためにこれを続けました。https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model – neel

関連する問題