2017-10-31 39 views
1

新しいTF Estimator APIでの転送学習/最後のレイヤ再学習の使い方を理解できませんでした。TensorFlowエスティメーターでの学習/再学習

Estimatorには、documentationで定義されているように、ネットワークのアーキテクチャとトレーニングと評価動作を含むmodel_fnが必要です。 CNNアーキテクチャを使用するmodel_fnの例はhereです。

たとえば、開始アーキテクチャの最後のレイヤーを再トレーニングしたい場合は、model_fnにモデル全体を指定する必要があるかどうかわからず、事前にトレーニングされたウェイトをロードするか'従来の'アプローチ(例:here)で行われているように、保存されたグラフを使用する方法です。

これはissueとして公開されていますが、まだ公開されており、回答は私には不明です。

答えて

2

モデル定義中にメタデータをロードし、SessionRunHookを使用してckptファイルからウェイトをロードすることは可能です。

def model(features, labels, mode, params): 
    # Create the graph here 

    return tf.estimator.EstimatorSpec(mode, 
      predictions, 
      loss, 
      train_op, 
      training_hooks=[RestoreHook()]) 

SessionRunHookをすることができます:

class RestoreHook(tf.train.SessionRunHook): 

    def after_create_session(self, session, coord=None): 
     if session.run(tf.train.get_or_create_global_step()) == 0: 
      # load weights here 

この方法では、重みが最初のステップでロードされ、モデルのチェックポイントでのトレーニング中に保存されます。