最近私はGANネットワークを研究しています。私はそれをmnisitイメージを生成するために使っています。私のコンピューターの環境はubuntu16.04、tensorflow、python3です。GANからジェネレーターへnnistを使うが、何も勉強しない
コードはエラーなしで実行できますが、結果はネットワーク学習で何も表示されず、トレーニングによって出力画像がノイズの多い画像になります。
私は、入力が784次元のノイズの多いデータであり、隠れたレイヤーを介してそれをルール化し、784次元の画像を生成します。
私は弁別者ネットワークを設計します:入力は実像と偽の画像で、隠れた層を通してそれを支配し、出力は1次元のログです。
次に、generator_lossとdiscriminator_lossを定義して、列車発電機とディスクリミネータを定義しました。実行することはできますが、結果はネットワーク調査で何も表示されず、損失は収束できません。
import tensorflow as tf
import numpy as np
import tensorflow.contrib.slim as slim
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/zyw/data/tensor_mnist-master/MNIST_data/",one_hot=True)
batch_size = 100
G_in = tf.placeholder(tf.float32,[None,784])
G_h1 = tf.layers.dense(G_in, 128)
G_h1 = tf.maximum(0.01 * G_h1, G_h1)
G_out = tf.tanh(tf.layers.dense(G_h1, 784))
real = tf.placeholder(tf.float32,[None,784])
Dl0 = tf.layers.dense(G_out, 128)
Dl0 = tf.maximum(0.01 * Dl0, Dl0)
p0 = tf.layers.dense(Dl0, 1)
Dl1 = tf.layers.dense(real, 128)
Dl1 = tf.maximum(0.01 * Dl1, Dl1)
p1 = tf.layers.dense(Dl1, 1)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits =p0,labels=tf.ones_like(p0)*0.9))
D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits =p1,labels=tf.ones_like(p1)*0.9))
D_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits =p0,labels=tf.zeros_like(p0)))
D_total_loss = tf.add(D_fake_loss,D_real_loss)
G_train = tf.train.AdamOptimizer(0.01).minimize(G_loss)
D_train = tf.train.AdamOptimizer(0.01).minimize(D_total_loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(1000):
mnist_data,_ = mnist.train.next_batch(batch_size)
# noise_org = tf.random_normal([batch_size,784],stddev = 0.1,dtype = tf.float32)
noise_org = np.random.randn(batch_size, 784)
a,b,dloss= sess.run([D_real_loss,D_fake_loss,D_total_loss,G_train,D_train],feed_dict={G_in:noise_org,real:mnist_data})[:3]
if i%100==0:
print(a,b,dloss)
#test_generative_image
noise_org = np.random.randn(1, 784)
image = sess.run(G_out,feed_dict ={G_in:noise_org})
outimage = tf.reshape(image, [28,28])
plt.imshow(outimage.eval(),cmap='gray')
plt.show()
print('ok')
結果は:
0.80509 0.63548 1.44057
0.33512 0.20223 0.53735
0.332536 0.97737 1.30991
0.328048 0.814452 1.1425
0.326688 0.411907 0.738596
0.325864 0.570807 0.896671
0.325575 0.970406 1.29598
0.325421 1.02487 1.35029
0.325222 1.34089 1.66612
0.325217 0.747129 1.07235
あなたは本当に5秒余分に空白や改行を加えたり、文章間の句読点を間違えないようにread.doを書くのが面倒だと書いてありますが、これは痛いです。 – Julien
これは私の最初です私はそれを訂正するためにあなたに感謝のために、良いと私の英語を質問するためにstackoverflowを使用する時間。 – zyw