2017-07-15 17 views
0

DynamicPartition操作を使用してベクトル[1,2,3,4,5,6]を[1、2、3]と[4]の2つのベクトルで分割するコードを作成しました。 、5、6]マスク[1、1、1、0、0を使用して、0]:DynamicPartitionは複数の代わりに単一の出力を返します

@Test 
public void dynamicPartition2() { 
    Graph graph = new Graph(); 

    Output a = graph.opBuilder("Const", "a") 
      .setAttr("dtype", DataType.INT64) 
      .setAttr("value", Tensor.create(new long[]{6}, LongBuffer.wrap(new long[] {1, 2, 3, 4, 5, 6}))) 
      .build().output(0); 

    Output partitions = graph.opBuilder("Const", "partitions") 
      .setAttr("dtype", DataType.INT32) 
      .setAttr("value", Tensor.create(new long[]{6}, IntBuffer.wrap(new int[] {1, 1, 1, 0, 0, 0}))) 
      .build().output(0); 

    graph.opBuilder("DynamicPartition", "result") 
      .addInput(a) 
      .addInput(partitions) 
      .setAttr("num_partitions", 2) 
      .build().output(0); 

    try (Session s = new Session(graph)) { 
     List<Tensor> outputs = s.runner().fetch("result").run(); 

     try (Tensor output = outputs.get(0)) { 
      LongBuffer result = LongBuffer.allocate(3); 
      output.writeTo(result); 

      assertArrayEquals("Shape", new long[]{3}, output.shape()); 
      assertArrayEquals("Values", new long[]{4, 5, 6}, result.array()); 
     } 

     //Test will fail here 
     try (Tensor output = outputs.get(1)) { 
      LongBuffer result = LongBuffer.allocate(3); 
      output.writeTo(result); 

      assertArrayEquals("Shape", new long[]{3}, output.shape()); 
      assertArrayEquals("Values", new long[]{1, 2, 3}, result.array()); 
     } 
    } 
} 

長さ1のs.runner().fetch("result").run()リストを呼び出した後の値[4、5と戻され、6]。私のグラフは1つの出力しか生成していないようです。

分割されたベクターの残りの部分を取得するにはどうすればよいですか?

+0

これはJavaの場合のみ必要ですか、またはPythonの答えで十分でしょうか? –

+0

すべての回答は歓迎です – Aeteros

+0

私の答えは何か説明されていますか? –

答えて

1

DynamicPartition操作では、複数の出力(パーティションごとに1つ)が返されますが、Session.Runner.fetchコールは0番目の出力のみを要求します。

Java APIには、Python APIに備わっている便利な砂糖が欠けていますが、すべての出力を明示的に要求することで、必要な機能を実行できます。

List<Tensor> outputs = s.runner().fetch("result").run(); 

役立ちます

List<Tensor> outputs = s.runner().fetch("result", 0).fetch("result", 1).run(); 

希望に:他の言葉では、からに変更。

+0

ありがとう、それは解決策です – Aeteros

0

Javaについてはわかりませんが(私にはわかりませんし、調査する環境もありません)、Pythonではすべて正常に動作します。たとえば、この

import tensorflow as tf 
a = tf.constant([1, 2, 3, 4, 5, 6]) 
b = tf.constant([1, 1, 1, 0, 0, 0]) 
c = tf.dynamic_partition(a, b, 2) 
with tf.Session() as sess: 
    v1, v2 = sess.run(c) 
    print v1 
    print v2 

は、正しいパーティションを返します。

関連する問題