テンソルフローr1.1を使用して、事前にトレーニングされたvgg-16ネットワークをロードしようとしています。ネットワークは、3つのファイルで提供されていますテンソルフローで事前にトレーニングされたvgg-16をロードする
- saved_model.pb
- 変数/ variables.index
- 変数/ variables.data-00000-の-00001
変数を初期化した後sess
as tf.Session()
次のスクリプトを使用してネットワークを読み込み、特定のレイヤーを抽出します。
vgg_path='./'
model_filename = os.path.join(vgg_path, "saved_model.pb")
export_dir = os.path.join(vgg_path, "variables/")
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
image_input, l7, l4, l3 = tf.import_graph_def(sm.meta_graphs[0].graph_def,
name='',return_elements=["image_input:0", "layer7_out:0",
"layer4_out:0", "layer3_out:0"])
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, image_input)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l7)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l4)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l3)
saver = tf.train.Saver(tf.global_variables())
print("load data")
saver.restore(sess, export_dir)
変数saver
を初期化するとき、スクリプトは次のエラーで終了します。
TypeError: Variable to save is not a Variable: Tensor("image_input:0", shape=(?, ?, ?, 3), dtype=float32)
は、どのように私は私のスクリプトを修正し、事前に訓練されたVGGネットワークを復元することができますか?
あなたはsaved_model.loader https://www.tensorflow.org/api_docs/python/tf/saved_model/loaderを試みることができます – thertweck
ありがとう、それは私のために働いた!あなたがそれを答えさせたら私はそれを受け入れて、私のコード解を書き留めます。 –