2016-12-23 3 views
0

最初はタイトルが非常に良いかどうかはわかりませんが、状況を理解していれば最高でした。エンキュー時にテンソルを強制的にコピーする

バックグラウンドは、キューがテンソルフローで動作する方法を理解しようとしていて、次の問題に遭遇して困惑しました。

は私はtf.FIFOQueueにエンキュー可変Nを有し、そして、私は、変数をインクリメントします。これは数回繰り返され、0,1,2、...と同様の結果が期待されます。ただし、キューを空にすると、すべての値が同じになります。

次のようにより正確には、コードは次のとおりです。

私は印刷するだろうと期待しており
from __future__ import print_function 

import tensorflow as tf 

q = tf.FIFOQueue(10, tf.float32) 

n = tf.Variable(0, trainable=False, dtype=tf.float32) 
inc = n.assign(n+1) 
enqueue = q.enqueue(n) 

init = tf.global_variables_initializer() 

sess = tf.Session() 
sess.run(init) 

sess.run(enqueue) 
sess.run(inc) 

sess.run(enqueue) 
sess.run(inc) 

sess.run(enqueue) 
sess.run(inc) 

print(sess.run(q.dequeue())) 
print(sess.run(q.dequeue())) 
print(sess.run(q.dequeue())) 

0.0 
1.0 
2.0 

代わりに、私は次のような結果を得る:

3.0 
3.0 
3.0 

をそれはのように思えます私は実際の値の代わりに、キューにnへのポインタをいくつか押しています。これは私が望むものです。しかし、私は実際にテンソルフローの実際の理解を持っていないので、何か他のことが起こっているのでしょうか?

私はHow can I copy a variable in tensorflowIn TensorFlow, what is tf.identity used for?への回答は私にそれが役立つかもしれないが、それは結果を変更しないという印象を与えますので、

enqueue = q.enqueue(tf.identity(n)) 

enqueue = q.enqueue(n) 

を変えてみました。また、tf.control_dependencies()を追加しようとしましたが、デキュー時にもすべての値が同じです。

編集:上の出力は、単一のCPUを搭載したコンピュータ上でコードを実行したものです。異なるバージョンのテンソルフローに違いがあるかどうかを確認するときに、CPUとGPU私は "期待された"結果を得る。確かに、私がCUDA_VISIBLE_DEVICES = "" CUDA_VISIBLE_DEVICES = "0"で上記の結果を得た場合、 "期待される"結果が得られます。

+0

'q.enqueue(n)'ではなく 'q.enqueue_many([[n]])'という別の回避策は、参照ではなく値でエンキューします –

答えて

3

非キャッシュを強制するためには、あなたが

q.enqueue(tf.add(q, 0)) 

これはコピーを強制的にバッチ正規化層によってcurrently doneを何で行うことができますお読みください。

変数の読み込みと参照のセマンティクスは、直感的ではないように改訂されています。具体的には、私はq.enqueue(v.read_value())が非キャッシュ読み取りを強制することを期待しましたが、TF 0.12rc1であなたの例を修正しません

GPUマシンはGPUに変数を置きます。キューはCPUだけなので、enqueue op GPU →CPUコピー。

0

私が助けてくれる場合は、他の回答が正しいにもかかわらずすべてのdtypeで機能しないことがわかりました。

q.enqueue(tf.add(n, 0)) 

この1は、キューは、異種のタイプ(例えば、int型や山車)でタプルを使用すると失敗します:

nは文字列のテンソルであるとき、例えば

が、これは浮動小数点数または整数型で正常に動作しますが、失敗しましただから、

q.enqueue_many([[n]]) 

、あなた自身がこれらのいずれかの状況に巻き込ま見た場合、代わりにこの方法を試してください。

q.enqueue(tf.add(n, tf.zeros_like(n))) 

あるいは、タプルトンエンキューする:

q.enqueue([tf.add(n, tf.zeros_like(n)) for n in t]) 

でも文字列テンソルと異種のタプルタイプのために働きます。

希望すると助かります!

-

更新:tf.bool種類がtf.zeros_likeで動作しないように見えます()。そのためには、整数型への明示的なキャストが必要な場合があります。

関連する問題