LSTMコードを実行しようとしていますが、word2Vecワード埋め込み入力を接続しようとしていますが、埋め込み参照を取得する際にエラーが発生します。続きTensorFlowのワード埋め込みの無効な引数例外インデックス
はコードです:
batchSize = 24
lstmUnits = 64
numClasses = 2
iterations = 100000
maxSeqLength = 250
numDimensions = 128
import tensorflow as tf
tf.reset_default_graph()
labels = tf.placeholder(tf.float32, [batchSize, numClasses])
input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])
data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32)
# word Vector Shape = (13277, 128)
data = tf.nn.embedding_lookup(wordVectors,input_data)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
try:
for i in range(iterations):
#nextBatch shape is (24, 250)
nextBatch, nextBatchLabels = getTrainBatch()
sess.run(optimizer, feed_dict={input_data: nextBatch, labels: nextBatchLabels})
except Exception as ex:
print(ex)
私が欠けている小さなステップがあるかもしれません。それは何でしょうか。私はあなたがあなたのケースで単語の埋め込みを使用する方法を理解させるためには報告しているコードを単純化している
まず、コードにエラーを修正する必要があります。あなたは 'data'を2回宣言しました。最初の出現は 'wordVectors'と呼ばれるべきです。 –
'wordVectors'はすでに定義されています。埋め込みで使用するimは次元'(13277,128) 'で検索します。私はそれをコメントすることができます。 –