2017-12-08 4 views
2

テンソルフローをキーボード割り込みで解除し、その時点でモデルを保存する方法はありますか?私は現在、一晩中セッションを実行したままにしていますが、停止する必要があるため、その日のPC使用のためにメモリを解放できます。訓練が進行するにつれて、各エポックが遅くなるので、プログラムの次のスケジュール保存に数時間待たなければならないこともあります。私はいつでも実行に侵入し、その時点から退くことができるという機能性を望みます。それが可能なのかどうかわからない。ポインタを感謝します。キーボード割り込みテンソルフローを実行してその時点で保存

答えて

2

1つのオプションは、tf.Sessionオブジェクトをサブクラス化し、キーボード割り込みが通過するときに現在の状態を保存する__exit__関数を作成することです。これは、新しいオブジェクトがwithブロックの一部として呼び出された場合にのみ機能します。 TensorFlowのmnistウォークスルーから

import tensorflow as tf 

class SessionWithExitSave(tf.Session): 
    def __init__(self, *args, saver=None, exit_save_path=None, **kwargs): 
     self.saver = saver 
     self.exit_save_path = exit_save_path 
     super().__init__(*args, **kwargs) 

    def __exit__(self, exc_type, exc_value, exc_tb): 
     if exc_type is KeyboardInterrupt: 
      if self.saver: 
       self.saver.save(self, self.exit_save_path) 
       print('Output saved to: "{}./*"'.format(self.exit_save_path)) 
     super().__exit__(exc_type, exc_value, exc_tb) 

使用例:ここでは

はサブクラスです。

import tensorflow as tf 
import datetime as dt 
from tensorflow.examples.tutorials.mnist import input_data 

mnist = input_data.read_data_sets('U:/mnist/', one_hot=True) 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.matmul(x, W) + b 
# Define loss and optimizer 
y_ = tf.placeholder(tf.float32, [None, 10]) 
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy) 

saver = tf.train.Saver() 

with SessionWithExitSave(
     saver=saver, 
     exit_save_path='./tf-saves/_lastest.ckpt') as sess: 
    sess.run(tf.global_variables_initializer()) 
    total_epochs = 50 
    for epoch in range(1, total_epochs+1): 
     for _ in range(1000): 
      batch_xs, batch_ys = mnist.train.next_batch(100) 
      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 
     # Test trained model 
     correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

     print(f'Epoch {epoch} of {total_epochs} :: accuracy = ', end='') 
     print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 
     save_time = dt.datetime.now().strftime('%Y%m%d-%H.%M.%S') 
     saver.save(sess, f'./tf-saves/mnist-{save_time}.ckpt') 

キーボードから割り込み信号を送信する前に、これを10エポックに実行させます。ここで出力されます:

Epoch 1 of 50 :: accuracy = 0.9169 
Epoch 2 of 50 :: accuracy = 0.919 
Epoch 3 of 50 :: accuracy = 0.9205 
Epoch 4 of 50 :: accuracy = 0.9221 
Epoch 5 of 50 :: accuracy = 0.92 
Epoch 6 of 50 :: accuracy = 0.9229 
Epoch 7 of 50 :: accuracy = 0.9234 
Epoch 8 of 50 :: accuracy = 0.9234 
Epoch 9 of 50 :: accuracy = 0.9252 
Epoch 10 of 50 :: accuracy = 0.9248 
Output saved to: "./tf-saves/_lastest.ckpt./*" 
--------------------------------------------------------------------------- 
KeyboardInterrupt       Traceback (most recent call last) 
... 
--> 768 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 
    769  return item[1]._is_present_in_parent 
    770 else: 
KeyboardInterrupt: 

そして実際、私は、保存されたすべてのファイルがシステムに送信されたキーボード割り込みから救う含まれています。

import os 

os.listdir('./tf-saves/') 
# returns: 
['checkpoint', 
'mnist-20171207-23.05.18.ckpt.data-00000-of-00001', 
'mnist-20171207-23.05.18.ckpt.index', 
'mnist-20171207-23.05.18.ckpt.meta', 
'mnist-20171207-23.05.22.ckpt.data-00000-of-00001', 
'mnist-20171207-23.05.22.ckpt.index', 
'mnist-20171207-23.05.22.ckpt.meta', 
'mnist-20171207-23.05.26.ckpt.data-00000-of-00001', 
'mnist-20171207-23.05.26.ckpt.index', 
'mnist-20171207-23.05.26.ckpt.meta', 
'mnist-20171207-23.05.31.ckpt.data-00000-of-00001', 
'mnist-20171207-23.05.31.ckpt.index', 
'_lastest.ckpt.data-00000-of-00001', 
'_lastest.ckpt.index', 
'_lastest.ckpt.meta'] 
関連する問題