ここで私は、ドキュメント、他の似たようなソリューション、試行錯誤を読んでから思いついたものです。これは、ランダムなデータの簡単な自動エンコーダーです。走ってから再び走った場合、それは中断していたところから続きます(つまり、最初の走行でのコスト関数は〜0.5から0.3秒の実行開始〜0.3になります)。私が何かを見逃していない限り、すべての保存、コンストラクタ、モデル構築、add_to_collectionが必要であり、正確な順序であるが、より単純な方法があるかもしれない。 グラフをロードするのは実際には上に書かれているので実際には必要ありませんが、私の実際のアプリケーションでは必要なものです。
from __future__ import print_function
import tensorflow as tf
import os
import math
import numpy as np
output_dir = "/root/Data/temp"
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")
input_length = 10
encoded_length = 3
learning_rate = 0.001
n_epochs = 10
n_batches = 10
if not os.path.exists(model_checkpoint_file_base + ".meta"):
print("Making new")
brand_new = True
x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in")
W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length],
-1.0/math.sqrt(input_length),
1.0/math.sqrt(input_length)), name="W_enc")
b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc")
encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded")
W_dec = tf.transpose(W_enc, name="W_dec")
b_dec = tf.Variable(tf.zeros(input_length), name="b_dec")
decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded")
cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")
saver = tf.train.Saver()
else:
print("Reloading existing")
brand_new = False
saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")
g = tf.get_default_graph()
x_in = g.get_tensor_by_name("x_in:0")
cost = g.get_tensor_by_name("cost:0")
sess = tf.Session()
if brand_new:
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
sess.run(init)
tf.add_to_collection("optimizer", optimizer)
else:
saver.restore(sess, model_checkpoint_file_base)
optimizer = tf.get_collection("optimizer")[0]
for epoch_i in range(n_epochs):
for batch in range(n_batches):
batch = np.random.rand(50, input_length)
_, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch})
print("batch_cost:", curr_cost)
save_path = tf.train.Saver().save(sess, model_checkpoint_file_base)
出典
2017-04-06 00:12:11
Ken
私はAdamOptimizerで同じ問題がありました。私は、私の活動をコレクションに入れて仕事をすることができました。この例は私を助けました:http://www.seaandsailor.com/tensorflow-checkpointing.html –