0
import tensorflow as tf
slim = tf.contrib.slim
def create_learning_rate(curr_step, lr_config):
base_lr = lr_config.get('base_lr', 0.1)
decay_steps = lr_config.get('decay_steps', [])
decay_rate = lr_config.get('decay_rate', 0.1)
scale_rates = [
lambda: tf.constant(decay_rate**i, dtype=tf.float32)
for i in range(len(decay_steps) + 1)
]
conds = []
prev = -1
for decay_step in decay_steps:
conds.append(tf.logical_and(curr_step > prev, curr_step <= decay_step))
prev = decay_step
conds.append(curr_step > decay_steps[-1])
learning_rate_scale = tf.case(
list(zip(conds, scale_rates)), lambda: 0.0, exclusive=True)
return learning_rate_scale * base_lr
global_step = slim.create_global_step()
train_op = tf.assign_add(global_step, 1)
lr = create_learning_rate(
global_step, {"base_lr": 0.1,
"decay_steps": [10, 20],
"decay_rate": 0.1})
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(30):
curr_lr, step, _ = sess.run([lr, global_step, train_op])
print(curr_lr, step)
特定の時刻に学習率を下げたいと思っています。ただし、常に0.001です。何か案は?あるいは、学習率を調整するより良い方法がありますか?テンソルフローの学習率調整
ありがとうございました。
のですか? –
ええ、それは適切な解決策です。 –