0
six.moves.cPickle
を使用してオブジェクトを保存し、Pythonで読み込むことができます。Theanoの多層ネットワークのパラメータを保存してリセットする
次のコードを使用して、LeNetのパラメータを保存してリセットしました。
# save model
# params = layer3.params + layer2.params + layer1.params + layer0.params
import six.moves.cPickle as pickle
f = file('best_cnnmodel.save', 'wb')
pickle.dump(params, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()
# reset parameters
model_file = file('best_cnnmodel.save', 'rb')
params = pickle.load(model_file)
model_file.close()
layer3.W.set_value(params[0].get_value())
layer3.b.set_value(params[1].get_value())
layer2.W.set_value(params[2].get_value())
layer2.b.set_value(params[3].get_value())
layer1.W.set_value(params[4].get_value())
layer1.b.set_value(params[5].get_value())
layer0.W.set_value(params[6].get_value())
layer0.b.set_value(params[7].get_value())
コードはLeNetにとっては問題ありません。しかし、それはエレガントではありません。ディープネットワークの場合、このコードを使用してモデルを保存することはできません。この場合、私は何ができますか?