2017-10-11 7 views
0

保存されたモデルoutput_graph.pbからすべての重み/バイアスを抽出しようとしています。tf.GraphKeys.TRAINABLE_VARIABLES on output_graph.pb空リストの結果

私がモデル読み:

def create_graph(modelFullPath): 
    """Creates a graph from saved GraphDef file and returns a saver.""" 
    # Creates graph from saved graph_def.pb. 
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(graph_def, name='') 

GRAPH_DIR = r'C:\tmp\output_graph.pb' 
create_graph(GRAPH_DIR) 

をそして、私は、各レイヤ内のすべての重み/バイアスを 抽出することができるだろう期待してこれを実行しようとしました。

with tf.Session() as sess: 
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
    print (len(all_vars)) 

しかし、私はlenとして0の値を得ています。

最後の目標は、重みと偏りを抽出し、テキストファイル/ np.arraysに保存することです。

答えて

1

tf.GraphKeys.TRAINABLE_VARIABLESコレクションを再構築するには、tf.import_graph_def()関数に十分な情報がありません(そのためには、MetaGraphDefが必要です)。ただし、output.pbに「フリーズ」が含まれている場合は、すべてのウェイトがグラフのノードtf.constant()に格納されます。 constant_valuesはおそらくちょうど重み以上の値が含まれていますので、あなたがop.nameまたはいくつかの他の基準によってさらにフィルタリングする必要があるかもしれないこと

create_graph(GRAPH_DIR) 

constant_values = {} 

with tf.Session() as sess: 
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] 
    for constant_op in constant_ops: 
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0]) 

注:それらを抽出するには、次のような何かを行うことができます。