2017-06-19 18 views
0

LSTMコードを実行しようとしていますが、word2Vecワード埋め込み入力を接続しようとしていますが、埋め込み参照を取得する際にエラーが発生します。続きTensorFlowのワード埋め込みの無効な引数例外インデックス

はコードです:

batchSize = 24 
lstmUnits = 64 
numClasses = 2 
iterations = 100000 
maxSeqLength = 250 
numDimensions = 128  
import tensorflow as tf 
tf.reset_default_graph() 

labels = tf.placeholder(tf.float32, [batchSize, numClasses]) 
input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength]) 

data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32) 
# word Vector Shape = (13277, 128) 
data = tf.nn.embedding_lookup(wordVectors,input_data) 




saver = tf.train.Saver() 
sess.run(tf.global_variables_initializer()) 
try: 
    for i in range(iterations): 
    #nextBatch shape is (24, 250) 
     nextBatch, nextBatchLabels = getTrainBatch() 
     sess.run(optimizer, feed_dict={input_data: nextBatch, labels: nextBatchLabels}) 
except Exception as ex: 
    print(ex) 

私が欠けている小さなステップがあるかもしれません。それは何でしょうか。私はあなたがあなたのケースで単語の埋め込みを使用する方法を理解させるためには報告しているコードを単純化している enter image description here

+0

まず、コードにエラーを修正する必要があります。あなたは 'data'を2回宣言しました。最初の出現は 'wordVectors'と呼ばれるべきです。 –

+0

'wordVectors'はすでに定義されています。埋め込みで使用するimは次元'(13277,128) 'で検索します。私はそれをコメントすることができます。 –

答えて

0

: 私は、コードを実行すると、私は例外を取得します。また、コードを完全に再現することができないため、すべてを指定していません(optimizer変数を参照)。

私はここで、形状(batchSize, maxSeqLength)の入力マトリックスから単語の埋め込みを取得するための簡単なスニペットを報告します。

batchSize = 24 
lstmUnits = 64 
numClasses = 2 
iterations = 100000 
maxSeqLength = 250 
numDimensions = 128 
numTokens = 50 

import tensorflow as tf 
import numpy as np 

session = tf.InteractiveSession() 
input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength]) 
# you should NOT use tf.Variable() but tf.get_variable() instead 
embeddings_weights = tf.get_variable("embeddings_weights", initializer=tf.random_normal_initializer(0, 1), shape=(numTokens, numDimensions)) 
input_embeddings = tf.nn.embedding_lookup(embeddings_weights, input_data) 
result = session.run(input_embeddings, feed_dict={input_data: np.random.randint(0, numTokens, (batchSize, maxSeqLength))}) 
print(result.shape) 
// should print (24, 250, 300) 

あなたはそのエラーを受け取り、なぜあなたのコードをデバッグする必要がありますし、トレーニングデータには有効でないインデックスが指定されたかどうかを確認理解しようとしている場合。私のスニペットコードでは、np.random.randint()を使用することによって、出力要素が(0、numTokens)の範囲にあるようにして、エラーを回避しました。これは、TensorFlowが範囲外になるIDのルックアップ操作を完了できないために発生します。

関連する問題