2017-10-04 5 views
0

Martin Gornerのビデオを参考にしてTFを使用して深いネットワークを構築しようとしています。私は浅いネットワークの例でいくつかの成功を収めています。深いネットワークの精度は何らかの理由で約98%の精度に達した後に崩壊しています。ピーク値に達した後にネットワークの精度が崩壊する

実装されたネットワークは、5層ネットワークを使用してMNIST数値文字を認識するために使用されます。 10000回の反復で100回分のバッチを練習しています。精度は約98%に達するまで徐々に増加し、次に完全に9.8%に崩壊する。

どのようなアイデアをお願いしますか?

"""Tensor flow character recognition of Numerals""" 
import tensorflow as tf 

from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 

# layer K will have 200 neuron and so on 
K = 200 
L = 100 
M = 60 
N = 30 

# ----- Initialization ----- 
# the None will become the batch size of 100 
# 28 by 28 grayscale images described by a single byte 
X = tf.placeholder(tf.float32, [None, 784]) 

# training will require computing variables W and b 

W1 = tf.Variable(tf.truncated_normal([28*28, K], stddev=0.1)) 
B1 = tf.Variable(tf.zeros([K])) 

W2 = tf.Variable(tf.truncated_normal([K, L], stddev=0.1)) 
B2 = tf.Variable(tf.zeros([L])) 

W3 = tf.Variable(tf.truncated_normal([L, M], stddev=0.1)) 
B3 = tf.Variable(tf.zeros([M])) 

W4 = tf.Variable(tf.truncated_normal([M, N], stddev=0.1)) 
B4 = tf.Variable(tf.zeros([N])) 

W5 = tf.Variable(tf.truncated_normal([N, 10], stddev=0.1)) 
B5 = tf.Variable(tf.zeros([10])) 

init = tf.global_variables_initializer() 

# ----- Model ----- 
# the model Y = WX+b 
# reshape is used to flatted the image into a 1D array of 784 locations 
# -1 is used to tell python to figure the reshape as there's only on solution 
#Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, 784]), W) + b) 

Y1 = tf.nn.relu(tf.matmul(X, W1) + B1) 

Y2 = tf.nn.relu(tf.matmul(Y1, W2) + B2) 

Y3 = tf.nn.relu(tf.matmul(Y2, W3) + B3) 

Y4 = tf.nn.relu(tf.matmul(Y3, W4) + B4) 

Y5 = tf.nn.softmax(tf.matmul(Y4, W5) + B5) 


# placeholder for correct answers 
# e.g. correct answer for 2 will be [0 0 1 0 0 0 0 0 0 0 ] 
Y_ = tf.placeholder(tf.float32, [None, 10]) 

# the loss function 
cross_entropy = tf.reduce_sum(Y_ * tf.log(Y5)) * -1 

# ----- Success Metrics ----- 
# calculate the % of correct answers found in batch 
is_correct = tf.equal(tf.argmax(Y5, 1), tf.argmax(Y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) 

# ----- Training Step ----- 
# pick an optimizer and tell it to minimize the cross entropy loss function 
optimizer = tf.train.GradientDescentOptimizer(0.003) 
train_step = optimizer.minimize(cross_entropy) 

# create the execution session 
sess = tf.Session() 
sess.run(init) 

for i in range(10000): 
    # load a batch of images from mnist 
    batch_X, batch_Y = mnist.train.next_batch(100) 
    train_data = {X: batch_X, Y_: batch_Y} 

    # ----- Execution ----- 
    # train 
    sess.run(train_step, feed_dict=train_data) 
    # test for success 
    a, c = sess.run([accuracy, cross_entropy], feed_dict=train_data) 

    # this is only to display information 
    if i%100 == 0: 

     # check for success on whole data set 
     test_data = {X: mnist.test.images, Y_:mnist.test.labels} 
     a, c = sess.run([accuracy, cross_entropy], feed_dict=test_data) 

     print(a) 
+0

1つのことは、ウェイトに何が起こるかをチェックすることです。それらはゼロに近いか無限近くに反復するのでしょうか? – Roelant

+0

ありがとうございます。実際には、私がrelu6からrelu6にアクティベーション機能を変更したとき、より良い結果が得られました。何が違うのかをチェックし、 –

答えて

0

崩壊する検証セットの精度です。右 ?

だから、あなたは劇的にオーバーフィットするかもしれません。 このような容量/構造のネットワークでは、98%の可能性があります。

関連する問題