2016-04-04 8 views
2

こんにちはテンソルフローからVGGモデルを微調整したいと思います。私には2つの質問があります。テンソルフローモデルからウェイトを取得

ネットワークからウェイトを取得するにはどうすればよいですか? trainable_variablesは空リストを返します。

私はここから既存のモデルを使用しました:https://github.com/ry/tensorflow-vgg16。 私は重量を得ることについてのポストを見つけるが、これはimport_graph_defのために私にとってはうまくいかない。 Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf 
import PIL.Image 
import numpy as np 

with open("../vgg16.tfmodel", mode='rb') as f: 
    fileContent = f.read() 

graph_def = tf.GraphDef() 
graph_def.ParseFromString(fileContent) 

images = tf.placeholder("float", [None, 224, 224, 3]) 

tf.import_graph_def(graph_def, input_map={ "images": images }) 
print("graph loaded from disk") 

graph = tf.get_default_graph() 

cat = np.asarray(PIL.Image.open('../cat224.jpg')) 
print(cat.shape) 
init = tf.initialize_all_variables() 

with tf.Session(graph=graph) as sess: 
    print(tf.trainable_variables()) 
    sess.run(init) 
+0

複数の質問を一度に避けるようにしてください。いずれかの回答が見つからず、両方とも貴重な質問であると思われる場合は、別の場所で両方の質問をしてください。 – Giewev

答えて

4

このpretrained VGG-16 modeltf.constant() OPSなどのモデルパラメータのすべてをコードします。たとえば、tf.constant()hereの呼び出しを参照してください。その結果、モデルパラメータはtf.trainable_variables()に表示されず、モデルは実質的な手術なしには変更できません。定数ノードをtf.Variableで始まるオブジェクトに置き換える必要があります訓練を続けるために同じ値。

一般に、再学習用のグラフをインポートするときは、tf.train.import_meta_graph()関数を使用する必要があります。この関数は、追加のメタデータ(変数のコレクションを含む)を読み込むためです。 tf.import_graph_def()関数は、より低いレベルであり、これらのコレクションにデータを設定しません。