2017-11-19 4 views
0

Tensorflow Estimator API(tf.estimator)では、model_fnの現在のセッションを使用してテンソルを評価し、その値をPythonに渡す方法はありますか?私はglobal_stepという値に依存してドロップアウトに種を持っていますが、前者はintが必要ですが、後者はtensorです。TF - in model_fn seedをグローバルに渡す

答えて

0

model_fnにアクセスする方法はありません。現在の値はglobal_stepです。

たとえ可能であったとしても、各ステップでtf.nn.dropoutのシードを変更すると、各ステップで異なるシードを持つ新しいグラフ操作が作成され、グラフが大きくなります。 tf.estimatorがなくても、これをどのように実装できるかわかりません。


あなたが望むのは、2回の実行の間に同じランダム性があることを確認することです。 tf.set_random_seed()でグラフレベルのランダムシードを設定するか、またはドロップアウトに通常のseedを使用するだけで、再現可能なマスクシーケンスが作成されるはずです。ここでは、コードと例です。

x = tf.ones(10) 
y = tf.nn.dropout(x, 0.5, seed=42) 

sess1 = tf.Session() 
y1 = sess1.run(y) 
y2 = sess1.run(y) 

sess2 = tf.Session() 
y3 = sess2.run(y) 
y4 = sess2.run(y) 

assert (y1 == y3).all() # y1 and y3 are the same 
assert (y2 == y4).all() # y2 and y4 are the same 

答えhereは、グラフのランダム性を再現可能にする方法の詳細を提供します。

関連する問題