2017-08-10 7 views
0

機械学習でテンソルフローを使用する場合、モデルをあらかじめ訓練して、チェックポイントファイルによるtrain_model floder(saver.save()を使用)次に、プレトレインモデルレイヤの一部を使用して新しいネットワークを初期化し、チェックポイントファイルを使用してnew-train_modelフォルダに新しいトレインモデルを保存します。1つのチェックポイントファイルから一部のレイヤーをロードし、別のチェックポイントファイルにすべてのグラフ変数を保存する方法を学習するTensorflowの転送

この機能を実装するにはどうすればよいですか。

+0

可能な重複が[Tensorflowで学習転送を自動化する方法はありますか?](https://stackoverflow.com/questions/45569330/is-there-a- way-to-automate-transfer-learning-テンソルフロー) – jdehesa

答えて

1

Tensorflow: 私たちは列車のグラフ内のすべてのVARSセーバーたい場合は、通常、我々は

__init__(self, 
     vat_list=None 
     ...... 
     ) 
self._var_list = var_list 
....... 
if self._var_list in None: 
    self._var_list = variables._all_saveable_objects() 
...... 

、init関数で

saver = tf.train.Saver() 

としてセーバー定義我々はから2層をロードする場合他のCNモデル 定義可能

これはあなたのコードに合わせて変更することができる例です

新しいモデルを新しいモデルフォルダに保存するときには、次回にすべてのグラフ変数を復元するときに、ので、あなたのセーバー

saver = tf.train.Saver() 

にinit怒鳴るコードを使用すると、あなたはsaver.restore() を実行したとき。あなたはテンソルノードを得るか、または変数はエラーを見つけることができません。

この問題を解決するには、このようにすることができます。

転送あなたのモデルは、の

variables_to_restore = [var for var in tf.global_variables() 
         if var.name.startswith('conv_1') 
         or var.name.startswith('conv_2')] 
saver = tf.train.Saver(variables_to_restore) 
saver_all = tf.train.Saver() 
........ 
saver.restore(....) 
saver = saver_all 
...... 
saver.save(<in new model folder>) 
関連する問題