2017-08-02 5 views
0

TensorArrayの使用に関連する質問があります。Tensorflowのtf.while_loopでTensorArray要素にアクセスするには?

問題:
私はtf.while_loopでTensorArrayのアクセス要素をしたいと思います。私はTensorArrayの内容を例えばu1.read(0)を使って読むことができます。

私の現在のコード:ここで
私がこれまで持っているものです。

embeds_raw = tf.constant(np.array([ 
    [1, 1], 
    [1, 1], 
    [2, 2], 
    [3, 3], 
    [3, 3], 
    [3, 3] 
], dtype='float32')) 
embeds = tf.Variable(initial_value=embeds_raw) 
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable') 
sen_len = tf.placeholder('int32', shape=[None], name='sen_len') 
# max_l = tf.reduce_max(sen_len) 
current_size = tf.shape(sen_len)[0] 
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT') 
added_container_variable = tf.add(container_variable, padded_sen_len) 
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False) 
u1 = u1.split(embeds, added_container_variable) 

sentences = [] 
i = 0 

def condition(_i, _t_array): 
    return tf.less(_i, current_size) 

def body(_i, _t_array): 
    sentences.append(_t_array.read(_i)) 
    return _i + 1, _t_array 

idx, arr = tf.while_loop(condition, body, [i, u1]) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) 
    print(sents) 

エラーメッセージ:

Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 267, in init fetch, allow_tensor=True, allow_operation=True)) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2584, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2673, in _as_graph_element_locked % (type(obj).name, types_str)) TypeError: Can not convert a TensorArray into a Tensor or Operation.

上記の例外の取り扱い時には、別の例外が発生しました:

Traceback (most recent call last): File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 191, in main() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 187, in main variable_container() File "/home/ultimateai/Honain/new/ultimateai/exercises/dynamic_reshape.py", line 179, in variable_container sents = sess.run(arr, feed_dict={sen_len: [2, 1, 3]}) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 984, in _run self._graph, fetches, feed_dict_string, feed_handles=feed_handles) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 410, in init self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 238, in for_fetch return _ElementFetchMapper(fetches, contraction_fn) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 271, in init % (fetch, type(fetch), str(e))) TypeError: Fetch argument has invalid type , must be a string or Tensor. (Can not convert a TensorArray into a Tensor or Operation.)

答えて

1

私はコメントする評判が十分ではないので、私は答えを書くでしょう。

あなたのコードが何を意図しているのかよく分かりませんが、sess.run()はTensorを返しますが、arrTensorArrayです。たとえば、次のようにすることができます。

sents = sess.run(arr.concat(), feed_dict={sen_len: [2, 1, 3]}) 

もちろん、分割を元に戻すだけです。あなたは多分、すべての値を取得したい場合:

sents = sess.run([arr.read(i) for i in range(512)], feed_dict={sen_len: [2, 1, 3]}) 

をしかし、私は確信している512をハードコーディングすると、おそらくあなたのWHILE_LOOPが何かをすることを意図しているよりもきれいな方法がなければなりません。

関連する問題