2017-08-22 25 views
0

私はアンドロイドスタジオとテンソルフロー、アンドロイドバージョンを使用して画像認識について取り組んでいます。 追跡と認識は連続的ではなく、1つの画像の認識だけです。 私は既にグラフpbとラベルtxtファイルを持っていて、必要な設定をしました。 しかし、大きな問題があります。 イメージ、寸法誤差について同じエラーが繰り返し発生しました。 ここにエラーログとソースコードがあります。Android Tensorflow IllegalArgumentExceptionエラー

java.lang.IllegalArgumentException: input must be 4-dimensional[1,1,299,299,3] 
                     [[Node: ResizeBilinear = ResizeBilinear[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](ExpandDims, ResizeBilinear/size)]] 
                     at org.tensorflow.Session.run(Native Method) 
                     at org.tensorflow.Session.access$100(Session.java:48) 
                     at org.tensorflow.Session$Runner.runHelper(Session.java:295) 
                     at org.tensorflow.Session$Runner.run(Session.java:245) 
                     at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:144) 
                     at com.example.yuuuuu.tensorTest.TensorFlowImageClassifier.recognizeImage(TensorFlowImageClassifier.java:119) 
                     at com.example.yuuuuu.tensorTest.MainActivity.runTensor(MainActivity.java:69) 
                     at com.example.yuuuuu.tensorTest.MainActivity$1.onClick(MainActivity.java:42) 
                     at android.view.View.performClick(View.java:6205) 
                     at android.widget.TextView.performClick(TextView.java:11103) 
                     at android.view.View$PerformClick.run(View.java:23653) 
                     at android.os.Handler.handleCallback(Handler.java:751) 
                     at android.os.Handler.dispatchMessage(Handler.java:95) 
                     at android.os.Looper.loop(Looper.java:154) 
                     at android.app.ActivityThread.main(ActivityThread.java:6682) 
                     at java.lang.reflect.Method.invoke(Native Method) 
                     at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1520) 
                     at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1410) 

問題の最初の行[1,1,299,299,3]はわかりません。私は2つの299はImageSize、1つはImageStdだと思いますが、別の1と3が何であるかわかりません... 私はtensorflow githubの公式コードと同じコードを入力し、ちょっと変わったところです。 これはMainActivityです。

public class MainActivity extends AppCompatActivity { 

private static final String MODEL_FILE = "file:///android_asset/optimized_graph.pb"; 
private static final String LABEL_FILE = "file:///android_asset/output_labels.txt"; 
private static final String INPUT_NAME = "Cast"; 
private static final String OUTPUT_NAME = "final_result"; 
private static final int INPUT_SIZE = 299; 
private static final int IMAGE_MEAN = 117; 
private static final float IMAGE_STD = 1; 

private Classifier classifier; 
private TextView textView; 
private ImageView img; 
private Button button; 

@Override 
protected void onCreate(Bundle savedInstanceState) { 
    super.onCreate(savedInstanceState); 
    setContentView(R.layout.activity_main); 

    textView = (TextView)findViewById(R.id.textView); 
    button = (Button)findViewById(R.id.btn); 
    img = (ImageView)findViewById(R.id.img); 

    button.setOnClickListener(new View.OnClickListener(){ 
     public void onClick(View v){ 
      runTensor(); 
     } 
    }); 

    initTensor(); 
} 

public void initTensor(){ 
    classifier = TensorFlowImageClassifier.create(
      getAssets(), 
      MODEL_FILE, 
      LABEL_FILE, 
      INPUT_SIZE, 
      IMAGE_MEAN, 
      IMAGE_STD, 
      INPUT_NAME, 
      OUTPUT_NAME 
    ); 
} 

public void runTensor(){ 
    Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test); 
    bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false); 

    img = (ImageView)findViewById(R.id.img); 
    img.setImageBitmap(bitmap); 

    final List<Classifier.Recognition> results = classifier.recognizeImage(bitmap); 
    textView.setText(results.toString()); 
} 

protected void onDestroy(){ 
    super.onDestroy(); 
    classifier.close(); 
} 

} 

これは公式コードと同じです。

public interface Classifier { 

public class Recognition{ 
    private final String id; 
    private final String title; 
    private final Float confidence; 
    private RectF location; 

    public Recognition(
      final String id, final String title, final Float confidence, final RectF location){ 
     this.id = id; 
     this.title = title; 
     this.confidence = confidence; 
     this.location = location; 
    } 

    public String getId(){return id;} 
    public String getTitle(){return title;} 
    public Float getConfidence(){return confidence;} 
    public RectF getLocation(){return location;} 
    public void setLocation(RectF location){this.location = location;} 

    public String toString(){ 
     String resultString = ""; 
     if (id != null) { 
      resultString += "[" + id + "] "; 
     } 

     if (title != null) { 
      resultString += title + " "; 
     } 

     if (confidence != null) { 
      resultString += String.format("(%.1f%%) ", confidence * 100.0f); 
     } 

     if (location != null) { 
      resultString += location + " "; 
     } 

     return resultString.trim(); 
    } 
} 

List<Recognition> recognizeImage(Bitmap bitmap); 
void enableStatLogging(final boolean debug); 
String getStatString(); 
void close(); 
} 

最後にTensorFlowImageClassifierも公式と同じです。

public class TensorFlowImageClassifier implements Classifier { 
private static final String TAG = "TensorFlowImageClassifier"; 

private static final int MAX_RESULTS = 3; 
private static final float THRESHOLD = 0.1f; 

private String inputName; 
private String outputName; 
private int inputSize; 
private int imageMean; 
private float imageStd; 

private Vector<String> labels = new Vector<String>(); 
private int[] intValues; 
private float[] floatValues; 
private float[] outputs; 
private String[] outputNames; 

private boolean logStats = false; 
private TensorFlowInferenceInterface inferenceInterface; 
private TensorFlowImageClassifier() {} 

/* 
assetManager : assets 로드하는데 사용 
modelFilename : pb 파일 
labelFilename : txt 파일 
inputSize : 정사각형 길이, inputSize * inputSize 
imageMean : image values 평균값 
imageStd : image values 표준값? 
inputName : image input 노드 레이블 
outputName : output 노드 레이블 
*/ 

public static Classifier create(
     AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName){ 
    TensorFlowImageClassifier c = new TensorFlowImageClassifier(); 
    c.inputName = inputName; 
    c.outputName = outputName; 

    String actualFilename = labelFilename.split("file:///android_asset/")[1]; 
    Log.d(TAG, "reading labels from : " + actualFilename); 
    BufferedReader br = null; 

    try { 
     br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); 
     String line; 
     while((line = br.readLine()) != null){ 
      c.labels.add(line); 
     } 
     br.close(); 
    } catch (IOException e) { 
     throw new RuntimeException("failed reading labels" , e); 
    } 

    c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); 

    final Operation operation = c.inferenceInterface.graphOperation(outputName); 
    final int numClasses = (int)operation.output(0).shape().size(1); 
    Log.d(TAG, "reading " + c.labels.size() + " labels, size of output layers : " + numClasses); 

    c.inputSize = inputSize; 
    c.imageMean = imageMean; 
    c.imageStd = imageStd; 

    c.outputNames = new String[]{outputName}; 
    c.intValues = new int[inputSize * inputSize]; 
    c.floatValues = new float[inputSize * inputSize * 3]; 
    c.outputs = new float[numClasses]; 

    return c; 
} 

@RequiresApi(api = Build.VERSION_CODES.JELLY_BEAN_MR2) 
public List<Recognition> recognizeImage(final Bitmap bitmap){ 
    beginSection("recognizeImage"); 
    beginSection("preprocessBitmap"); 

    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); 
    for(int i = 0; i < intValues.length; i++){ 
     final int val = intValues[i]; 
     floatValues[i*3+0] = (((val >> 16) & 0xFF) - imageMean)/imageStd; 
     floatValues[i*3+1] = (((val >> 8) & 0xFF) - imageMean)/imageStd; 
     floatValues[i*3+2] = ((val & 0xFF) - imageMean)/imageStd; 
    } 
    endSection(); 

    beginSection("feed"); 
    inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); 
    endSection(); 

    beginSection("run"); 
    inferenceInterface.run(outputNames, logStats); 
    endSection(); 

    beginSection("fetch"); 
    inferenceInterface.fetch(outputName, outputs); 
    endSection(); 


    PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(
      3, 
      new Comparator<Recognition>(){ 
       public int compare(Recognition lhs, Recognition rhs){ 
        return Float.compare(rhs.getConfidence(), lhs.getConfidence()); 
       } 
      } 
    ); 

    for(int i = 0; i < outputs.length; ++i){ 
     if(outputs[i] > THRESHOLD){ 
      pq.add(
        new Recognition("" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null)); 
     } 
    } 

    final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); 
    int recognitionSize = Math.min(pq.size(), MAX_RESULTS); 
    for(int i = 0; i < recognitionSize; ++i){ 
     recognitions.add(pq.poll()); 
    } 
    endSection(); 

    return recognitions; 
} 

public void enableStatLogging(boolean logStats){this.logStats = logStats;} 
public String getStatString(){return inferenceInterface.getStatString();} 
public void close(){inferenceInterface.close();} 
} 

これらのコードを修正する方法がわかっている場合は、どうすればいいですか。

答えて

0

私がネイティブライブラリを使って作業していたときに、通常はアセットからファイルを取得しないことに気づいたので、それをアクセス可能なファイルストレージパスにコピーし、絶対パスをライブラリに渡す必要があります。

あなたのエラーは、リソースの読み込みによるものかもしれません。

+0


は、返信用のuをありがとうございます。私はそれを試しましたが、私はアクセス可能なファイルストレージが正確にあなたが言っているか分からない。それが何であるか教えてくれますか? – user7860075

+0

/sdcard/downloadsなどのパブリックスペース用のコピー、またはご希望のもの –

+0

ありがとうございました。やってみます。 – user7860075

0

java.lang.IllegalArgumentExceptionが:入力がなければならない 4次元[1,1,299,299,3]エラーメッセージは、問題を説明

:誤っの代わりに5項目の配列を渡すを4アイテム配列。つまり、おそらく[1,1,299,299,3]の代わりに[1,299,299,1]のようなものを渡す必要があります。

あなたが実際に行ったコードの変更はあなたの質問からは分かりません。単一のGitコミットとして変更を加えることができれば、誰かが問題の原因となっている変更を特定するのが簡単かもしれません。

あなたは、彼らはあなたが設定した値と一致することを確認するために、入力と出力のノードを検査するTensorBoardであなたのTensorFlowモデルを見てみてください: https://medium.com/@daj/how-to-inspect-a-pre-trained-tensorflow-model-5fd2ee79ced0

関連する問題