2016-12-20 10 views
2

私はprintステートメントをセットアップしましたが、RNNに最初のバッチを送るときには埋め込みは存在しますが、次のエラーを取得する:ここでテンソルフロー埋め込みは最初のRNNの例の後には存在しません

ValueError: Variable RNNLM/RNNLM/Embedding/Adam_2/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

は埋め込みを生成するための私のコードです:

def add_embedding(self): 
    with tf.device('/gpu:0'): 
     embedding = tf.get_variable("Embedding", [len(self.vocab), self.config.embed_size]) 
     e_x = tf.nn.embedding_lookup(embedding, self.input_placeholder) 
     inputs = [tf.squeeze(s, [1]) for s in tf.split(1, self.config.num_steps, e_x)] 
     return inputs 

ここで私は問題が

に位置疑うところこれは、モデルがseutpある方法ですここではここ

def __init__(self, config): 
    self.config = config 
    self.load_data(debug=False) 
    self.add_placeholders() 
    self.inputs = self.add_embedding() 
    self.rnn_outputs = self.add_model(self.inputs) 
    self.outputs = self.add_projection(self.rnn_outputs) 
    self.predictions = [tf.nn.softmax(tf.cast(o, 'float64')) for o in self.outputs] 
    output = tf.reshape(tf.concat(1, self.outputs), [-1, len(self.vocab)]) 
    self.calculate_loss = self.add_loss_op(output) 
    self.train_step = self.add_training_op(self.calculate_loss) 

は、他の方法である皆を助けるためにいくつかの更新されたコードは次のとおりです。私は損失関数を得るとき

def add_training_op(self, loss): 
    opt = tf.train.AdamOptimizer(self.config.lr) 
    train_op = opt.minimize(loss) 
    return train_op 

EDITを次のように

問題が定義され、発生しますここではadd_projectioncalculate_lossに属しているので、それらを除外することができます。

def add_loss_op(self, output): 
    weights = tf.ones([self.config.batch_size * self.config.num_steps], tf.int32) 
    seq_loss = tf.python.seq2seq.sequence_loss(
     [output], 
     tf.reshape(self.labels_placeholder, [-1]), 
     weights 
    ) 
    tf.add_to_collection('total_loss', seq_loss) 
    loss = tf.add_n(tf.get_collection('total_loss')) 
    return loss 

def add_projection(self, rnn_outputs): 
    with tf.variable_scope("Projection", initializer=tf.contrib.layers.xavier_initializer()) as scope: 
     U = tf.get_variable("U", [self.config.hidden_size, len(self.vocab)]) 
     b_2 = tf.get_variable("b_2", [len(self.vocab)]) 

     outputs = [tf.matmul(x, U) + b_2 for x in rnn_outputs] 
     return outputs 


def train_RNNLM(): 
    config = Config() 
    gen_config = deepcopy(config) 
    gen_config.batch_size = gen_config.num_steps = 1 

    with tf.variable_scope('RNNLM') as scope: 
    model = RNNLM_Model(config) 
    # This instructs gen_model to reuse the same variables as the model above 
    scope.reuse_variables() 
    gen_model = RNNLM_Model(gen_config) 

    init = tf.initialize_all_variables() 
    saver = tf.train.Saver() 

    with tf.Session() as session: 
    best_val_pp = float('inf') 
    best_val_epoch = 0 

    session.run(init) 
    for epoch in xrange(config.max_epochs): 
     print 'Epoch {}'.format(epoch) 
     start = time.time() 
     ### 
     train_pp = model.run_epoch(
      session, model.encoded_train, 
      train_op=model.train_step) 
     valid_pp = model.run_epoch(session, model.encoded_valid) 
     print 'Training perplexity: {}'.format(train_pp) 
     print 'Validation perplexity: {}'.format(valid_pp) 
     if valid_pp < best_val_pp: 
     best_val_pp = valid_pp 
     best_val_epoch = epoch 
     saver.save(session, './ptb_rnnlm.weights') 
     if epoch - best_val_epoch > config.early_stopping: 
     break 
     print 'Total time: {}'.format(time.time() - start) 
+0

もっとコードを提供できますか? 'add_training_op'と' add_embedding'がコードの残りの部分でどのように動作しているのか分かりません。 – martianwars

+0

@martianwars私は先に進み、コード設定を追加しました。 – TheM00s3

答えて

0

問題は、次のコード行であることが判明:

model = RNNLM_Model(config) 
    # This instructs gen_model to reuse the same variables as the model above 
    scope.reuse_variables() 
    gen_model = RNNLM_Model(gen_config) 

それは第2のモデルはreuse_variables()を使用して、問題だったことが判明します。この行を削除することで、問題がなくなりました。

0

コードが各バッチで新しいAdam変数を作成しようとしているようです。 add_training_opが2回呼び出される可能性はありますか? また、def add_training_opのスニペットはreturn文がないため、不完全です。

+0

リターンステートメントの欠落はタイプミスで、修正されました。私は先に進み、printステートメントを追加して、そのコードがどのくらい頻繁に呼び出されるかを調べます。 – TheM00s3

+0

私は別の方法を追加しました。おそらくその問題を引き起こしています。 – TheM00s3

関連する問題