2017-08-02 14 views
5

私は新しいデータセットで微調整された初期モデルを持っていて、それをKerasの ".h5"モデルとして保存しました。私の目標は ".pb"拡張子だけを受け入れるAndroid Tensorflowで自分のモデルを実行することです。質問は、この変換を行うためにKerasまたはtensorflowにライブラリがあるかどうかです。私は今までこの投稿を見てきました:https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.htmlしかしまだ分かりません。Keras .h5をテンソルフロー.pbにエクスポートするには?

答えて

9

Kerasには、TensorFlowグラフをプロトコルバッファファイルとしてエクスポートする手段は含まれていませんが、通常のTensorFlowユーティリティを使用して実行できます。 Hereは、TensorFlowに含まれているユーティリティスクリプトfreeze_graph.pyを使用して行う方法を説明するブログ記事です。これは「一般的な」方法です。

しかし、私は個人的にチェックポイントを作成して、モデルを得るために外部スクリプトを実行し、代わりに自分のPythonコードからそれを行うことを好むようになる迷惑を見つけたので、私はこのような機能を使用します。

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): 
    """ 
    Freezes the state of a session into a pruned computation graph. 

    Creates a new computation graph where variable nodes are replaced by 
    constants taking their current value in the session. The new graph will be 
    pruned so subgraphs that are not necessary to compute the requested 
    outputs are removed. 
    @param session The TensorFlow session to be frozen. 
    @param keep_var_names A list of variable names that should not be frozen, 
          or None to freeze all the variables in the graph. 
    @param output_names Names of the relevant graph outputs. 
    @param clear_devices Remove the device directives from the graph for better portability. 
    @return The frozen graph definition. 
    """ 
    from tensorflow.python.framework.graph_util import convert_variables_to_constants 
    graph = session.graph 
    with graph.as_default(): 
     freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
     output_names = output_names or [] 
     output_names += [v.op.name for v in tf.global_variables()] 
     input_graph_def = graph.as_graph_def() 
     if clear_devices: 
      for node in input_graph_def.node: 
       node.device = "" 
     frozen_graph = convert_variables_to_constants(session, input_graph_def, 
                 output_names, freeze_var_names) 
     return frozen_graph 

freeze_graph.pyの実装に影響を受けています。パラメータもスクリプトに似ています。 sessionはTensorFlowセッションオブジェクトです。 keep_var_namesは、一部の変数をフリーズしない(ステートフルモデルの場合など)ようにしたい場合にのみ必要です。 output_namesは、必要な出力を生成する操作の名前のリストです。 clear_devicesは、グラフをより移植性の高いものにするために、デバイス指示文を削除するだけです。だから、一つの出力を持つ典型的なKeras modelのために、あなたのようなものだろう。そして、あなたはtf.train.write_graphといつものようにファイルにグラフを書くことができます

from keras import backend as K 

# Create, compile and train model... 

frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name]) 

を:

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False) 
1

freeze_session方法が正常に動作します。しかし、チェックポイントファイルに保存するのと比較して、TensorFlowに付属しているfreeze_graphツールを使用する方がメンテナンスが簡単なので、もっと簡単です。あなたがする必要があるのは以下の2つのステップです:

まず、あなたのKerasコードmodel.fit(...)後に追加して、あなたのモデルを訓練:

from keras import backend as K 
import tensorflow as tf 
print(model.output.op.name) 
saver = tf.train.Saver() 
saver.save(K.get_session(), '/tmp/keras_model.ckpt') 

その後、あなたのTensorFlowのルートディレクトリにcdし、実行します。

python tensorflow/python/tools/freeze_graph.py \ 
--input_meta_graph=/tmp/keras_model.ckpt.meta \ 
--input_checkpoint=/tmp/keras_model.ckpt \ 
--output_graph=/tmp/keras_frozen.pb \ 
--output_node_names="<output_node_name_printed_in_step_1>" \ 
--input_binary=true 
関連する問題