2016-11-18 5 views
2

私はtf.PaddingFIFOQueueまたはtf.contrib.data.PaddedBatchDatasetを使用してさまざまな長さのシーケンスを送り、dequeue_manyを使用して、0パッドのバッチを取り出します。PaddingFIFOQueueから動的シーケンスの長さを取得

このバッチのシーケンス長を取得する一般的な方法はありますか?

私の現在の解決策は、系列長をキューに追加入力として明示的に提供することです。すなわち、私はtf.PaddingFIFOQueue(names=["data", "seq_length"], ...)のようなsthを持っています。私もtf.ones_like()を使用することができますが、私の現在の方法は安く簡単です。しかし、それが正規/標準的な方法であるのか、それとも別の方法があるのだろうかと思います。

+0

あなたの質問をいくつかのコードで説明できますか?結果のテンソルの '.get_shape()'があなたの場合にはうまくいかないのはなぜですか? – sygi

+0

@sygi:get_shapeはゼロパディングされているため、(batch、max_length、...)を返します。では、どのように各シーケンスの長さを取得するのですか? – Albert

+0

元の文章の末尾に0がないと仮定できますか? – sygi

答えて

0

dataseq_lengthをタプル(またはリスト)に結合し、そのタプルをキューにプッシュすることができます。

import tensorflow as tf 
sess = tf.InteractiveSession() 
q = tf.PaddingFIFOQueue(capacity=10, dtypes=[tf.int32, tf.int32], shapes=[[], [None]]) 
eq1 = q.enqueue([1, [1]]) 
eq2 = q.enqueue([2, [2,3]]) 
eq3 = q.enqueue([3, [4,5,6]]) 
dq = q.dequeue() 
sess.run(eq1) 
sess.run(eq2) 
sess.run(eq3) 
sess.run(dq) # [1, array([1], dtype=int32)] 
sess.run(dq) # [2, array([2, 3], dtype=int32)] 
sess.run(dq) # [3, array([4, 5, 6], dtype=int32)] 
+0

これは私の現在の解決策とまったく同じです。だから、基本的にあなたの答えは、私がすでに説明したように、より良い方法がないということです。 – Albert

関連する問題