0
私は訓練したTensorflowでモデルを構築しました。今私は出力で作業したいので、チェックポイント、メタ、および他のすべてのファイルをテンソルに読み込みたいと思います。訓練モデルをテンソルフローに戻す
私は、モデルを訓練するために、次のコードを使用しています
# Logging
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train')
test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test')
validate_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/validate')
writer = tf.summary.FileWriter(FLAGS.summary_dir, sess.graph)
saver = tf.train.Saver() # for storing the best network
# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)
# Best validation accuracy seen so far
bestValidation = -0.1
# Training loop
coord = tf.train.Coordinator() # coordinator for threads
threads = tf.train.start_queue_runners(coord = coord, sess=sess) # start queue thread
# Training loop
for i in range(FLAGS.maxIter):
xTrain, yTrain = sess.run(data_batch)
sess.run(train_step, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])})
summary = sess.run(merged, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])})
train_writer.add_summary(summary, i)
if ((i + 1) % 10 == 0):
print("Iteration:", i + 1, "/", FLAGS.maxIter)
summary = sess.run(merged, feed_dict={x_data: dataTest.data, y_target: np.transpose([dataTest.target])})
test_writer.add_summary(summary, i)
currentValidation, summary = sess.run([accuracy, merged], feed_dict={x_data: dataTest.data,
y_target: np.transpose(
[dataTest.target])})
validate_writer.add_summary(summary, i)
if (currentValidation > bestValidation and currentValidation <= 0.9):
bestValidation = currentValidation
saver.save(sess=sess, save_path=FLAGS.summary_dir + '/bestNetwork')
print("\tbetter network stored,", currentValidation, ">", bestValidation)
coord.request_stop() # ask threads to stop
coord.join(threads) # wait for threads to stop
は、今私は戻ってTensorflowにモデルをロードします。
- 訓練とテストのデータセット用に既に作成した出力を使用してください。
- 新しいデータをモデルにロードしてから、同じウェイトを使用して新しい出力を作成できるようにします。
私は戻ってtensorflowにモデルをロードするには、次のコードを使用して試してみたが、それは動作しません:
with tf.Session() as sess:
saver = tf.train.import_meta_graph(FLAGS.summary_dir + '/bestNetwork.meta')
saver.restore(sess,tf.train.latest_checkpoint(FLAGS.summary_dir + '/checkpoint'))
コードを実行するときに、私は次のエラーを取得する:
TypeError:予想されるバイト、NoneTypeが見つかりました
前のセクションのメタグラフにtf.train.import_meta_graph()関数をロードして、チェックポイントの部分で荷重をロードします。では、なぜこれが機能していないのですか?