2016-08-24 13 views
4

私はTensorFlowメタグラフとして再トレーニングされたInceptionV3 NNを正常にエクスポートしました。私はこのprotobufをPythonに戻して読みましたが、TensorFlowの外側にあるnnを再作成するために、メタグラフのprotobuf内に格納されていると仮定して、各レイヤーのウェイトとバイアスの値をエクスポートする方法を見て苦労しています。訓練されたメタグラフからの重みとバイアス

私のワークフローは、次のような次のとおりです。ここから

Retrain final layer for new categories 
Export meta graph tf.train.export_meta_graph(filename='model.meta') 
Build python pb2.py using Protoc and meta_graph.proto 
Load Protobuf: 

import meta_graph_pb2 
saved = meta_graph_pb2.CollectionDef() 
with open('model.meta', 'rb') as f: 
    saved.ParseFromString(f.read()) 

私はノード名と、そのような同じように、グラフのほとんどの側面を見ることができますが、私は私の経験不足は、それが困難に正しい方法を追跡するために作っていると思います関連する各レイヤーのウェイト値とバイアス値にアクセスします。

答えて

6

MetaGraphDef protoには、実際には重みと偏りの値は含まれていません。代わりにtf.train.Saverで書かれた1つまたは複数のチェックポイントファイルに格納されている重みにGraphDefを関連付ける方法を提供します。 MetaGraphDef tutorialは詳細を持っていますが、次のようにおおよその構造は次のとおりです。あなたのトレーニングプログラムで

  1. tf.train.Saverを使用して、チェックポイントを書き出します。同じディレクトリにある.metaファイルにもMetaGraphDefが書き込まれます。

    saver = tf.train.Saver(...) 
    # ... 
    saver.save(sess, "model") 
    

    あなたのチェックポイントディレクトリに(いくつかの整数NNNN用)model.metaと呼ばれるファイルやmodel-NNNNを見つける必要があります。

  2. 別のプログラムでは、作成したMetaGraphDefをインポートし、チェックポイントから復元することができます。あなたは、各変数の値を取得したい場合は

    、あなたは(たとえば)tf.all_variables()コレクション内の変数を見つけて、その値を取得するためにsess.run()にそれを渡すことができます。たとえば、すべての変数の値を印刷するには、次の操作を実行できます。また、あなたはモデルから抽出しようとしている特定の重みとバイアスを見つけることtf.all_variables()をフィルタリングすることができ

    for var in tf.all_variables(): 
        print var.name, sess.run(var) 
    

+0

ありがとう、これは大きな助けです。私は訓練されたネットワークを再作成し、この新しい[iOSの例](https://developer.apple.com/library/prerelease/content/samplecode/MetalImageRecognition/Introduction/Intro.html)のように展開しようとしています。あなたの専門知識を考えれば、ネットワークパラメータ(浮動小数点配列としての各レイヤの重みとバイアスのバイナリ.datファイル)を確立するための正しい方法でしょうか?最終的な目標は、TensorFlow訓練されたネットワーク、金属の推論です。 –

+0

ええと、それはネットワークの複雑さによって決まります。チェックポイント(例えば、 'tf.train.NewCheckpointReader() 'を使って)を直ちに実行し、' MetaGraphDef'をバイパスしてチェックポイントファイルからNumPy配列を読み出すことができます。確かに、iOSプログラムで 'CheckpointReader'のC++実装を使用することもできます(もしリンクがいかに難しいのか分かりませんが)。 – mrry

+0

'tf.train.NewCheckpointReader()は素晴らしいAPIです。上記のdatファイルに格納されている配列を再作成したいと考えているinception_v3チェックポイントをロードしました。彼らがネットワークパラメータをどのように生成しているかに関して、私のための切断があります。ネットワークパラメータとして190ファイル(1つのバイアス、各ノードにつき1つのウェイトファイル)をロードしますが、TFは開始モデルのために1200を超える格納変数を持っています。ギャップを埋めるための指針は? –

関連する問題