KerasとTensorFlowで非同期バージョンのactor-criticを実装しようとしています。私はネットワーク層を構築するフロントエンドとしてKerasを使用しています(テンソルフローで直接パラメータを更新しています)。私はglobal_model
と1つの主要テンソルフローセッションを持っています。しかし、各スレッド内でlocal_model
を作成しており、global_model
からパラメータをコピーしています。私のコードは、私が操作が同じグラフ上になければならないと言ってtf.assign
操作上のtensorflowエラーが続くKerasKerasとTensorflowのモデルをマルチスレッド設定で複製する
UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via
K.set_session(sess)
からユーザー警告が表示されます。この
def main(args):
config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
sess = tf.Session(config=config)
K.set_session(sess) # K is keras backend
global_model = ConvNetA3C(84,84,4,num_actions=3)
threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]
for t in threads:
t.start()
def a3c_thread(i, sess, global_model):
K.set_session(sess) # registering a session for each thread (don't know if it matters)
local_model = ConvNetA3C(84,84,4,num_actions=3)
sync = local_model.get_from(global_model) # I get the error here
#in the get_from function I do tf.assign(dest.params[i], src.params[i])
ようになります。
ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)
私は間違っているとは確信していません。 tf.get_default_graph() is sess.graph
はFalse
を返しているため、エラーがKerasから来
おかげ