2016-11-23 25 views
4

私はlibtensorflow.soターゲットを構築してC APIを構築しました。私は事前に訓練されたモデルを読み込んで推論を実行して予測をしたいと思っています。私は 'c_api.h'ヘッダーファイル(適切な場所に 'libtensorflow.so'をコピーすることに加えて)を含めることでこれを行うことができると言われましたが、Web上のサンプルを見つけることはできませんでした。私が見つけることができるのは、Bazelビルドシステムを使用する例ですが、別のビルドシステムを使用し、TensorFlowをライブラリとして使用したいと考えています。誰かが私にどのようにa)メタグラフファイルをインポートするかの例を教えてもらえますか? b)protobufグラフファイルとチェックポイントファイルを使用して予測を行いますか?以下のPythonファイルに相当するC++に相当し、g ++でビルドされていますか?TensorFlowで訓練されたモデルとC APIを使用して予測する

#!/usr/bin/env python 

import tensorflow as tf 
import numpy as np 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('./metagraph.meta') 
    saver.restore(sess, './checkpoint.ckpt') 
    x = tf.get_collection("x")[0] 
    yhat = tf.get_collection("yhat")[0] 
    print sess.run(yhat, feed_dict={x : np.array([[2, 3], [4, 5]])}) 

ありがとうございました!

PS:私はファイルをビルドするには、以下のなかったしている完全を期すために:/

#!/usr/bin/env python 

import tensorflow as tf 
import numpy as np 

x = tf.placeholder(tf.float32, shape=[None, 2], name='x') 
tf.add_to_collection("x", x) 
y = tf.placeholder(tf.float32, shape=[None, 1], name='y') 
w = tf.Variable(np.array([[10.0], [100.0]]), dtype=tf.float32, name='w') 
b = tf.Variable(0.0, dtype=tf.float32, name='b') 
yhat = tf.add(tf.matmul(x, w), b) 
tf.add_to_collection("yhat", yhat) 
mse_loss = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(y, yhat)))) 
step_size = tf.constant(0.01) 
optimizer = tf.train.GradientDescentOptimizer(step_size) 
init_op = tf.initialize_all_variables() 
train_op = optimizer.minimize(mse_loss) 
saver = tf.train.Saver() 
with tf.Session() as sess: 
    sess.run(init_op) 
    for i in xrange(10000): 
     train_x = np.random.random([100, 2]) * 10 
     train_y = np.dot(train_x, np.array([[100.0], [10.0]])) + 1.0 
     sess.run(train_op, feed_dict={x : train_x, y : train_y}) 
    print sess.run(w) 
    print sess.run(b) 
    saver.save(sess, './checkpoint.ckpt') 
    saver.export_meta_graph('./metagraph.meta') 
    tf.train.write_graph(sess.graph_def, './', 'graph') 
+0

こんにちは、TensorFlow C++ APIドキュメントをチェックしましたか?このページでは、セッションを実行するために必要なものを提供する必要があります:https://www.tensorflow.org/versions/r0.11/api_docs/cc/ClassSession.htmlこれは、グラフを読むのに役立ちます:https: //www.tensorflow.org/versions/r0.11/api_docs/cc/index.html – Neal

+0

ありがとうございました。私はそれらのリンクを見ましたが、私は例を探していて見つけられませんでした。 –

答えて

1

私はEclipseを使用して、/ usr/localに私のプロジェクトファイルへとlibtensorflow.so c_api.hを追加しましたビン。次に、libtensorflow共有オブジェクトへの参照をGCC C++ Linker上のライブラリに追加して、簡単なプログラムを作成しました。

#include <iostream> 
#include "c_api.h" 

using namespace std; 

int main() { 
    cout << TF_Version(); 
    return 0; 
} 

これにより、Tensorflow関数をコンパイルして使用することができました。

関連する問題