2017-01-18 15 views
2

私自身のデータセットのための再入モデルを持っています。 ThoモデルはPythonでビルドされています。保存されたグラフは.pbファイルとして、ラベルファイルは.txtとして保存されます。今私はJavaを介して画像のこのモデルを使用して予測する必要があります。誰でも助けてください。Javaからのトレーニングと予測のためにPythonで書かれたテンソルフローモデルを実行する

答えて

3

TensorFlowチームはJavaインターフェイスを開発中ですが、まだ安定していません。既存のコードはhttps://github.com/tensorflow/tensorflow/tree/master/tensorflow/javaにあり、開発の更新情報はhttps://github.com/tensorflow/tensorflow/issues/5に従ってください。あなたはGraphTest.javaSessionTest.javaおよびTensorTest.javaを見て、それが現在どのように使われているかを見ることができます(しかし、説明したように、これは将来変わるかもしれません)。基本的には、Graphオブジェクトにバイナリ保存されたグラフをロードし、それにSessionを作成し、適切な値(Tensor秒)で実行すると、出力がList<Tensor>になります。ソースの例からまとめ:

import java.nio.file.Files; 
import java.nio.file.Paths; 
import org.tensorflow.Graph; 
import org.tensorflow.Session; 
import org.tensorflow.Tensor; 

try (Graph graph = new Graph()) { 
    graph.importGraphDef(Files.readAllBytes(Paths.get("saved_model.pb")); 
    try (Session sess = new Session(graph)) { 
     try (Tensor x = Tensor.create(1.0f); 
      Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 
      System.out.println(y.floatValue()); 
     } 
    } 
} 
0

私はそれが.pbで終わる、protobufファイルを読んで働い使用されるコードを。

try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) { 
    Session sess = b.session(); 
    ... 
    float[][]matrix = sess.runner() 
     .feed("x", input) 
     .feed("keep_prob", keep_prob) 
     .fetch("y_conv") 
     .run() 
     .get(0) 
     .copyTo(new float[1][10]); 
    ... 
} 

私はそれを保存するために使用Pythonのコードはでした:

signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'x': tf.saved_model.utils.build_tensor_info(x)}, 
    outputs = {'y_conv': tf.saved_model.utils.build_tensor_info(y_conv)}, 
) 
    builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model") 
    builder.add_meta_graph_and_variables(sess, 
     [tf.saved_model.tag_constants.SERVING], 
     signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} 
    ) 
    builder.save() 
関連する問題