2016-05-09 5 views
0

Keras Functional APIを使用して、単純なマルチ入力マルチ出力ネットワークを実現しています。しかし、いくつかのエラーが私に来て、私はそれを解決する方法を見つけることができません。ここ は、コードは次のとおりです。Kerasアサーションエラー

import numpy as np 
from keras.layers import Dense, Activation, Input, merge, Lambda 
from keras.models import Model 
from keras.optimizers import SGD 

def get_half_1(nparray): 
    return nparray[:,:5] 
def get_half_2(nparray): 
    return nparray[:,5:] 

train_x = np.random.uniform(0.0,1.0,size=(50,12)) 
train_y = np.random.uniform(0.0,1.0,(50,8)) 

x_row, x_col = train_x.shape 
y_row, y_col = train_y.shape 

x_input = Input(shape=(x_row,), name='x_input') 
y_input = Input(shape=(y_row,), name='y_input') 

x_hidden = Dense(5,activation='sigmoid')(x_input) 
y_hidden = Dense(5,activation='sigmoid')(y_input) 

# merge two layers 
com_x = merge([x_hidden, y_hidden],mode='concat') 

feature_layer = Dense(10, activation='sigmoid')(com_x) 

# decoding 
com_x_transpose = Dense(10,activation='sigmoid')(feature_layer) 

x_hidden_transpose = Lambda(get_half_1,output_shape=(50,5)) (com_x_transpose) 
y_hidden_transpose = Lambda(get_half_2,output_shape=(50,5))(com_x_transpose) 

x_recon_error = Dense(12,activation='sigmoid')(x_hidden_transpose) 
y_recon_error = Dense(8,activation='sigmoid')(y_hidden_transpose) 
# 
model = Model(input=[x_input, y_input],output=[x_recon_error, y_recon_error]) 


model.compile(optimizer='rmsprop',loss='mean_square_error') 

model.fit(train_x, train_y,nb_epoch=50,batch_size=50) 

私はのpython3でこのコードを実行すると、私は次のエラーを取得:

Traceback (most recent call last): 
    File "splittest.py", line 35, in <module> 
    x_recon_error = Dense(12,activation='sigmoid')(x_hidden_transpose) 
    File "/Users/lw/Library/Python/3.5/lib/python/site- packages/keras/engine/topology.py", line 458, in __call__ 
    self.build(input_shapes[0]) 
    File "/Users/lw/Library/Python/3.5/lib/python/site-packages/keras/layers/core.py", line 583, in build 
    assert len(input_shape) == 2 
AssertionError 
+0

x_input = Input(shape=(x_row,), name='x_input') y_input = Input(shape=(y_row,), name='y_input') 

を変更しますそれは非常に単純な例です。私は可変次元についていくつかの間違いをしていますか? –

答えて

0

は単に

x_input = Input(shape=train_x.shape, name='x_input') 
y_input = Input(shape=train_y.shape, name='y_input') 
関連する問題