2017-04-20 1 views
3

私はJavaで訓練されたモデル(Tensorflow、Python)をインポートして使用しようとしています。TensorflowのJavaへのインポートモデル

モデルをPythonで保存できましたが、Javaで同じモデルを使用して予測を行うときに問題が発生しました。

Hereを使用すると、モデルの初期化、トレーニング、保存のためのPythonコードが表示されます。

Hereを入力すると、入力値のインポートと予測のためのJavaコードが表示されます。

私が取得エラーメッセージは次のとおりです。 Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7 [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:285) at org.tensorflow.Session$Runner.run(Session.java:235) at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

私は信じて、問題がどこかにPythonコードであるが、私はそれを見つけることができませんでした。

ご協力いただきましてありがとうございます。

、ありがとう

ピーター

+0

Iは[これ](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java)使用インスピレーションとして – szi

答えて

5

Java importGraphDef()関数は、 (あなたのPythonコードでtf.train.write_graphで書かれています)、それは訓練された変数(チェックポイントに格納されている)の値をロードしていないため、初期化されていない変数について不平を言うエラーが発生します。

一方、TensorFlow SavedModel formatには、モデル(グラフ、チェックポイントの状態、その他のメタデータ)に関するすべての情報が含まれており、訓練された変数値で初期化されたセッションを作成するのに、SavedModelBundle.loadを使用すると便利です。

のPythonからこの形式でモデルをエクスポートするには、あなたがあなたのケースでは、関連する質問でDeploy retrained inception SavedModel to google cloud ml engine

をご覧になる場合があります、これはPythonで、次のようなものに達する必要があります。

def save_model(session, input_tensor, output_tensor): 
    signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)}, 
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)}, 
) 
    b = saved_model_builder.SavedModelBuilder('/tmp/model') 
    b.add_meta_graph_and_variables(session, 
           [tf.saved_model.tag_constants.SERVING], 
           signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}) 
    b.save() 

そしてsave_model(session, x, yhat)

そしてJavaの負荷のモデルを経由して使用していることを呼び出します。

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) { 
    // b.session().run(...) 
} 
0123を

希望があれば

+0

警告:これはJavaで動作しますが、現在TFはAndroidでSavedModelの読み込みをサポートしていません。これを難しい方法で見つけました。 :/ – Keilaron

+0

代わりに、#12750または#13079の問題を参照するか、https://www.tensorflow.org/mobile/prepare_modelsを参照してください。 – Keilaron

1

あなたのpython-モデルは確かにこの時に失敗します。

sess.run(init) #<---this will fail 
save_model(sess) 
error = tf.reduce_mean(tf.square(prediction - y)) 

#accuracy = tf.reduce_mean(tf.cast(error, 'float')) 
print('Error:', error) 

initは、モデルで定義されていない - 私はあなたがで達成したいものをわかりませんよこの場所は、あなたに出発点を与えるはずです

1

Fwiw、Deeplearning4jは、Keras 1.0でTensorFlowで訓練されたモデルをインポートすることができます(Keras 2.0サポートが進行中です)。

https://deeplearning4j.org/model-import-keras

我々はまた、テンソルを扱うときPy4jよりも、それをより効率的に代わりにコピーするデータのポインタを使用numpyのアレイとPyjniusのラッパーですJumpyで呼ばれるライブラリを、構築されました。

https://deeplearning4j.org/jumpy

+0

ありがとうございます! 私はTensorflowでGANを構築しようとしていましたが、DL4Jは一般的にGANをサポートしていないことがわかっている限り、Tensorflowで訓練されたGANを使用でき、JVMで動作する回避策を探していました。 – szi

関連する問題