2016-08-29 5 views
3

tf.nn.embedding_lookupを試してみましたが、動作しています。しかし、それは高密度の入力データを必要とし、現在は疎入力用にtf.nn.embedding_lookup_sparseが必要です。TensorFlowでtf.nn.embedding_lookup_sparseを使用するには?

次のコードを書いていますが、エラーが発生しました。

import tensorflow as tf 
import numpy as np 

example1 = tf.SparseTensor(indices=[[4], [7]], values=[1, 1], shape=[10]) 
example2 = tf.SparseTensor(indices=[[3], [6], [9]], values=[1, 1, 1], shape=[10]) 

vocabulary_size = 10 
embedding_size = 1 
var = np.array([0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0]) 
#embeddings = tf.Variable(tf.ones([vocabulary_size, embedding_size])) 
embeddings = tf.Variable(var) 

embed = tf.nn.embedding_lookup_sparse(embeddings, example2, None) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 

    print(sess.run(embed)) 

エラーログは次のようになります。

enter image description here

今私が正しく、この方法を修正して使用する方法は考えています。どんなコメントもありがとう。 safe_embedding_lookup_sparseのユニットテストにダイビング後

まばらな重みを与える場合、私はこの結果を得た理由は、私たちは3は、上記のコードに登場していないembedding_weights[0][3]のようなものを持って、特に理由は、より混乱しています。

​​

+0

、私の答えはあなたの問題を解決した場合:) – rvinas

+0

ありがとう@rvinasを私に知らせてください。 'safe_embedding_lookup_sparse'の単体テストを読んだあと、私はまだそれを理解しません。私は質問を更新しました。私たちのコードを説明しますか? – tobe

+0

生のコードを教えてください。 – rvinas

答えて

6

tf.nn.embedding_lookup_sparse()は0で開始し、このエラーが発生した理由ですその1ずつ増加させるSparseTensorからインデックスを必要とする、埋め込みを組み合わせることSegmentation使用しています。

ブール値の代わりに、スパーステンソルは、埋め込みから取得するすべての行のインデックスのみを保持する必要があります。

import tensorflow as tf 
import numpy as np 

example = tf.SparseTensor(indices=[[0], [1], [2]], values=[3, 6, 9], dense_shape=[3]) 

vocabulary_size = 10 
embedding_size = 1 
var = np.array([0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0]) 
embeddings = tf.Variable(var) 

embed = tf.nn.embedding_lookup_sparse(embeddings, example, None) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    print(sess.run(embed)) # prints [ 9. 36. 81.] 

また、あなたが許可されtf.nn.embedding_lookup_sparse()コンバイナのいずれかを使用して単語の埋め込みを組み合わせることtf.SparseTensor()からインデックスを使用することができます:ここにあなたの微調整コードです

  • 「合計は」の加重和を計算します各行の結果を埋め込みます。
  • 「平均」は、合計重量を総重量で割ったものです。
  • "sqrtn"は、加重和を加重の二乗和の平方根で割ったものです。たとえば、

example = tf.SparseTensor(indices=[[0], [0]], values=[1, 2], dense_shape=[2]) 
... 
embed = tf.nn.embedding_lookup_sparse(embeddings, example, None, combiner='sum') 
... 
print(sess.run(embed)) # prints [ 5.] 
関連する問題