2017-10-21 13 views
9

私は、基本的なAndroidのTensorFlowInferenceの例を1つのスレッドで正常に実行しています。マルチコアデバイスでTensorFlowを実行する

public class InferenceExample { 

    private static final String MODEL_FILE = "file:///android_asset/model.pb"; 
    private static final String INPUT_NODE = "intput_node0"; 
    private static final String OUTPUT_NODE = "output_node0"; 
    private static final int[] INPUT_SIZE = {1, 8000, 1}; 
    public static final int CHUNK_SIZE = 8000; 
    public static final int STRIDE = 4; 
    private static final int NUM_OUTPUT_STATES = 5; 

    private static TensorFlowInferenceInterface inferenceInterface; 

    public InferenceExample(final Context context) { 
     inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE); 
    } 

    public float[] run(float[] data) { 

     float[] res = new float[CHUNK_SIZE/STRIDE * NUM_OUTPUT_STATES]; 

     inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]); 
     inferenceInterface.run(new String[]{OUTPUT_NODE}); 
     inferenceInterface.fetch(OUTPUT_NODE, res); 

     return res; 
    } 
} 

以下の例のようThreadPoolで実行しているときの例では、ので、私はそれがスレッドセーフではありませんだと思いますjava.lang.ArrayIndexOutOfBoundsExceptionjava.lang.NullPointerException含む様々な例外を除いてクラッシュします。

InferenceExample inference = new InferenceExample(context); 

ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);  
Collection<Future<?>> futures = new LinkedList<Future<?>>(); 

for (int i = 1; i <= 100; i++) { 
    Future<?> result = executor.submit(new Runnable() { 
     public void run() { 
      inference.call(randomData); 
     } 
    }); 
    futures.add(result); 
} 

for (Future<?> future:futures) { 
    try { future.get(); } 
    catch(ExecutionException | InterruptedException e) { 
     Log.e("TF", e.getMessage()); 
    } 
} 

それはTensorFlowInferenceInterfaceでマルチコアAndroidデバイスを活用することは可能ですか?

答えて

0

それはfeedへの呼び出し間で状態を保つようTensorFlowInferenceInterfaceクラスは、(runスレッドセーフではありません、fetchなど

しかし、それはSessionクラスのオブジェクトであるTensorFlowのJava APIの上に構築されますスレッドセーフ。

ですから、直接基礎となるJava APIを使用する場合があります、TensorFlowInferenceInterfaceのコンストラクタはSessionを作成し、AssetManagercode)からロードGraphでそれを設定します。

希望があれば

1

私はstaticからTensorFlowInferenceInterfaceを変更し、run方法​​作らInferenceExampleスレッドセーフにするには:その後、私はロビンにnumThreadsInterferenceExampleインスタンスのリストを丸める

private TensorFlowInferenceInterface inferenceInterface; 

public InferenceExample(final Context context) { 
    inferenceInterface = new TensorFlowInferenceInterface(assets, model); 
} 

public synchronized float[] run(float[] data) { ... } 

を。

for (int i = 1; i <= 100; i++) { 
    final int id = i % numThreads; 
    Future<?> result = executor.submit(new Runnable() { 
     public void run() { 
      list.get(id).run(data); 
     } 
    }); 
    futures.add(result); 
} 

これが行うの増加のパフォーマンスは、しかし 8コアデバイス上でこの2のnumThreadsでピークとは〜Android用スタジオモニターの50%のCPU使用率を示しています。

+0

このアプローチには強くお勧めします。もちろん、 'run'の呼び出しを同時に行うことができるようにしましたが、' TensorFlowInferenceface.feed() 'を呼び出して)入力を変更しないと意味がありません。 おそらく、スレッドが異なる入力をフィードして計算を実行できるようにしたいとします。そして、あなたが提案したアプローチはそれにとって安全ではありません。 – ash

+0

異なる入力にはなぜ安全でないのですか?ループの中で先物を 'id'の順番で格納するというマイナーな変更で、どの入力がどの出力と一致するかを知ることができます。 –

+0

ああ、申し訳ありませんが、私は誤解して、 'feed()'と 'fetch()'の呼び出しがあなたの同期 'run()'の中にあることに気付かなかった。だから、私は上記の私のコメントで間違っていた。 しかし、このアプローチでは、本質的にTensorFlowセッションの使用をシリアル化するため、並列性が制限されます。一度に1つのスレッドしかモデルを実行できません。 – ash

関連する問題