TensorFlowの操作と変数を効率的にデバイスに割り当てる方法が困惑しています。少なくとも私の基本的な畳み込みニューラルネットワークの実装では、可能な限り多くの操作をGPUに配置することが望ましいことは明らかです。しかし、私は現在へのアクセス権を持っているGPUは、フォームGPU実装の一環として、特定のTensorFlow計算をCPU上で実行するにはどうすればよいですか?
Ran out of memory trying to allocate 2.60GiB. The caller indicates that this is not a failure, but may mean that there could be performance gains if more memory is available.
と
のような特定の操作のために時折クラッシュし、多くの警告に記憶し、その結果を制限していますRan out of memory trying to allocate 83.74MiB. See logs for memory state. Resource exhausted: OOM when allocating tensor with shape[28000,1,28,28]
これは、CPUに変数を配置することで回避できますが、私の実装では、これによりトレーニングのエポックが10回esは計算するのに長い。
明らかに、理想的な方針は、エラーを生成するコードの特定のチャンクを特定し、それらをCPU上に配置しようとすることです。しかし、これらの計算は効率を達成するためにGPUを配置する必要がある他のものから分離できないため、これをどうやって行うのかはわかりません。例えば
、単に
x
は私のモデルへの入力の
tf.placeholder
で、
y
は私のネットワークの出力アクティベーションされ、上記のエラーが生成さ
evals = sess.run(tf.argmax(y, 1), feed_dict={x: use_x_all})
ようなものでテストセットに予測を生成しますuse_x_all
は大きな配列です(ここでは28000
の例)。この計算を単独でCPUに実行しようとすると、恐らくy
を生成するネットワーク評価がGPU上にあるために失敗します。このため
私は(ように見える)
明確にスケールしないuse_x_all, _ = data_loader.stack_data(use_data, as_cols=False)
use_x_split = np.split(use_x_all, splits)
for use_x in use_x_split:
# ... (full example below)
evals_part = sess.run(tf.argmax(y, 1), feed_dict={x: use_x})
# accumulate evals
のようなアプローチに頼る必要があります。
良い方法がありますか?具体的には
- 上記のような計算をCPUに置き、GPU上で実行される同じグラフ(例:トレーニング)の計算はまだありますか?
、あるいは
- より容易にそのような計算のメモリ要求を低減するために、このような状況でも適用することができる(バッチ処理など)イディオムがありますか?
実は、私は後者がTensorFlow APIの一部ではないことを驚いています。上記のようなコードを必要とせずに、デバイスに収まらないautomatically break up calculationsはできませんか?私のコードから
全例:
f = open('{0:s}/{1:s}_{2:3.0f}.csv'.format(FLAGS.pred_dir, FLAGS.session_name,
10000*float(sess.run(accuracy, feed_dict=valid_feed))), 'w')
f.write('ImageId,Label\n')
use_x_all, _ = data_loader.stack_data(use_data, as_cols=False)
use_x_split = np.split(use_x_all, splits)
last = 0
buff = ''
for use_x in use_x_split:
evals = sess.run(tf.argmax(y, 1), feed_dict={x: use_x})
f.write('\n'.join('{0},{1}'.format(r[0]+ last, r[1]) for r in enumerate(evals, start=1)))
last += int(len(use_x_all)/splits)
if last < len(use_x_all):
f.write('\n')
f.close()
TensorFlow [機能要求](https://github.com/tensorflow/tensorflow/issues/2431)とも呼ばれます。 – orome