2017-03-08 11 views
1

私はtensorflowの初心者です。私はword2vec_basicスクリプトを理解しようとしています。tensoflow word2vec_basic入出力出力プレースホルダ

それは入力と出力を定義していました。

train_inputs = tf.placeholder(tf.int32, shape=[batch_size]) 
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1]) 

Iが正しくtrain_inputsの形状を理解している場合 - [batch_size]は、整数の配列であり、配列の長さは、batch_sizeあります。次に、train_labelsの形状は[batch_size, 1]であり、これは1行の行列です。正しい?はいの場合、私はそれがなぜそうであるか、そしてそれらのプレースホルダーに保存されるべきものを理解していません。理論によると、labelはintであり、inputはコンテキストのスライディングウィンドウの配列なので、なぜbatch_sizeがそこに現れますか?

私は理論で何か基本的なものを見逃したようです。私は説明を感謝します。

答えて

0

トレーニングのために、複数の文章に対して並列に計算するのが便利です。これがbatch_sizeの目的です。バッチ内の各単語は、データ内の異なる部分に由来し、train_labelsは、対応するラベルを有する。勾配はバッチ全体で集計され、その後単一のパラメータ更新が行われます。

+0

この場合、ラベルの形状が[batch_size、1]で、[batch_size]でない理由は何ですか? – com

+0

私はよく分かりません。おそらく、1xnサイズの行列で乗算しなければならない場所があります。 – Aaron

0

私は最近word2vec_basic.pyをデバッグするために起こったと単純なデータについては、「犬は犬は猫が木に登った猫を追いかけた猫を見た」の例を設定し、dictionaryreverse_dictionary{0: 'the', 1: 'cat', 2: 'dog', 3: 'a', 4: 'saw', 5: 'chased', 6: 'climbed', 7: 'tree'}で、{'the': 0, 'cat': 1, 'dog': 2, 'a': 3, 'saw': 4, 'chased': 5, 'climbed': 6, 'tree': 7}です。

BATCH_SIZE = 8、skip_window = 1、embedding_size = 1、及びnum_skips = 2、generate_batch戻り呼び出し:単語に翻訳さ batch: [2 2 4 4 3 3 1 1] labels: [[0] [4] [2] [3] [4] [1] [0] [3]]

し、それらは:最初の三つの単語の batch: [dog, dog, saw, saw, a, a, cat, cat] labels: [[the], [saw], [dog], [a], [saw], [cat], [the], [a]]

skip_windowは1であり、期待される(ターゲット、コンテキスト)ペアは(dog、the)と(dog、saw)でなければならない - 詳細については、 "The Skip-gram Model"(https://www.tensorflow.org/tutorials/word2vec)を参照してください。必要に応じて。 1 [batch_size, 1]に使用されている理由としては

https://www.tensorflow.org/api_docs/python/tf/nn/nce_lossnce_lossドキュメントは、と言う「のラベル:タイプのInt64のテンソル及び形状【BATCH_SIZE、num_true]ターゲットクラス」 「num_true:An int。1つのトレーニングサンプルあたりのターゲットクラスの数」 (num_trueは、デフォルト値の1を持っている)

、コードでnce_lossが定義されています:

 loss = tf.reduce_mean(
     tf.nn.nce_loss(weights=nce_weights, 
        biases=nce_biases, 
        labels=train_labels, 
        inputs=embed, 
        num_sampled=num_sampled, 
        num_classes=vocabulary_size)) 

だから今、train_inputstrain_labelsの形状が理にかなって見ることは容易でなければなりません。

+0

ありがとうございました。バッチのnce_lossはスキップグラムにハードワイヤードされているようですが、何らかの理由でラベルにマルチラベルの分類が可能です。奇妙な。そして、ペア(犬、犬)と(犬、鋸)の場合、ターゲット単語がラベルと文脈にバッチで表示されるはずですが、それは別の方法です。 – com