2017-09-09 45 views
2

私は1つのモデルを訓練し、そのモデルをフリーズして1つの.pbファイルを作成しました。 私の質問は、.pbファイルからウェイトを取得する方法か、ウェイトを得るためにさらに処理する必要があります。Tensorflowによる.pbファイルからの値の重みの取得

@mryry、私を案内してください。

+0

残念ながら私はmrryではありませんが、モデルをフリーズするとGraphDefが得られます。 [PythonでGraphDefを解析する](https://www.tensorflow.org/extend/tool_developers/#graphdef)があります。これは定数の値(凍結された重みを含む)を持ちます。 –

+0

ohk ..ありがとうございました。 –

答えて

4

まずグラフを.pbファイルから読み込みます。あなたは.pbにグラフを凍結するとき今、あなたの変数はConst型に変換されているファイルとtrainabe変数だった重みも.pbファイルにConstとして格納されます。

import tensorflow as tf 
from tensorflow.python.platform import gfile 

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file 
with tf.Session(config=config) as sess: 
    print("load graph") 
    with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    sess.graph.as_default() 
    tf.import_graph_def(graph_def, name='') 
    graph_nodes=[n for n in graph_def.node] 

graph_nodesには、グラフ内のすべてのノードが含まれます。しかし、私たちはすべてのConstタイプノードに興味があります。

wts = [n for n in graph_nodes if n.op=='Const'] 

wtsの各要素はNodeDef型です。名前、演算子などいくつかの属性があります。値は次のように抽出できます。

from tensorflow.python.framework import tensor_util 

for n in wts: 
    print "Name of the node - %s" % n.name 
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor) 

これがあなたの懸念を解決することを望みます。

関連する問題