2016-10-06 13 views
0

私は2つのクラスを持っています:正(1)と負(0)です。TensorFlow:バイナリ分類のためのクラスごとの損失関数を実装する方法

データセットは非常に不均衡なので、私のミニバッチにはほとんど0が含まれています。実際、多くのバッチには0しか含まれません。私は、正と負の例に対して別々のコストをかけて実験したかったのです。以下のコードを参照してください。

私のコードの問題は、多くのnanがあることです。なぜなら、bound_indexリストが空になるからです。これを解決するためのエレガントな方法は何ですか?

def calc_loss_debug(logits, labels): 
    logits = tf.reshape(logits, [-1]) 
    labels = tf.reshape(labels, [-1]) 
    index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32))) 
    index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32))) 
    entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels) 
    entropies_bound = tf.gather(entropies, index_bound) 
    entropies_unbound = tf.gather(entropies, index_unbound) 
    loss_bound = tf.reduce_mean(entropies_bound) 
    loss_unbound = tf.reduce_mean(entropies_unbound) 
+0

私は空のリストの平均をとっているので(Entropies_boundは空になります) – Stackd

答えて

1

あなたは0と1のラベルを持っているので、あなたは簡単に平均損失を取得するには、この

labels = ... 
entropies = ... 
labels_complement = tf.constant(1.0, dtype=tf.float32) - labels 
entropy_ones = tf.reduce_sum(tf.mul(labels, entropies)) 
entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies)) 

などの建設をtf.whereを避けることができ、あなたは0と1の数で分割する必要があります

num_ones = tf.reduce_sum(labels) 
num_zeros = tf.reduce_sum(labels_complement) 

もちろん、バッチに1がない場合でも、0で除算する必要はありません。私はtf.cond(tf.equal(num_ones, 0), ...)を使用することをお勧めします。

+0

ありがとうございます。 tf.confの0はtf.constant(0)でなければなりませんか? – Stackd

関連する問題