2017-05-28 21 views
3

保存されたモデルを読み込もうとしているときに、次のエラーが表示されます。KeyError: 'state_dictの'予期しないキー 'module.encoder.embedding.weight'

KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'

これは私が保存されたモデルをロードするために使用しています機能です。

def load_model_states(model, tag): 
    """Load a previously saved model states.""" 
    filename = os.path.join(args.save_path, tag) 
    with open(filename, 'rb') as f: 
     model.load_state_dict(torch.load(f)) 

このモデルは、init関数(コンストラクタ)が以下に示されているシーケンス間ネットワークです。

def __init__(self, dictionary, embedding_index, max_sent_length, args): 
    """"Constructor of the class.""" 
    super(Sequence2Sequence, self).__init__() 
    self.dictionary = dictionary 
    self.embedding_index = embedding_index 
    self.config = args 
    self.encoder = Encoder(len(self.dictionary), self.config) 
    self.decoder = AttentionDecoder(len(self.dictionary), max_sent_length, self.config) 
    self.criterion = nn.NLLLoss() # Negative log-likelihood loss 

    # Initializing the weight parameters for the embedding layer in the encoder. 
    self.encoder.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize) 

私がモデル(配列間ネットワーク)を印刷すると、次のようになります。

Sequence2Sequence (
    (encoder): Encoder (
    (drop): Dropout (p = 0.25) 
    (embedding): Embedding(43723, 300) 
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25) 
) 
    (decoder): AttentionDecoder (
    (embedding): Embedding(43723, 300) 
    (attn): Linear (600 -> 12) 
    (attn_combine): Linear (600 -> 300) 
    (drop): Dropout (p = 0.25) 
    (out): Linear (300 -> 43723) 
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25) 
) 
    (criterion): NLLLoss (
) 
) 

そこで、module.encoder.embeddingは埋め込み層であり、module.encoder.embedding.weightは、関連する重み行列を表します。だから、それはなぜ - unexpected key "module.encoder.embedding.weight" in state_dict

答えて

6

私はこの問題を解決しました。実際には、モデルをモデルに保存しているnn.DataParallelを使用してモデルを保存していましたが、それをDataParallelなしでロードしようとしていました。したがって、ロード目的のために私のネットワークにnn.DataParallelを一時的に追加する必要があります。または、重みファイルをロードして、モジュール接頭辞なしで新しい順序付き辞書を作成し、ロードすることができます。

2番目の回避策は次のようになります。

# original saved file with DataParallel 
state_dict = torch.load('myfile.pth.tar') 
# create new OrderedDict that does not contain `module.` 
from collections import OrderedDict 
new_state_dict = OrderedDict() 
for k, v in state_dict.items(): 
    name = k[7:] # remove `module.` 
    new_state_dict[name] = v 
# load params 
model.load_state_dict(new_state_dict) 

参考:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686

関連する問題