2017-04-07 3 views
1

map_fn関数を使って遊んでいて、TensorArrayが出力されていることに気付きました。これは、テンソルの「ギザギザ」(内部のテンソルが異なる第1次元を持つ)tensorflow map_fn TensorArrayの形が矛盾しています

私はこのコードでこの動作を確認してみました:

import tensorflow as tf 
import numpy as np 

NUM_ARRAYS = 1000 
MAX_LENGTH = 1000 

lengths = tf.placeholder(tf.int32) 
tArray = tf.map_fn(lambda x: tf.random_normal((x,), 0, 1), 
        lengths, 
        dtype=tf.float32) # Should return a TensorArray. 

# startTensor = tf.random_normal((tf.reduce_sum(lengths),), 0, 1) 
# tArray = tf.TensorArray(tf.float32, NUM_ARRAYS) 
# tArray = tArray.split(startTensor, lengths) 
# outArray = tArray.concat() 


with tf.Session() as sess: 
    outputArray, l = sess.run(
     [tArray, lengths], 
     feed_dict={lengths: np.random.randint(MAX_LENGTH, size=NUM_ARRAYS)}) 
    print outputArray.shape, l 

は、しかし、エラーが発生しました:

「TensorArrayは、一貫性のない形状を持つインデックス0は形状をしている:[259]が、インデックス1は形状をしています:[773] "

私はTensorArraysがそれを処理できるはずであるという印象を受けているので、これは当然驚きです。私が間違っている? tf.map_fn()が一緒に要素を積み重ねることによってtf.Tensorに戻ってそのtf.TensorArray結果を変換しているため、そのまま

答えて

4

tf.map_fn()ながらtf.TensorArrayオブジェクト内部を使用しない、とtf.TensorArrayは、異なるサイズのオブジェクトを保持することができ、このプログラムは動作しません、この操作は失敗します。

ただしtf.TensorArrayではなく、下位レバーtf.while_loop()オペアンプを使用してベースの実装できます

lengths = tf.placeholder(tf.int32) 
num_elems = tf.shape(lengths)[0] 
init_array = tf.TensorArray(tf.float32, size=num_elems) 

def loop_body(i, ta): 
    return i + 1, ta.write(i, tf.random_normal((lengths[i],), 0, 1)) 

_, result_array = tf.while_loop(
    lambda i, ta: i < num_elems, loop_body, [0, init_array]) 
関連する問題