2017-11-15 10 views
0

の文字列の出力を解析する方法:ここでのコードを使用してモデルを作成しtensorflowモデル

public static void main(String[] args) { 
    Session session = SavedModelBundle.load("/Users/gagandeep.malhotra/Documents/SampleTF_projects/tf_iris_model/1510707746/", "serve").session(); 

    Tensor x = 
     Tensor.create(
      new long[] {2, 4}, 
      FloatBuffer.wrap(
       new float[] { 
        6.4f, 3.2f, 4.5f, 1.5f, 
        5.8f, 3.1f, 5.0f, 1.7f 
       })); 

    final String xName = "Placeholder:0"; 
    final String scoresName = "dnn/head/predictions/probabilities:0"; 

    List<Tensor<?>> outputs = session.runner() 
     .feed(xName, x) 
     .fetch(scoresName) 
     .run(); 

    // Outer dimension is batch size; inner dimension is number of classes 
    float[][] scores = new float[2][3]; 

    outputs.get(0).copyTo(scores); 
    System.out.println(Arrays.deepToString(scores)); 
    } 

https://gist.github.com/gaganmalhotra/1424bd3d0617e784976b29d5846b16b1

が、それはコードの下に使用して行うことができるJavaでprobabilitesの予測を取得するには私たちは以下のコードのための予測クラス(文字列型)をコピーしたい場合は:

final String xName = "Placeholder:0"; 
final String className = "dnn/head/predictions/str_classes:0"; 

List<Tensor<?>> outputs = session.runner() 
    .feed(xName, x) 
    .fetch(className) 
    .run(); 

// Outer dimension is batch size; inner dimension is number of classes 
String[][] classes = new String[2][1]; 

outputs.get(0).copyTo(classes); 
System.out.println(Arrays.deepToString(classes)); 

私はERROで終わりますRのような:

Exception in thread "main" java.lang.IllegalArgumentException: cannot copy Tensor with 2 dimensions into an object with 1 
    at org.tensorflow.Tensor.throwExceptionIfTypeIsIncompatible(Tensor.java:739) 
    at org.tensorflow.Tensor.copyTo(Tensor.java:450) 
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:71) 

しかし、寸法が出力テンソルと同じである:[形状のSTRINGテンソル[2、1]]

PS:署名定義は以下のように求めることができる -

試み
The given SavedModel SignatureDef contains the following input(s): 
    inputs['x'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 4) 
     name: Placeholder:0 
    The given SavedModel SignatureDef contains the following output(s): 
    outputs['class_ids'] tensor_info: 
     dtype: DT_INT64 
     shape: (-1, 1) 
     name: dnn/head/predictions/ExpandDims:0 
    outputs['classes'] tensor_info: 
     dtype: DT_STRING 
     shape: (-1, 1) 
     name: dnn/head/predictions/str_classes:0 
    outputs['logits'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 3) 
     name: dnn/head/logits:0 
    outputs['probabilities'] tensor_info: 
     dtype: DT_FLOAT 
     shape: (-1, 3) 
     name: dnn/head/predictions/probabilities:0 
    Method name is: tensorflow/serving/predict 

もの:

テンソルテンソル=(テンソル)outputs.get(0)。 byte [] [] [] result = tensor.copyTo(新しいバイト[2] [1] []);

が、以下のようなエラーが出:

Exception in thread "main" java.lang.IllegalStateException: invalid DataType(7) 
    at org.tensorflow.Tensor.readNDArray(Native Method) 
    at org.tensorflow.Tensor.copyTo(Tensor.java:451) 
    at deeplearning.IrisTFLoad.main(IrisTFLoad.java:74) 

答えて

1

DT_STRING型付けTensorFlowテンソルが含まれているarbitrary byte sequences要素ではなく、JavaのString S(文字列)として。

したがって、あなたがしたいことは、このようなものです:あなたはJavaのStringオブジェクトを取得したい場合は、あなたがあなたのモデルをコードするのクラスを生成しますかを知る必要があります

byte[][][] classes = new byte[2][1][]; 
outputs.get(0).copyTo(classes); 

、および(UTF-8エンコーディングを想定して)次のようにすることができます:

String[][] classesStrings = new String[2][1]; 
for (int i = 0; i < classes.length; ++i) { 
    for (int j = 0; j < classes[i].length; ++j) { 
    classesString[i][j] = new String(classes[i][j], UTF_8); 
    } 
} 

希望することができます。 unittestも参考になるかもしれません。

+0

ありがとうございます@ash –

関連する問題