2017-01-30 17 views
0

テンソルフローでLSTMセルベースのRNNを実装しようとしています。私は、フィーチャを構成する各行に一連の単語を持ち、ターゲットとなるlabels(1または0のいずれか)という名前の列を持っています。私はdynamic_RNNの出力を解釈することに困難を抱えています。 dynamic_RNNは出力と状態を返します。 私はデータを256のバッチに分割しました。 の各行を一様な長さ160にパッディング/スライスしました。次に、各ワードが100次元のベクトルになるように(グローブを使用して)埋め込みました。この後、私の各例は160×100の行列になり、RNN関数の各バッチはテンソルの形状 - (256,160,100)になります。今、LSTMセルのhidden_​​dimensionが300に設定されていると仮定すると、160×100の各例の隠れ状態ベクトルの予想される次元は1×300です。しかし、rnn_states.hの形状をチェックすると、次のようになります -tensorflowのdynamic_rnnの出力を解釈するのが難しい

私は寸法が256×300であると予想していました。今後の進め方はわかりません。

私が使用した以下のコードとその形状を追加しています。私は2つのテンソル - train_contexttrain_utteranceを持っています。私はそれぞれ128サンプルを選択し、後でそれらを埋め込み、連結して256の入力バッチを作成します。どこに間違っているか教えていただけますか?

train_context.get_shape() 
# TensorShape([Dimension(128), Dimension(160)]) 

train_utterance.get_shape() 
# TensorShape([Dimension(128), Dimension(160)]) 

context_embedded = tf.cast(tf.nn.embedding_lookup(embeddings_W,\ 
train_context, name="embed_context"), dtype = tf.float64) 

utterance_embedded = tf.cast(tf.nn.embedding_lookup(embeddings_W,\ 
train_utterance, name="embed_utterance"), dtype = tf.float64) 

context_embedded.get_shape() 
# TensorShape([Dimension(128), Dimension(160), Dimension(100)]) 
utterance_embedded.get_shape() 
# TensorShape([Dimension(128), Dimension(160), Dimension(100)]) 

inputs = tf.concat(0, [context_embedded, utterance_embedded]) 
inputs.get_shape() 
# TensorShape([Dimension(256), Dimension(160), Dimension(100)]) 

cell = tf.nn.rnn_cell.LSTMCell(300, forget_bias=2.0, use_peepholes=True,\ 
state_is_tuple=True) 

cell.state_size 
# LSTMStateTuple(c=300, h=300) 

rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float64,\ 
parallel_iterations=64, time_major=False) 


rnn_states.h.get_shape() 
# TensorShape([Dimension(None), Dimension(300)]) 

私は次元が、私はそれが事前に256×300のおかげであることを期待していたいずれのx 300ではない理由を理解することはできませんよ。

答えて

1

get_shapeがベストエフォートです。一般に、Tensorflowはグラフ作成時に常にTensorsのサイズを認識しません。 Noneは、単に工事時に寸法が不明であることを意味します。

PythonからTensorの形を見つける唯一の保証された方法は、グラフを実行して(例えば、session.run()または.eval()を介して)テンソルを特定の値に評価することです。助け

(あなたは、グラフ内形状を消費したい場合は、tf.shapeもある。)

願っています!

関連する問題