2017-05-10 11 views
4

言語モデルにTensorFlow LSTMを使用しています(単語のシーケンスがあり、次の単語を予測したい)。私が言語モデルを実行しているときに、入力、変換、および出力ゲートを各ステップで使用します。これはどうすればいいですか?TensorFlowでLSTMゲートの値をプリントアウトする方法は?

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.pyにコードを検査するから、私はLayerNormBasicLSTMCellクラスは、私はプリントアウトしたいi, j, f, o変数を含むcall方法を持っていることがわかります。

def call(self, inputs, state): 
    """LSTM cell with layer normalization and recurrent dropout.""" 
    c, h = state 
    args = array_ops.concat([inputs, h], 1) 
    concat = self._linear(args) 

    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 
    if self._layer_norm: 
     i = self._norm(i, "input") 
     j = self._norm(j, "transform") 
     f = self._norm(f, "forget") 
     o = self._norm(o, "output") 

    g = self._activation(j) 
    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 
     g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 

    new_c = (c * math_ops.sigmoid(f + self._forget_bias) 
      + math_ops.sigmoid(i) * g) 
    if self._layer_norm: 
     new_c = self._norm(new_c, "state") 
    new_h = self._activation(new_c) * math_ops.sigmoid(o) 

    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) 
    return new_h, new_state 

しかし、これらの変数を簡単に印刷する方法はありますか?または、私はLTSMを実行しているスクリプトで、このメソッドで関連するコード行を基本的に再作成する必要がありますか?

答えて

1

gitの問題で同様の質問をしました。そして、応答は元の細胞がch(すべてのステップの出力yでもある)を返すだけです。内部変数を取得するには、自分で行う必要があります。 https://github.com/tensorflow/tensorflow/issues/5731

+0

は、あなたがこれを達成することができたのは?私はまた、すべてのLSTMセルゲートを記録する必要があります。しかし、私はあまりにも多くのものを壊す呼び出しの出力を変更します。あなたがしたことの例がありますか? – dsalaj

0

基本的にはこの方法を行うことができます:ここで

はリンクになって

はまず、例えば、return new_h, new_state, i, j, f, oをあなたが必要とする状態を返します。このような変更を行うには、TensorFlowのソースコードファイルをコピーし、独自のコードのようにコードにインポートする必要があります。あなたのコード内
その後、session.run(to_return, feed_dict)に、このようなto_returnます

output, state, i, j, f, o = lstm_cell(input, state) 
to_return = { 
    "new_h": output, 
    "new_state": state, 
    "i": i, 
    "j": j, 
    "f": f, 
    "o": o, 
} 

results = session.run(to_return, feed_dict) # get what you want from the 
# graph(which are tensors), resulting in results of a dictionary with values 
# being numpy arrays. 

print results["i"] # you'll get a numpy array representing the i gate  
関連する問題