2017-07-28 1 views
1

テンソルフローでは、API tf.add_to_collcetionがコードに次のような値を追加しています。Tensorflowリセットまたはクリアコレクション

def accuracy_rate(logits, labels): 
    correct = tf.nn.in_top_k(logits, labels, 1) 
    # Return the accuracy of true entries. 
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) 
    return accuracy 
with tf.Session() as sess: 
    logits, labels = ... 
    accuracy = accuracy_rate(logits, labels) 
    tf.add_to_collection('total_accuracy', sess.run(accuracy)) 

APIで見つけられないことは、既に1つのコレクションに保存されているすべての値をクリアするにはどうすればよいですか?

+0

私はあなたが別の解決策を発見した知っているが、あなたはまた、この目的のために 'tf.get_default_graph()。clear_collection( 'total_accuracyを')'検討することができます。コレクションには 'variables'、' trainable_variables'と 'train_op'も含まれているので、キーで値を削除することをお勧めします。 –

答えて

3

あなたは(それだけでPythonのリストです)クリアできるコレクションへの変更可能な参照を取得するtf.get_collection_refを使用することができます。

+0

コレクションの値をクリアすることができましたか? – user6932206

+1

はい、クリアすることができるpythonリストです –

0

異なるtf.Graph()を使用して、別の解決策を探します。

0

これはあなたが探しているかもしれないと思いますか?

In [2]: import tensorflow as tf 
In [3]: w = tf.Variable([[1,2,3], [4,5,6], [7,8,9], [3,1,5], [4,1,7]], collections=[tf.GraphKeys.WEIGHTS, tf.GraphKeys.GLOBAL_VARIABLES], dtype=tf.float32) 
In [4]: params = tf.get_collection_ref(tf.GraphKeys.WEIGHTS) 
In [5]: del params[:] 
In [6]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)                                         
Out[6]: [] 
In [10]: params = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) 
In [11]: params 
Out[11]: [<tf.Variable 'Variable:0' shape=(5, 3) dtype=float32_ref>] 
関連する問題