2016-03-22 15 views
0

私はlasagneのConv3DDNNLayerを使用しています。入力次元は(N x 1 x 9 x 9 x 9)です。各9x9x9キューブは分類対象のサンプルです。Theano次元誤差 - ターゲット次元

したがって、(N×1)のターゲットディメンションがあり、各エントリはキューブに対応しています。これにより、エラーが発生しています。

Bad input argument to theano function with name "Conv_Net_1.py:45" at index 1(0-based)', 'Wrong number of dimensions: expected 1, 
got 2 with shape (324640, 1).')´ 

この場合、どのサイズのターゲットを使用する必要がありますか?

11 dtensor5 = TensorType('float32', (False,)*5) 
12 input_var = dtensor5('X_Train') 
13 target_var = T.ivector('Y_train') 
14 
15 X_train, Y_train = DP.data_gen('/home/Upload/Smalls', 9) 

16 print X_train.shape 
17 print Y_train.shape 

18 # Build Neural Network: 
19 input = lasagne.layers.InputLayer((None, 1, 9, 9, 9), input_var=input_var) 
20 
21 l_conv_1 = lasagne.layers.dnn.Conv3DDNNLayer(input, 20, (2,2,2)) 
22 
29 l_hidden1 = lasagne.layers.DenseLayer(l_conv_1, num_units=256,nonlinearity=lasagne.nonlinearities.rectify,W=l asagne.init.HeNormal(gain='relu')) 
30 
31 l_hidden1_dropout = lasagne.layers.DropoutLayer(l_hidden1, p=0.5) 
32 
33 output = lasagne.layers.DenseLayer(l_hidden1_dropout, num_units=2, nonlinearity = lasagne.nonlinearities.soft max) 
34 
35 ## 
36 prediction = lasagne.layers.get_output(output) 
37 loss = T.mean(lasagne.objectives.categorical_crossentropy(prediction, target_var) 
39 
40 # Get list of all trainable parameters in the network. 
41 params = lasagne.layers.get_all_params(output, trainable=True) 
42 updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.3) 
43 
44 ## 
45 train_fn = theano.function([input_var, target_var], loss, updates=updates) 
46 
47 ## 
48 for epoch in range(500): 
49  print('training') 
50  loss = train_fn(X_train, Y_train) 
51  print(loss.type) 
52  print("Epoch %d: Loss %g" % (epoch + 1, loss)) 
53 
54 
55 ## 
56 test_prediction = lasagne.layers.get_output(output, deterministic=True) 
57 predict_fn = theano.function([input_var], T.argmax(test_prediction, axis=1)) 

編集 - 追加されたコード

ありがとう!

+0

コードを投稿することはできますか? "期待1、取得2"あなたのコードで何かを指示する次元ではなく、2つではない – malioboro

+0

私は上記のコードを掲載しました。探している1つのディメンションは、各9x9x9のラベルになります。私はN 9x9x9用のラベルベクトルを持っていますが、一度にすべてにアクセスしようとしています。どのように私はちょうど1つをターゲットすることができますか? – JB1

答えて

0

データが(N、1)ではなく(N、)だったからです。

問題を解決したようです! - 次へ..