2017-03-25 13 views
0

私はテンソルフローを使い始めました。私はMNISTデータセットのチュートリアルの例に従いました。それはうまくいった、私は約90%の精度のようになった。テンソルフローmnist my get_next_minibatchの例

しかし、next_batchを自分のバージョンに置き換えた場合、結果はそれまでの50%より悪くなりました。

ダウンロードして解析したデータTensorflowを使用する代わりに、このwebsiteからデータセットをダウンロードします。 numpyを使って私が望むものを得る。

df = pd.read_csv('mnist_train.csv', header=None) 
X = df.drop(0,1) 
Y = df[0] 
temp = np.zeros((Y.size, Y.max()+1)) 
temp[np.arange(Y.size),Y] = 1 
np.save('X',X) 
np.save('Y',temp) 

その後、チュートリアル以下、何も右ここ

x = tf.placeholder(tf.float32, shape=[None, 784]) 
y_ = tf.placeholder(tf.float32, shape=[None, 10]) 
X = np.load('X.npy') 
Y = np.load('Y.npy') 
X_test = np.load('X_test.npy') 
Y_test = np.load('Y_test.npy') 
BATCHES = 1000 


W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1)) 

# W = tf.Variable(tf.zeros([784,10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x, W) + b) 

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 
train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy) 

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 


sess = tf.InteractiveSession() 
tf.global_variables_initializer().run() 

に変更されていない、テストデータに同じことを行う私自身get_mini_batchで、私はその後、私が取得するたびに、元のデータのインデックスをシャッフル100個のデータがありますが、これはサンプルコードとまったく同じようです。唯一の違いは、データの一部を捨て去ることです。

pos = 0 
idx = np.arange(X.shape[0]) 
np.random.shuffle(idx) 


for _ in range(1000): 
    batch_xs, batch_ys = X[idx[range(pos,pos+BATCHES)],:], Y[idx[range(pos,pos+BATCHES)],] 
    if pos+BATCHES >= X.shape[0]: 
     pos = 0 
     idx = np.arange(X.shape[0]) 
     np.random.shuffle(idx) 
    pos += BATCHES 
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 
print(sess.run(accuracy, feed_dict={x: X_test, y_: Y_test})) 

私のバージョンは、チュートリアル1より道悪化している理由は私を混乱させる。

+0

あなたのデータからX [0]をプリントアウトします。数字は0-255ですか? – lejlot

+0

X [0]はヘッダではなく、最初のデータインスタンスを含んでいます – 7d9af0aec9

+1

ヘッダについては考えていませんでしたが、MNISTは通常[0、1]に値を持つように正規化されています。 "0-255の場合、255で除算したい場合があります。ネットワークに接続する前に、パフォーマンスが非常に悪いのは、 – lejlot

答えて

0

lejilotのように、データをニューラルネットワークにプッシュする前に正規化する必要があります。 See this post

関連する問題