2016-02-19 11 views
15

トレーニングを受けたネットワークをC++にインポートするには、ネットワークをエクスポートしてエクスポートする必要があります。たくさんの情報を検索して情報をほとんど見つけていない場合は、freeze_graph()を使用する必要があることが明らかになりました。Tensorflow C++でグラフをエクスポートおよび実行するさまざまな方法

新しい0.7バージョンのTensorflowのおかげで、彼らはdocumentationを追加しました。

ドキュメンテーションに見た後、私は、いくつかの類似した方法があることが判明あなたはfreeze_graph()との違いが何であるかを伝えることができます:それは同様のパラメータを持っているよう tf.train.export_meta_graphが、それはCにモデルをインポートするためにも使用することができそうです++

write_graph()を使用する方法についても一つの質問(私は違いは、この方法でファイル出力を使用するためだけimport_graph_def()を使用することができるか、それが何か他のものだということですね?):graph_defsess.graph_defで与えられるドキュメンテーションで をが、 freeze_graph()の例ではsess.graph.as_graph_def()です。これら2つの違いは何ですか?

この質問はthis issue.

に関連しているありがとうございました!あなたが訓練した後あなたの変数または操作将来の使用のための

self.init = tf.initialize_variables(tf.all_variables(), name="nInit") 

をすべきpythonですべての

まず:あなたはこのような何かを行うことができます予測(および他のすべての操作)については

+0

「Tensorflow Serving」という別のライブラリがこの機能を提供しているようです:https://tensorflow.github.io/serving/serving_basic.html。私はまだプレーンなTensorflowでこれを行う方法を理解したいと思うでしょう! –

答えて

13

あなたの変数が割り当てられているときに、それらのすべてを調べて、グラフに定数として保存するときの計算。 (ほぼ同じにはそのフリーズツールで行うことができます、しかし、私は通常、自分でそれを行うPY、以下のcppに「名前= nWeights」をご確認ください)

def save(self, filename): 
    for variable in tf.trainable_variables(): 
     tensor = tf.constant(variable.eval()) 
     tf.assign(variable, tensor, name="nWeights") 

    tf.train.write_graph(self.sess.graph_def, 'graph/', 'my_graph.pb', as_text=False) 

今Cを行く++から私たちのグラフと負荷変数をロード保存された定数:

void load(std::string my_model) { 
     auto load_graph_status = 
       ReadBinaryProto(tensorflow::Env::Default(), my_model, &graph_def); 

     auto session_status = session->Create(graph_def); 

     std::vector<tensorflow::Tensor> out; 
     std::vector<string> vNames; 

     int node_count = graph_def.node_size(); 
     for (int i = 0; i < node_count; i++) { 
      auto n = graph_def.node(i); 

      if (n.name().find("nWeights") != std::string::npos) { 
       vNames.push_back(n.name()); 
      } 
     } 

     session->Run({}, vNames, {}, &out); 

これで、すべてのニューラルネットの重みやその他の変数が読み込まれました。

同様に、他の操作も実行できます(名前については覚えておいてください)。適切なサイズの入力テンソルと出力テンソルを作成し、データでテンソルを入力し、セッションを実行します。

auto operationStatus = session->Run(input, {"put_your_operation_here"}, {}, &out); 
+0

私はすでにC++で実行中のコードを作成していますが、私の質問では、異なる方法とその違いについて質問していますか? Btw、あなたが提案した方法はほとんどトリックですが、正しい方法は 'freeze_graph()'コードを使ってノードをフリーズすることです。 –

+0

正直言って私はまだフリーズグラフをテストする時間がありませんでした。私のプロジェクトで作業していたときには、これを行うための明確なAPIやツールはありませんでした( –

+0

@Alex Joz私はあなたの解決策を試しましたが、変数、たとえばいくつかのバッチ正規化の移動平均と分散、 tf.traceable_variables()の変数:tf.global_variables()の変数に 'wit'を置き換えました。問題は、C++の結果として得られるネットワークが常に異なるサンプルで同じ出力が出てくるのですが、どこから来たのか、修正する方法は? – gdelab

21

ここでは、TF 0.12で導入されたV2チェックポイントを利用しています。

すべての変数を定数またはfreeze the graphに変換する必要はありません。ただ、明確にするため、V2のチェックポイントは、私のディレクトリmodelsに次のようになります

checkpoint # some information on the name of the files in the checkpoint 
my-model.data-00000-of-00001 # the saved weights 
my-model.index # probably definition of data layout in the previous file 
my-model.meta # protobuf of the graph (nodes and topology info) 

Pythonの一部(保存)

with tf.Session() as sess: 
    tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model') 

あなたはtf.trainable_variables()Saverを作成する場合、次のことができます頭痛や記憶スペースを節約しましょう。しかし、もっと複雑なモデルの場合は、すべてのデータを保存する必要がありますが、この引数をSaverにすると、のグラフが作成された後にSaverが作成されていることを確認してください。すべての変数/レイヤに一意の名前を付けることも非常に賢明です。そうでなければdifferent problemsで実行できます。

Pythonの一部(推論)

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('models/my-model.meta') 
    saver.restore(sess, tf.train.latest_checkpoint('models/')) 
    outputTensors = sess.run(outputOps, feed_dict=feedDict) 

C++の部分(推論)checkpointPathは、既存のファイルのいずれかへのパスではないことを

ノート、ちょうど彼らの共通のプレフィックス。あなたが間違って.indexファイルへのパスを入れた場合、TFはそれが間違っているとは言いませんが、初期化されていない変数による推論の間に死ぬでしょう。

#include <tensorflow/core/public/session.h> 
#include <tensorflow/core/protobuf/meta_graph.pb.h> 

using namespace std; 
using namespace tensorflow; 

... 
// set up your input paths 
const string pathToGraph = "models/my-model.meta" 
const string checkpointPath = "models/my-model"; 
... 

auto session = NewSession(SessionOptions()); 
if (session == nullptr) { 
    throw runtime_error("Could not create Tensorflow session."); 
} 

Status status; 

// Read in the protobuf graph we exported 
MetaGraphDef graph_def; 
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def); 
if (!status.ok()) { 
    throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString()); 
} 

// Add the graph to the session 
status = session->Create(graph_def.graph_def()); 
if (!status.ok()) { 
    throw runtime_error("Error creating graph: " + status.ToString()); 
} 

// Read weights from the saved checkpoint 
Tensor checkpointPathTensor(DT_STRING, TensorShape()); 
checkpointPathTensor.scalar<std::string>()() = checkpointPath; 
status = session->Run(
     {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },}, 
     {}, 
     {graph_def.saver_def().restore_op_name()}, 
     nullptr); 
if (!status.ok()) { 
    throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString()); 
} 

// and run the inference to your liking 
auto feedDict = ... 
auto outputOps = ... 
std::vector<tensorflow::Tensor> outputTensors; 
status = session->Run(feedDict, outputOps, {}, &outputTensors); 
+1

こんにちは、@ peci1、あなたの仕事のおかげで。このC++コードをどのようにコンパイルできるのか教えてください。私はバゼルを使うべきですか?簡単な例のコマンドを表示できますか? –

+0

@YWPKwonこの例は、完全なコードまたはスタンドアロンコードではありません。あなたはそれをあなたにとって有益な方法で組み込むべきです(例えば、それをメインメソッドに入れる)。コンパイルは、bazelを使用して標準的な方法で行うこともできますし、C++環境を得るためにPython pipのインストールを(誤って)使用する方法を見てみることもできます。 https://github.com/tradr-project/tensorflow_rosにあり、catkinビルドツール(https://github.com/ros/catkin/)を使用しています。 –

+2

@YWPKwon TFを使用してC++コードを構築するには、http://www.tensorflow.org/api_guides/cc/guideまたはCMake(https:// github)に記載されているように、bazelを使用することもできます。com/FloopCZ/tensorflow_cc – Floop

関連する問題