2017-03-15 4 views
1

に `saver.save`を使用して保存しました。は、私は新しいJava APIを使用して、ディスクからモデルを読み込むしようとしていますTensorflowのJava API

The one example TensorflowのJava APIを使用すると、グラフ定義とパラメータの両方の重みを持つ.pbモデルファイルを読み取る方法が示されます。 Pythonの側では

、Tensorflowはディスクにモデルを保存するSaverオブジェクトを使用することを示唆しています。これは、定義を持ち、重みのファイルが.dataである.metaファイルを作成します。 Pythonでは、私はnew_saver=tf.train.import_meta_graph(var_filename) new_saver.restore(sess, model_filename)を使ってディスクからモデルを読み込みます。

Java APIでどうやったらいいですか?

答えて

0

SavedModelBundleクラスは、あなたが探しているものと考えられます。特に、SavedModelBundle.load()Sessionを返します。このモデルを使用して保存されたモデルを実行できます。

この機能は比較的最近にJava APIに追加されたため、まだバイナリリリースには存在しないため、TensorFlow 1.1がリリースされるまでbuild the Java API from sourceにする必要があります。

+0

グレート、ありがとう。私の現在の解決策は、 'freeze_graph'を使ってグラフのdefとweightsの両方を保存し、Javaで読み込むことです。このクラスは有望です。正式リリースまで私のコードで試してみるつもりですが、 –

0

私はPythonのインターフェイスを使用してHadoopクラスタ上でモデルを訓練し、Javaで予測するモデルと学んだのパラメータを使用して、類似した何かをやっています。

  • あなたがSavedModelBuilderを使用する必要がモデル

    を保存します。ここでは指導見つける:

詐欺関数のシグネチャを構築 https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py

https://tensorflow.github.io/serving/serving_basic.html を、あなたが彼らのmnist例を使用することができます使用方法は、Java側で非常に簡単です:

SavedModelBundle load = SavedModelBundle.load(modelDir, "serve"); 
     float[][] resultArray; 
     try (Graph g = load.graph()) { 
      try (Session s = load.session(); 
       Tensor result = s.runner().feed("data", data).fetch("prediction").run().get(0)) { 
       resultArray = result.copyTo(new float[10][1]); 
      } 
     } 
     load.close(); 
     return resultArray; 

の名前を取得するにはフィードとフェッチされた操作で署名を印刷し、入力値と出力値の名前を使用できます。

print(prediction_signature) 

https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py#L119

関連する問題