2017-08-16 7 views
0

私はテンソルの流れの中で自分のやり方を学ぶために、いくつかの既存の行列処理ロジックを変換しています。ステップの1つは、以下の例ではscatter_addを使用するなどの散布操作です。この例の私の問題は、操作が評価されるたびに、前の結果の上に累積的に加算されることです。以下に示すように3ラン()の結果が印刷され、呼び出しである:私が何をしたいのに対し非累積テンソルフローscatter_addを実行するにはどうすればよいですか?

[[8 12 8]...] 
[[16 24 16]...] 
[[24 36 24]...] 

[[8 12 8]...]毎回です。 indicesベクターには重複が含まれており、updatesの対応する要素は一緒に追加する必要がありますが、すでにscatteredに保持されている既存の値には追加する必要はありません。

テンソルフローの散布操作は、私が探しているものではありません。使用する適切な操作はありますか?そうでない場合は、私が必要とするものを達成する最良の方法は何ですか?

import tensorflow as tf 

indices = tf.constant([0, 1, 0, 1, 0, 1, 0, 1], tf.int32) 

updates = tf.constant([ 
      [1., 2., 3., 4.], 
      [2., 3., 4., 1.], 
      [3., 4., 1., 2.], 
      [4., 1., 2., 3.], 
      [1., 2., 3., 4.], 
      [2., 3., 4., 1.], 
      [3., 4., 1., 2.], 
      [4., 1., 2., 3.]], tf.float32) 

scattered = tf.Variable([ 
      [0., 0., 0., 0.,], 
      [0., 0., 0., 0.,]], tf.float32) 

# Requirement: 
# scattered[i, j] = Sum of updates[k, j] where indices[k] == i 
# 
# i.e. 
# scattered_data = [ 
#  [1+3+1+3, 2+4+2+4, 3+1+3+1, 4+2+4+2], 
#  [2+4+2+4, 3+1+3+1, 4+2+4+2, 1+3+1+3]] 
# == [ 
#  [ 8, 12, 8, 12], 
#  [12, 8, 12, 8]] 

scattered = tf.scatter_add(scattered, indices, updates, use_locking=True, name='scattered') 
scattered_print = tf.Print(scattered, [scattered]) 

sess = tf.InteractiveSession() 
sess.run(tf.global_variables_initializer()) 
sess.run(scattered_print) 
# Printout: [[8 12 8]...] 
sess.run(scattered_print) 
# Printout: [[16 24 16]...] 
sess.run(scattered_print) 
# Printout: [[24 36 24]...] 
sess.close() 

答えて

0

Scatter_add可変参照への更新。

tf.matmul(tf.cast(tf.concat([indices[tf.newaxis,...], 1-indices[tf.newaxis,...]], axis=0), tf.float32),updates) 
+0

返信いただきありがとうございます。しかし、表示されているコードは実際の問題のおもちゃバージョンです(もちろん)。実際には、indexの値は0から約15,000の範囲であり、その長さは10,000,000のオーダーです。 32ビットのエントリを持つ行列による乗算に変換するには、約600GBの中間行列を作成する必要があります。私はこれをGPUに入れようとしています。既存のインプリメンテーションでは、インデックスエントリあたり16ビットが使用されるため、インデックス配列は約20MBになります。 –

0

scatter_addコールに以下の変更が必要に応じて、物事を動作させるために表示されます:代わりに、以下に示すように、あなたが何かを行うことができますので、

with tf.control_dependencies([scattered.initializer]): 
    scattered = tf.scatter_add(scattered, indices, updates, use_locking=True, name='scattered') 

根拠は以来、私はゼロで変数を初期化するということですすべてのscatter_add操作が実行される前にイニシャライザを強制的に再実行し、累積的な更新を回避します。

これは私にとってやや厄介なように思えます。私はそこにシングルコールソリューションがあると思います。そして、不必要なメモリ割り当てや解放がどれだけ進んでいるのかよく分かりませんが、それは後で調べるべきことです。

関連する問題