2017-07-20 5 views
3

モデルを復元して予測に使用できるようにLSTMの最終状態を保存します。以下で説明するように、Saverはtf.assignを使用すると最終状態の知識しか持ちません。ただし、これによりエラーが発生します(後述)。Tensorflowは、予測のためにdynamic_rnnのLSTMの最終状態を保存します。

トレーニング中は、this postで説明されているように、私はいつも最後のLSTM状態をネットワークに戻します。ここでは、コードの重要な部分です:

  self.init_state = tf.placeholder(tf.float32, [ 
       self.n_layers, 2, self.batch_size, self.n_hidden 
      ]) 

      state_per_layer_list = tf.unstack(self.init_state, axis=0) 

      rnn_tuple_state = tuple([ 
       tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0], 
               state_per_layer_list[idx][1]) 

       for idx in range(self.n_layers) 
      ]) 

      outputs, self.final_state = tf.nn.dynamic_rnn(
       cell, inputs=self.inputs, initial_state=rnn_tuple_state) 

とトレーニング中:グラフ構築する際

私は後でチェックポイントからの私のモデルを復元すると

 _current_state = np.zeros((self.n_layers, 2, self.batch_size, 
            self.n_hidden)) 

      _train_step, _current_state, _loss, _acc, summary = self.sess.run(
       [ 
        self.train_step, self.final_state, 
        self.merged 
       ], 
       feed_dict={self.inputs: _inputs, 
          self.labels:_labels, 
          self.init_state: _current_state}) 

を、最終状態があります復元されません。 this postに概要が示されているように、Saverは新しい状態を知らないという問題があります。この記事では、tf.assignに基づいた解決法も提案しています。 self.init状態は変数が、プレースホルダではないので、残念ながら、私が提案し

  assign_op = tf.assign(self.init_state, _current_state) 
      self.sess.run(assign_op) 

を使用することはできません。 「テンソル」オブジェクトが、私は今、数時間、この問題を解決することを試みたが、私はそれを動作させることはできません

「を割り当てる」は属性を持っていない:私はエラー

はAttributeErrorを取得します。

ご協力いただきましてありがとうございます。

EDIT:

私が変更されているself.init_state

  self.init_state = tf.get_variable('saved_state', shape= 
      [self.n_layers, 2, self.batch_size, self.n_hidden]) 

      state_per_layer_list = tf.unstack(self.init_state, axis=0) 

      rnn_tuple_state = tuple([ 
       tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0], 
               state_per_layer_list[idx][1]) 

       for idx in range(self.n_layers) 
      ]) 

      outputs, self.final_state = tf.nn.dynamic_rnn(
       cell, inputs=self.inputs, initial_state=rnn_tuple_state) 

し、トレーニング中に、私はself.init_stateに値を供給していない:

しかし
  _train_step, _current_state, _loss, _acc, summary = self.sess.run(
       [ 
        self.train_step, self.final_state, 
        self.merged 
       ], 
       feed_dict={self.inputs: _inputs, 
          self.labels:_labels}) 

、私はまだ割り当てopを実行することはできません。予想される、 '割り当て' オペアンプの 'の値を' パラメータに渡された(LSTMStateTuple(C =配列([[0.07291573、-0.06366599、-0.23425588、...、0.05307654、

+0

あなたの質問はあなたがリンクしたものと重複しているようです。受け入れられた答えは、プレースホルダをtf.Variableに変えることを示唆しているため、TensorFlowの割り当て操作でネットワークの状態を更新することができます。 –

+0

はい、私はそれを試みました。たぶん、私はそれを変えるときに間違いをしています。どのようにtf.Variableに変換するのですか? self.init_state = tf.Variable(self.init_state)を使用すると、すべての変数を初期化するときにエラーが発生する – Lemon

+0

元の質問を変更しました(編集を参照)。それはそれが他のポストで説明された方法ではまだ動作しません。何か不足していますか? – Lemon

答えて

1
を得たのfloat32:私は

例外TypeErrorが得る知っています

最後の状態を保存するには、別のTF変数を作成してから、グラフを保存する前にassignを実行して、その変数に最新の状態を割り当ててから、グラフを保存します。あなたがSaverを宣言する前にその変数を宣言することです。それ以外の場合は、グラフには含まれません。

ここでは、作業コードを含めて詳細に説明しています。 TF LSTM: Save State from training session for prediction session later

*** UPDATE:答えは質問をフォローする:

あなたがstate_is_tuple=Trueで、BasicLSTMCellを使用しているように見えます。私が前に述べたのは、GRUCellstate_is_tuple=Falseを使用したことです。

トレーニング中は、最初にinitial_stateとして0を入力してから、dynamic_rnnに0を入力してから、自分の出力を元に戻し続けます。これは、次の2つの方法の違いです。 initial_stateと入力します。したがって、dynamic_rnnコールの最後の出力状態は、後で保存するものです。それはsess.run()呼び出しの結果であるため、本質的にそれはnumpy配列(テンソルではなくプレースホルダ)です。だから質問は「グラフの変数の残りの部分と一緒にTensorflow変数として数量の少ない配列を保存するにはどうすればよいのか」ということになります。だからこそ、最終状態をその目的だけの変数に代入します。

述べたように
# GRAPH DEFINITIONS: 
    state_in = tf.placeholder(tf.float32, [LAYERS, 2, None, CELL_SIZE], name='state_in') 
    l = tf.unstack(state_in, axis=0) 
    state_tup = tuple(
     [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) 
     for idx in range(NLAYERS)]) 
    #multicell = your BasicLSTMCell/MultiRNN definitions 
    output, state_out = tf.nn.dynamic_rnn(multicell, X, dtype=tf.float32, initial_state=state_tup) 

    savedState = tf.get_variable('savedState', shape=[LAYERS, 2, BATCHSIZE, CELL_SIZE]) 
    saver = tf.train.Saver(max_to_keep=1) 

    in_state = np.zeros((LAYERS, 2, BATCHSIZE, CELL_SIZE)) 

    # TRAINING LOOP: 
    feed_dict = {X: x, Y_: y_, batchsize: BATCHSIZE, state_in:in_state} 
    _, out_state = sess.run([training_step, state_out], feed_dict=feed_dict) 
    in_state = out_state 

    # ONCE TRAINING IS OVER: 
    assignOp = tf.assign(savedState, out_state) 
    sess.run(assignOp) 
    saver.save(sess, pathModel + '/my_model.ckpt') 

    # RECOVERING IN A DIFFERENT PROGRAM: 

    gInit = tf.global_variables_initializer().run() 
    lInit = tf.local_variables_initializer().run() 
    new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta') 
    new_saver.restore(sess, pathModel + 'my_model.ckpt') 
    # retrieve State and get its LAST batch (latest obervarions) 
    savedState = sess.run('savedState:0') # this is FULL state from training 
    state = savedState[:,:,-1,:] # -1 gets only the LAST batch of the state (latest seen observations) 
    state = np.reshape(state, [state.shape[0], 2, -1, state.shape[2]]) #[LAYERS, 2, 1 (BATCH), SELL_SIZE] 
    #x = .... (YOUR INPUTS) 
    feed_dict = {'X:0': x, 'state_in:0':state} 
    #PREDICTION LOOP: 
    preds, state = sess.run(['preds:0', 'state_out:0'], feed_dict = feed_dict) 
    # so now state will be re-fed into feed_dict with the next loop iteration 

が、これはGRUCellstate_is_tuple = Falseと私のためによく働くものの修正アプローチです:

ので、コードは次のようなものです。私はBasicLSTMCellstate_is_tuple=Trueを試してみました。それは元のアプローチと同じくらい正確には機能しません。私はGRUがLSTMよりも優れているのか、それとも他の理由から私には分かりません。これがうまくいくかどうかを確認してください...

リカバリと予測コードでわかるように、予測はトレーニングループとは異なるバッチサイズに基づいている可能性があります1?)したがって、回復した状態を処理する方法を考える必要があります。最後のバッチを取るだけですか?または、他の何か?このコードでは、保存された状態の最後のレイヤー(トレーニングからの最新の観測値)が使用されます。これは、関連性が高いためです...

+0

はい。私はその議論を見て、私のポストにリンクを提供しました。しかし、提供されたコードは重要な部分を除外しています。それが役に立たなかった理由です。私があなたを正しく理解していれば、投稿の「savedState」変数は、すべてのエポックで最新のLSTM状態を割り当てるためにのみ使用されますか?しかし、あなたはトレーニング中にlstm州でどのように食べるのですか?上記のプレースホルダを使用していますか?あなたのポストでは、「フィードdict」は決して完全には指定されていません。 – Lemon

+0

また、変数を明確には復元するためのコード。トレーニングで最後のlstm状態を取得するには、sess.run( 'savedState:0')を実行します。しかし、どのようにネットにその情報を供給していますか?どの変数またはプレースホルダに与えられますか? – Lemon

+0

Apols ..あなたの質問は、予測のためにFINAL状態を保存する方法を尋ねました。私は数秒で私の答えを更新します。 –

関連する問題