set_weightsメソッドを使用して、畳み込みレイヤーにnumpy配列を渡す必要があります。
畳み込みレイヤの重みは、個々のフィルタの重みだけでなく、バイアスでもあることに注意してください。だからあなたがあなたの体重を設定したい場合は、余分な次元を追加する必要があります。
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
をし、それを設定します。あなたはすべての重みが中心的な要素を除いてゼロで1x3x3フィルタを設定したい場合は
たとえば、あなたがそれを作る必要があります。あなたが実行できるコードの場合
:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
np.random.seed(1234)
from keras.layers import Input
from keras.layers.convolutional import Convolution2D
from keras.models import Model
print("Building Model...")
inp = Input(shape=(1,None,None))
output = Convolution2D(1, 3, 3, border_mode='same', init='normal',bias=False)(inp)
model_network = Model(input=inp, output=output)
print("Weights before change:")
print (model_network.layers[1].get_weights())
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
input_mat = np.asarray([
[[
[1.,2.,3.],
[4.,5.,6.],
[7.,8.,9.]
]]
])
model_network.layers[1].set_weights(w)
print("Weights after change:")
print(model_network.layers[1].get_weights())
print("Input:")
print(input_mat)
print("Output:")
print(model_network.predict(input_mat))
は、畳み込みfillter(例では2)の中心的要素を変更してみてください。
コードの内容:
最初にモデルを作成します。
inp = Input(shape=(1,None,None))
output = Convolution2D(1, 3, 3, border_mode='same', init='normal',bias=False)(inp)
model_network = Model(input=inp, output=output)
印刷元の重みは
w = np.asarray([
[[[
[0,0,0],
[0,2,0],
[0,0,0]
]]]
])
input_mat = np.asarray([
[[
[1.,2.,3.],
[4.,5.,6.],
[7.,8.,9.]
]]
])
あなたの重みを設定input_mat
print (model_network.layers[1].get_weights())
希望重みWテンソルと、いくつかの入力を作成し(INIT = '正常' は、正規分布で初期化)それらを印刷します。
model_network.layers[1].set_weights(w)
print("Weights after change:")
print(model_network.layers[1].get_weights())
最後に、(自動的にモデルのコンパイル予測)を予測して出力を生成するためにそれを使用
print(model_network.predict(input_mat))
出力例:
Using Theano backend.
Building Model...
Weights before change:
[array([[[[ 0.02357176, -0.05954878, 0.07163535],
[-0.01563259, -0.03602944, 0.04435815],
[ 0.04297942, -0.03182618, 0.00078482]]]], dtype=float32)]
Weights after change:
[array([[[[ 0., 0., 0.],
[ 0., 2., 0.],
[ 0., 0., 0.]]]], dtype=float32)]
Input:
[[[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]]]]
Output:
[[[[ 2. 4. 6.]
[ 8. 10. 12.]
[ 14. 16. 18.]]]]
ああ、ありがとう!それは私には分かりませんでした。ドキュメントには、ウェイトの形の正確な要件は明記されていませんでした。あなたの例をありがとう! – displayname
参考:https://github.com/fchollet/keras/issues/1671 – maz