テンソルフローv0.11.0を使用しています。私は2つのネットワークを訓練しています:D
とG
。基本的な問題は、私が表示している値を記録することができず、それが混乱の原因になっていることです。 d
とg
の印刷された値は、私が記録しているものとはまったく異なります。テンソルフローの不適切なロギングとサマリー
#Defining Losses
loss_d1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(out_1, 0.9*tf.ones_like(in_1)))
loss_d2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(out_2, tf.zeros_like(in_2)))
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(out_2, tf.ones_like(in_2)))
#combining losses
loss_d = loss_d1 + loss_d2
#creating summaries
loss_d1_sum = tf.scalar_summary("loss_d1", loss_d1)
loss_d2_sum = tf.scalar_summary("loss_d2", loss_d2)
loss_d_sum = tf.scalar_summary("loss_d", loss_d)
loss_g_sum = tf.scalar_summary("loss_g", loss_g)
#defining the optimizers
op_g = tf.train.AdamOptimizer(learning_rate).minimize(loss_g, var_list=vars_g)
op_d = tf.train.AdamOptimizer(learning_rate).minimize(loss_d, var_list=vars_d)
#merging summaries
self.g_sum = tf.merge_summary([loss_g])
self.d_sum = tf.merge_summary([loss_d, loss_d1, loss_d2])
self.writer = tf.train.SummaryWriter('./logs', sess.graph)
私は一度や二度G
D
を更新しています:
# Update D network
_, summary_str = self.sess.run([op_d, d_sum], feed_dict= fd)
self.writer.add_summary(summary_str, counter)
# Update G network
_, summary_str = self.sess.run([op_g, g_sum], feed_dict= fd)
self.writer.add_summary(summary_str, counter)
# Update G network again
_, summary_str = self.sess.run([op_g, g_sum], feed_dict= fd)
self.writer.add_summary(summary_str, counter)
私は、次の2
回実行しています。各反復は1 D
更新と2 G
更新を持ちます。
# Printing stats to stdout
[loss_d, loss_g] = sess.run([loss_d, loss_g], feed_dict = fd)
print 'd_loss:{}, g_loss:{}'.format(loss_d, loss_g)
これは私に出力を提供します:
d_loss: 1.43663359, g_loss: 0.68026298
d_loss: 1.72329867, g_loss: 0.55464995
私は今matplotlib
を使用してプロットするために、次のコードを使用して値を抽出しています。
for entry in tf.train.summary_iterator(log_file)
for v in entry.summary.value:
if v.tag == 'loss_d'
g.append(v.simple_value)
elif v.tag == 'loss_g'
d.append(v.simple_value)
print len(d), len(d) # 2 4
print 'd_loss', d # d_loss [2.196472406387329, 1.200419545173645]
print 'g_loss', g # g_loss [6.2136077880859375, 2.953800678253174, 5.429494857788086, 2.6036014556884766]
ロギング方式を正しく実装していますか?ログ後にopsを実行したので、loss_d
とloss_g
の値が変更されている可能性があります。しかし、それは最後のオペレーションが値を保存しなければならないことを意味し、sess.run
を実行している間は変更してはいけません。 fd
は反復を通じて一定です。