2017-06-11 8 views
0

テンソルフローr1.1を使用して、事前にトレーニングされたvgg-16ネットワークをロードしようとしています。ネットワークは、3つのファイルで提供されていますテンソルフローで事前にトレーニングされたvgg-16をロードする

  • saved_model.pb
  • 変数/ variables.index
  • 変数/ variables.data-00000-の-00001

変数を初期化した後sess as tf.Session()

次のスクリプトを使用してネットワークを読み込み、特定のレイヤーを抽出します。

vgg_path='./' 
model_filename = os.path.join(vgg_path, "saved_model.pb") 
export_dir = os.path.join(vgg_path, "variables/") 

with gfile.FastGFile(model_filename, 'rb') as f: 
    data = compat.as_bytes(f.read()) 
    sm = saved_model_pb2.SavedModel() 
    sm.ParseFromString(data) 
    image_input, l7, l4, l3 = tf.import_graph_def(sm.meta_graphs[0].graph_def, 
      name='',return_elements=["image_input:0", "layer7_out:0", 
      "layer4_out:0", "layer3_out:0"]) 

tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, image_input) 
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l7) 
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l4) 
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l3) 

saver = tf.train.Saver(tf.global_variables()) 
print("load data") 
saver.restore(sess, export_dir) 

変数saverを初期化するとき、スクリプトは次のエラーで終了します。

TypeError: Variable to save is not a Variable: Tensor("image_input:0", shape=(?, ?, ?, 3), dtype=float32)

は、どのように私は私のスクリプトを修正し、事前に訓練されたVGGネットワ​​ークを復元することができますか?

+0

あなたはsaved_model.loader https://www.tensorflow.org/api_docs/python/tf/saved_model/loaderを試みることができます – thertweck

+0

ありがとう、それは私のために働いた!あなたがそれを答えさせたら私はそれを受け入れて、私のコード解を書き留めます。 –

答えて

1

あなたはSavedModelを持っているので、あなたはそれをロードするためtf.saved_model.loaderを使用することができます。

with tf.Session() as sess: 
    tf.saved_model.loader.load(sess, ["some_tag"], model_dir) 
+0

これは私の仕事を完全に解決します –

関連する問題