2017-07-14 4 views
2

1d CNNレイヤのウェイトを抽出し、予測値の正確な計算方法を理解したいと思います。私は、get_weights()関数からの重みを使用して予測値を再生成することができません。ケラス:get_weights関数を使用してウェイトを抽出する

私の理解を説明するために、ここに小さなデータセットがあります。

n_filter = 64 
kernel_size = 10 
len_timeseries = 123 
n_feature = 3 
X = np.random.random(sample_size*len_timeseries*n_feature).reshape(sample_size,len_timeseries,n_feature) 
y = np.random.random(sample_size*(len_timeseries-kernel_size+1)*n_filter).reshape(sample_size, 
                        (len_timeseries-kernel_size+1), 
                        n_filter) 

、のような単純な1D CNNモデルを作成します。

model = Sequential() 
model.add(Conv1D(n_filter,kernel_size, 
       input_shape=(len_timeseries,n_feature))) 
model.compile(loss="mse",optimizer="adam") 

モデルをフィットとしてXの値を予測:1000(

model.fit(X,y,nb_epoch=1) 
y_pred = model.predict(X) 

あるy_predの次元を114,64)を必要とする。

ここで、model.layerに格納された重みを使用して、y_pred [irow、0、ilayer]の値を再現したいとします。単一のレイヤーしかないので、len(model.layer)= 1です。だから、第としてのみレイヤからの重みを抽出:

weight = model.layers[0].get_weights() 
print(len(weight)) 
> 2 
weight0 = np.array(weight[0]) 
print(weight0.shape) 
> (10, 1, 3, 64) 
weight1 = np.array(weight[1]) 
print(weight1.shape) 
> (64,) 

重量が長さ2を有し、Iは0番目の位置が特徴の重みを含み、第1の位置にバイアスを含むと仮定する。 weight0.shape =(kernel_size、1、n_feature、n_filter)として、私は[i層、0、irow]私はy_predの値を得ることができることを考えた:

ifilter = 0 
irow = 0 
y_pred_by_hand = weight1[ifilter] + np.sum(weight0[:,0,:,ifilter] * X[irow,:kernel_size,:]) 
y_pred_by_hand 
> 0.5124888777 

しかし、この値はかなり異なっていますy_pred [irow、0、ifilter]から:

y_pred[irow,0,ifilter] 
>0.408206 

私が間違っていた場所を教えてください。

答えて

0

ここではweightsの属性を誤解しています。あなたが探しているのはの結果であるレイヤーのoutput属性です。これは、layer.outputによって得ることができます。典型的には、Layerには入力テンソルが与えられ、使用される層のタイプに依存するマトリックスによって作用される。この計算はあなたが探しているものである出力テンソルを与えます。例えば、入力テンソルAが形状(1,3)であり、出力S字状層がテンソルB(1,1)を放出し、重み行列Wが様々な技術によって初期化される単純な密集層を考える。 Wの形状は、入力形状と出力形状に基づいて決定されます。従ってこの場合、密集層はA matmul Wであり、その結果は予測Bとなる。したがって、Wの形状は、(1,1)の出力形状をもたらすことができる(3,1)としてのみ決定される。あなたが探しているものはBで、アクセスしようとしているのはWです。

+1

あなたの例で "B"を定義する方法を聞いてもよろしいですか? – ThePrincess

+0

あなたは "B"(model.predictによって与えられた結果)を持ちたいと思っています。しかし、私は単純にmodel.predictを使うことができます。私が "B"を必要とする理由は、入力テンソル "X"と一緒にどのようにウェイトを使用して "B"を計算するかを理解することです。私がBを使って計算した方法は、model.predictの出力には同意しないので間違っています。 – ThePrincess

関連する問題