2017-09-17 12 views
0

私はニューラルネットワークに新しいですし、hereからtensorflowを使用することを学ぶが、私は、コードを実行すると、エラーを与える:tensorflow.python.framework.errors_impl.InvalidArgumentError:互換性のない形状:[10]対[10000]

次のように
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [10] vs. [10000] 
[[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, _arg_Placeholder_2_0_2)]] 

私のコードは次のとおりです。

import tensorflow as tf 
import numpy as np 
from tensorflow.examples.tutorials.mnist import input_data 

data = input_data.read_data_sets("data/MNIST/", one_hot=True) 
data.test.cls = np.array([label.argmax() for label in data.test.labels]) 
img_size = 28 
img_size_flat = img_size * img_size 
img_shape = (img_size, img_size) 
num_classes = 10 

x = tf.placeholder(tf.float32, [None, img_size_flat]) 
y_true = tf.placeholder(tf.float32, [None, num_classes]) 
y_true_cls = tf.placeholder(tf.int64, [None]) 

weights = tf.Variable(tf.zeros([img_size_flat, num_classes])) 
biases = tf.Variable(tf.zeros([num_classes])) 

logits = tf.matmul(x, weights) + biases 
y_pred = tf.nn.softmax(logits) 
y_pred_cls = tf.argmax(y_pred) 

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_true) 
cost = tf.reduce_mean(cross_entropy) 
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(cost) 

correct_prediction = tf.equal(y_pred_cls, y_true_cls) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

session = tf.Session() 
session.run(tf.global_variables_initializer()) 
batch_size = 100 

feed_dict_test = {x: data.test.images, 
        y_true: data.test.labels, 
        y_true_cls: data.test.cls} 

def print_accuracy(): 
    acc = session.run(accuracy, feed_dict=feed_dict_test) 
    print("Accuracy on test-set: {0:.1%}".format(acc)) 

print_accuracy() 

誰かが、私はこのエラーを取得していますし、どのようにこの問題を解決するために理由を説明してくださいことはできますか?

答えて

0

私は解決策を見つけ、問題は、私はy_pred_clsのための次元を合格していないということでしたので、その形状は(10)だったが、それがために必要なは(?、10)ので、私はy_pred_cls = tf.argmax(y_pred, axis=1)

y_pred_cls = tf.argmax(y_pred)を更新
関連する問題