2017-08-02 2 views
0

私はセンテンスレベルの注意層で訓練された深いニューラルネットワークを持っています。次のように、ネットワークはGRUと呼ばれています。私はテスト後に注目価値(sen_alpha)の結果を得たいと考えています。numpyをTensorFlowのpythonでcPickle.PicklingErrorとして.npy形式で保存することはできません。

class GRU: 
def __init__(self,is_training,word_embeddings,settings): 

    self.big_num = big_num = settings.big_num  
    for i in range(big_num): 

     sen_repre.append(tf.tanh(attention_r[self.total_shape[i]:self.total_shape[i+1]])) 
     batch_size = self.total_shape[i+1]-self.total_shape[i] 
       sen_alpha.append(tf.reshape(tf.nn.softmax(tf.reshape(tf.matmul(tf.mul(sen_repre[i],sen_a),sen_r),[batch_size])),[1,batch_size])) 
       self.attentions.append(sen_alpha[i]) 

テストコード:

def main(_): 
test_settings = Settings() 
with tf.Graph().as_default(): 

    sess = tf.Session() 
    with sess.as_default():  
     with tf.variable_scope("model"): 
          mtest = GRU(is_training=False, word_embeddings = None, settings = test_settings) 
        saver = tf.train.Saver() 

      attentions = mtest.attentions 
      att = np.array(attentions)  
      print(str(type(att))) 
      print(att[0:100]) 
      np.save("attentions.npy",att) 

結果:

タイプ:タイプ 'numpy.ndarray'

ATT [0:100]:

[<tf.Tensor 'model/Reshape_9:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_17:0' shape=(1, ?) dtype=float32<tf.Tensor 'model/Reshape_25:0' shape=(1, ?) dtype=float32>

エラー:

メイン np.save( "attentions.npy"、ATT)

cPickle.PicklingErrorで

ファイル "test_GRU.py"、行242、:酸洗いすることはできません:属性検索組み込み .moduleは

を失敗しました

結果を正しく保存するにはどうすればよいですか?おかげ

+0

なぜあなたは 'tpy.Tensor'オブジェクトの' numpy'オブジェクト配列を使用していますか?それはほとんど意味がありません。ちょうど私たちのリストまたは何か。 –

+0

'att'はおそらくオブジェクトdtype配列です。つまり、 'attentions'オブジェクトへのポインタを1つ以上持つ配列です。 'np.save'は' pickle'を使ってオブジェクトを保存します。数値データバッファをファイルに直接書き込むことはできますが、バイト列を作成するには 'pickle'を使用する必要があります。私の推測では、 'tf.Tensor'は酸洗いの方法を定義していません。独自の定義済みの保存方法については、Tensorflowをチェックしてください。 – hpaulj

答えて

0

私はあなたのコードを修正することはできませんが、私はそれから値を抽出するモデル定義から、ステップのデザインであなたのステップの短いバージョンを与えることができます:

  1. はモデルグラフを定義します。それはGRUがグラフの一部であることを意味します。
  2. セッションを開始します。 sess = tf.Session()
  3. グラフの変数を初期化します。 sess.run(tf.global_variables_initializer())
  4. セッションメソッドを使用して、対応するグラフから値を取得します。 sess.run(the_tensor, dictionary_of_numpy_array_as_input_to_graph)

出力は保存できるnumpyの配列になります。

関連する問題