2016-10-20 13 views
3

KerasとTensorFlowで非同期バージョンのactor-criticを実装しようとしています。私はネットワーク層を構築するフロントエンドとしてKerasを使用しています(テンソルフローで直接パラメータを更新しています)。私はglobal_modelと1つの主要テンソルフローセッションを持っています。しかし、各スレッド内でlocal_modelを作成しており、global_modelからパラメータをコピーしています。私のコードは、私が操作が同じグラフ上になければならないと言ってtf.assign操作上のtensorflowエラーが続くKerasKerasとTensorflowのモデルをマルチスレッド設定で複製する

UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via K.set_session(sess)

からユーザー警告が表示されます。この

def main(args): 
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True) 
    sess = tf.Session(config=config) 
    K.set_session(sess) # K is keras backend 
    global_model = ConvNetA3C(84,84,4,num_actions=3) 

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)] 

    for t in threads: 
     t.start() 

def a3c_thread(i, sess, global_model): 
    K.set_session(sess) # registering a session for each thread (don't know if it matters) 
    local_model = ConvNetA3C(84,84,4,num_actions=3) 
    sync = local_model.get_from(global_model) # I get the error here 

    #in the get_from function I do tf.assign(dest.params[i], src.params[i]) 

ようになります。

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

私は間違っているとは確信していません。 tf.get_default_graph() is sess.graphFalseを返しているため、エラーがKerasから来

おかげ

答えて

5

。 TFドキュメントから、tf.get_default_graph()が現在のスレッドのデフォルトグラフを返すことがわかります。新しいスレッドを開始してグラフを作成すると、そのスレッドに固有の個別のグラフとして構築されます。私はこの問題を次のように解決することができます:

with sess.graph.as_default(): 
    local_model = ConvNetA3C(84,84,4,3) 
関連する問題