2017-09-07 7 views
0
import tensorflow as tf 

slim = tf.contrib.slim 


def create_learning_rate(curr_step, lr_config): 
    base_lr = lr_config.get('base_lr', 0.1) 
    decay_steps = lr_config.get('decay_steps', []) 
    decay_rate = lr_config.get('decay_rate', 0.1) 

    scale_rates = [ 
     lambda: tf.constant(decay_rate**i, dtype=tf.float32) 
     for i in range(len(decay_steps) + 1) 
    ] 

    conds = [] 
    prev = -1 
    for decay_step in decay_steps: 
     conds.append(tf.logical_and(curr_step > prev, curr_step <= decay_step)) 
     prev = decay_step 
    conds.append(curr_step > decay_steps[-1]) 

    learning_rate_scale = tf.case(
     list(zip(conds, scale_rates)), lambda: 0.0, exclusive=True) 
    return learning_rate_scale * base_lr 


global_step = slim.create_global_step() 
train_op = tf.assign_add(global_step, 1) 
lr = create_learning_rate(
    global_step, {"base_lr": 0.1, 
       "decay_steps": [10, 20], 
       "decay_rate": 0.1}) 

with tf.Session() as sess: 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    for i in range(30): 
     curr_lr, step, _ = sess.run([lr, global_step, train_op]) 
     print(curr_lr, step) 

特定の時刻に学習率を下げたいと思っています。ただし、常に0.001です。何か案は?あるいは、学習率を調整するより良い方法がありますか?テンソルフローの学習率調整

ありがとうございました。

+0

のですか? –

+0

ええ、それは適切な解決策です。 –

答えて

0

これは、lambda関数がvalueではなくreferenceで変数を取り込むためです。

だから、正しい方法では、なぜあなたはそれにしたい任意の値のプレースホルダとして、各反復フィードであなたの学習率を設定していない

def create_learning_rate(global_step, lr_config): 
    base_lr = lr_config.get('base_lr', 0.1) 
    decay_steps = lr_config.get('decay_steps', []) 
    decay_rate = lr_config.get('decay_rate', 0.1) 

    prev = -1 
    scale_rate = 1.0 

    cases = [] 
    for decay_step in decay_steps: 
     cases.append((tf.logical_and(global_step > prev, 
            global_step <= decay_step), 
        lambda v=scale_rate: v)) 
     scale_rate *= decay_rate 
     prev = decay_step 
    cases.append((global_step > decay_step, lambda v=scale_rate: v)) 
    learning_rate_scale = tf.case(cases, lambda: 0.0, exclusive=True) 
    return learning_rate_scale * base_lr 
関連する問題