2017-01-12 9 views
6

私は基本的にテンソルフローでGenerative Adversarial Network(GAN)をトレーニングしています。ここでは基本的に独自のオプティマイザを持つ2つの異なるネットワークがあります。Tensorflowで変数のサブセットを復元する

self.G, self.layer = self.generator(self.inputCT,batch_size_tf) 
self.D, self.D_logits = self.discriminator(self.GT_1hot) 

... 

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step) 

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \ 
         .minimize(self.d_loss, var_list=self.d_vars) 

問題は、最初にネットワーク(g)の1つを訓練してから、gとdを一緒に訓練したいということです。私は、ロード関数を呼び出すときしかし、:

self.sess.run(tf.initialize_all_variables()) 
self.sess.graph.finalize() 

self.load(self.checkpoint_dir) 

def load(self, checkpoint_dir): 
    print(" [*] Reading checkpoints...") 

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
     ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 
     self.saver.restore(self.sess, ckpt.model_checkpoint_path) 
     return True 
    else: 
     return False 

を私は(多くのトレースバック付き)このようなエラーがあります:

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000 

私はgネットワークを復元し、その関数でトレーニングを続けることができ、しかし、私が最初からdに星をつけたいと思ったときに、そして記憶されたモデルからgにその誤差がある。

答えて

17

変数のサブセットを復元するには、新しいtf.train.Saverを作成し、オプションのvar_list引数に復元する変数の特定のリストを渡す必要があります。

デフォルトでは、tf.train.Saverはあなたが(名前で)saver.save()および(ii)のルックアップを呼び出すときにsaver.restore()を呼び出すとき(i)が与えられたチェックポイントであなたのグラフにすべての変数をすべての変数を保存するOPSを作成します。これは最も一般的なシナリオのために動作しますが、あなたは、変数の特定のサブセットで動作するようにより多くの情報を提供する必要があります。

  1. あなただけの変数のサブセットを復元したい場合は、これらの変数のリストを取得することができますtf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX)を呼び出して、 "g"ネットワークを共通のwith tf.name_scope(G_NETWORK_PREFIX):またはtf.variable_scope(G_NETWORK_PREFIX):ブロックに入れると仮定します。このリストをtf.train.Saverコンストラクタに渡すことができます。

  2. あなたがチェックポイントに変数および/または、それらの変数のサブセットを復元する場合異なる名前を持っている、あなたはvar_list引数として辞書を渡すことができますが。既定では、チェックポイント内の各変数は、tf.Variable.nameプロパティの値であるキーに関連付けられています。ターゲットグラフで名前が異なる場合(たとえば、スコーププレフィックスを追加したため)、チェックポイントファイルの文字列キーをtf.Variableオブジェクト(ターゲットグラフ内)にマップする辞書を指定できます。

0

あなたが復元したい変数に設定var_list引数でtf.train.Saver()の別のインスタンスを作成することができます。 変数を保存するために別のインスタンスを作成する

0

@mrryからインスパイアされた、この問題の解決方法を提案します。 これを明確にするために、モデルが事前に訓練されたモデルで構築されている場合、チェックポイントから変数のサブセットを復元するという問題を定式化します。 まず、私たちは図書館inspect_checkpointからprint_tensors_in_checkpoint_file機能を使用するか、単にすることで、この機能を抽出する必要があります。その後、我々はTFを使用

from tensorflow.python import pywrap_tensorflow 
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): 
    varlist=[] 
    reader = pywrap_tensorflow.NewCheckpointReader(file_name) 
    if all_tensors: 
     var_to_shape_map = reader.get_variable_to_shape_map() 
     for key in sorted(var_to_shape_map): 
     varlist.append(key) 
    return varlist 
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None) 

。ちょうど@mrryのsaiedようget_collection():

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 

最後に、我々はによってセーバーを初期化することができます。

saver = tf.train.Saver(variable[:len(varlist)]) 

完全なバージョンは、私のgithubので見つけることができます:私の状況でhttps://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py

新しい変数はモデルの最後に追加されますので、単純に[:length()]を使用して必要な変数を識別することができます。より複雑な状況では、手作業で作業するか、必要なvを決定する文字列照合関数ariables。

関連する問題