言語モデルにTensorFlow LSTMを使用しています(単語のシーケンスがあり、次の単語を予測したい)。私が言語モデルを実行しているときに、入力、変換、および出力ゲートを各ステップで使用します。これはどうすればいいですか?TensorFlowでLSTMゲートの値をプリントアウトする方法は?
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.pyにコードを検査するから、私はLayerNormBasicLSTMCell
クラスは、私はプリントアウトしたいi, j, f, o
変数を含むcall
方法を持っていることがわかります。
def call(self, inputs, state):
"""LSTM cell with layer normalization and recurrent dropout."""
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state
しかし、これらの変数を簡単に印刷する方法はありますか?または、私はLTSMを実行しているスクリプトで、このメソッドで関連するコード行を基本的に再作成する必要がありますか?
は、あなたがこれを達成することができたのは?私はまた、すべてのLSTMセルゲートを記録する必要があります。しかし、私はあまりにも多くのものを壊す呼び出しの出力を変更します。あなたがしたことの例がありますか? – dsalaj