2016-09-02 15 views
0

six.moves.cPickleを使用してオブジェクトを保存し、Pythonで読み込むことができます。Theanoの多層ネットワークのパラメータを保存してリセットする

次のコードを使用して、LeNetのパラメータを保存してリセットしました。

# save model 
# params = layer3.params + layer2.params + layer1.params + layer0.params 
import six.moves.cPickle as pickle 
f = file('best_cnnmodel.save', 'wb') 
pickle.dump(params, f, protocol=pickle.HIGHEST_PROTOCOL) 
f.close() 

# reset parameters 
model_file = file('best_cnnmodel.save', 'rb') 
params = pickle.load(model_file) 
model_file.close() 
layer3.W.set_value(params[0].get_value()) 
layer3.b.set_value(params[1].get_value()) 
layer2.W.set_value(params[2].get_value()) 
layer2.b.set_value(params[3].get_value()) 
layer1.W.set_value(params[4].get_value()) 
layer1.b.set_value(params[5].get_value()) 
layer0.W.set_value(params[6].get_value()) 
layer0.b.set_value(params[7].get_value()) 

コードはLeNetにとっては問題ありません。しかし、それはエレガントではありません。ディープネットワークの場合、このコードを使用してモデルを保存することはできません。この場合、私は何ができますか?

答えて

0

json形式の使用を検討できます。それは人間が読むことができ、使いやすいものです。ここで

は一例です:

import json 


data = { 
    'L1' : { 'W': layer1.W, 'b': layer1.b }, 
    'L2' : { 'W': layer2.W, 'b': layer2.b }, 
    'L3' : { 'W': layer3.W, 'b': layer3.b }, 
} 
json_data = json.dumps(data) 

json_dataデータを準備し、次のようになります

{"L2": {"b": 2, "W": 17}, "L3": {"b": 2, "W": 10}, "L1": {"b": 2, "W": 1}} 

は、データ

params = json.loads(json_data) 

for k, v in params.items(): 
    level = int(k[1:]) 
    # assume you save the layer in an array, but you can use 
    # different way to store and reference the layers 
    layer = layers[level] 
    layer.W = v['W'] 
    layer.b = v['b'] 
を開梱