2016-04-02 10 views
6

私は、pretrained word2vecベクトルを既存のtensorflow seq2seqモデルに注入しようとしていました。TensorFlowにpre-training word2vecベクトルを注入するseq2seq

this answerに続いて、次のコードを生成しました。しかし、変数の値は更新されますが、パフォーマンスを向上させることはできません。

私の理解では、EmbeddingWrapperまたはembedding_attention_decoderが語彙の順序とは独立して埋め込みを作成するという事実に起因している可能性があります。

事前訓練されたベクトルをテンソルフローモデルにロードするにはどうすればよいでしょうか?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding" 
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding" 


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size): 
    word2vec_model = word2vec.load(word2vec_path, encoding="latin-1") 
    print("w2v model created!") 
    session.run(tf.initialize_all_variables()) 

    assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size) 
    assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size) 


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size): 
    vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name] 
    if len(vectors_variable) != 1: 
     print("Word vector variable not found or too many. key: " + embedding_key) 
     print("Existing embedding trainable variables:") 
     print([v.name for v in tf.trainable_variables() if "embedding" in v.name]) 
     sys.exit(1) 

    vectors_variable = vectors_variable[0] 
    vectors = vectors_variable.eval() 

    with gfile.GFile(vocab_path, mode="r") as vocab_file: 
     counter = 0 
     while counter < vocab_size: 
      vocab_w = vocab_file.readline().replace("\n", "") 
      # for each word in vocabulary check if w2v vector exist and inject. 
      # otherwise dont change the value. 
      if word2vec_model.__contains__(vocab_w): 
       w2w_word_vector = word2vec_model.get_vector(vocab_w) 
       vectors[counter] = w2w_word_vector 
      counter += 1 

    session.run([vectors_variable.initializer], 
      {vectors_variable.initializer.inputs[1]: vectors}) 

答えて

5

私はseq2seq例に精通していないですが、一般的に、あなたの埋め込みを注入するために、次のコードスニペットを使用することができます:あなたは、グラフ構築

with tf.device("/cpu:0"): 
    embedding = tf.get_variable("embedding", [vocabulary_size, embedding_size])  
    inputs = tf.nn.embedding_lookup(embedding, input_data) 

とき(グラフを作成してからトレーニングを開始する前に)実行すると、埋め込み変数に保存された埋め込みを割り当てます:

session.run(tf.assign(embedding, embeddings_that_you_want_to_use)) 

考えられるのは、embedding_lookupの値がinput_dataの値をembeddingの変数に置き換えたということです。