2017-01-13 6 views
1

私はしばらくしてアンドロイド上で動作する事前訓練モデルを取得しようとしました。問題は、私は訓練されたネットのためのckptとメタファイルしか持っていないということです。私の意見では、アンドロイドアプリの.pbが必要です。そこで、指定したファイルを.pbファイルに変換しようとしました。Pretrained Inception v3モデルをAndroidで動作するOpen Imagesデータセットから取得する

私はfreeze_graph.pyを試しましたが、succesはありませんでした。そこで、https://github.com/openimages/dataset/blob/master/tools/classify.pyのサンプルコードを使用して、それを修正してpbを保存しました。

if not os.path.exists(FLAGS.checkpoint): 
    tf.logging.fatal(
     'Checkpoint %s does not exist. Have you download it? See tools/download_data.sh', 
     FLAGS.checkpoint) 
    g = tf.Graph() 
with g.as_default(): 
    input_image = tf.placeholder(tf.string) 
    processed_image = PreprocessImage(input_image) 

    with slim.arg_scope(inception.inception_v3_arg_scope()): 
    logits, end_points = inception.inception_v3(
     processed_image, num_classes=FLAGS.num_classes, is_training=False) 

    predictions = end_points['multi_predictions'] = tf.nn.sigmoid(
     logits, name='multi_predictions') 
    init_op = control_flow_ops.group(tf.global_variables_initializer(), 
          tf.global_variables_initializer(), 
          data_flow_ops.initialize_all_tables()) 
    saver = tf_saver.Saver() 
    sess = tf.Session() 
    saver.restore(sess, FLAGS.checkpoint) 

    outpt_filename = 'output_graph.pb' 
    #output_graph_def = sess.graph.as_graph_def() 
    output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ["multi_predictions"]) 
    with gfile.FastGFile(outpt_filename, 'wb') as f: 
     f.write(output_graph_def.SerializeToString()) 

をロードした後のファイルは、今私の問題は、私は.pbファイルを持っているが、私は入力ノード名が何であるかを任意の意見を持っていないと私はmulti_predictionsは、右の出力名であるかどうかわからないですということです。私はアンドロイドアプリの例では、両方を指定する必要があります。アンドロイドアプリがクラッシュしました:

tensorflow_inference_jni.cc:138 Could not create Tensorflow Graph: Invalid argument: No OpKernel was registered to support Op 'DecodeJpeg' with these attrs. 

.pbの問題を修正することによってさらに問題があるかどうかわかりません。または、誰かが私のケースで.pdファイルにckptとメタファイルを移植するより良い方法を知っているか、入力と出力名を持つ最終ファイルのソースを知っているなら、このタスクを完了するためのヒントを与えてください。

おかげ

答えて

2

あなたは、グラフ内の未使用のノードを取り除くためにoptimize_for_inference.pyスクリプトを使用する必要があります。 "decodeJpeg"はAndroidではサポートされていません - ピクセル値は直接入力する必要があります。 ClassifierActivity.javaには、開始v3に使用する特定のノードの詳細があります。

+0

しばらくお待ちください。あなたは、optimize_for_inference.pyを使用するためのあなたのヒントで絶対に正しいですし、入力ノードとして "Mul"を使用することができます。多くのおかげで – lampep

+0

こんにちは@ lampepどのように画像を前処理していますか?私はこれをiosで実行しようとしており、同じ予測が繰り返されます。私はあなたのコードに従って、最適化されたグラフを生成しましたが、予測は正しく機能していません。ありがとう! –

関連する問題