2017-11-19 11 views
4

私はニューラルネットワークを学習しています。私は関数cross_entropyをpythonで書いています。それはクロスエントロピー関数(python)

cross entropyNはサンプルの数である

kと定義されるクラスの数、logは自然対数であり、試料iが他クラスj0であればt_i,jは1でありますp_i,jは、サンプルiがクラスjにある予測された確率です。 対数の数値問題を回避するには、予測値を[10^{−12}, 1 − 10^{−12}]の範囲にクリップします。

上記の説明によれば、私は、予測をclippintでコードを​​の範囲に書き留め、次に上記の式に基づいてクロスエントロピーを計算します。

def cross_entropy(predictions, targets, epsilon=1e-12): 
    """ 
    Computes cross entropy between targets (encoded as one-hot vectors) 
    and predictions. 
    Input: predictions (N, k) ndarray 
      targets (N, k) ndarray   
    Returns: scalar 
    """ 
    predictions = np.clip(predictions, epsilon, 1. - epsilon) 
    ce = - np.mean(np.log(predictions) * targets) 
    return ce 

次のコードは、機能がcross_entropyであるかどうかを確認するために使用されます。

predictions = np.array([[0.25,0.25,0.25,0.25], 
         [0.01,0.01,0.01,0.96]]) 
targets = np.array([[0,0,0,1], 
        [0,0,0,1]]) 
ans = 0.71355817782 #Correct answer 
x = cross_entropy(predictions, targets) 
print(np.isclose(x,ans)) 

上記のコードの出力は、それがcross_entropyが正しくない関数を定義するための私のコードを言うために、Falseです。次に、cross_entropy(predictions, targets)の結果を出力します。それは0.178389544455を与え、正しい結果はans = 0.71355817782であるはずです。誰かが私のコードの問題点を確認するのを手伝ってもらえますか?

答えて

5

あなたはそれほど離れているわけではありませんが、N個の合計の平均値を取っていることを覚えておいてください。ここではN = 2です。したがって、あなたのコードは以下のように読むことができます:

def cross_entropy(predictions, targets, epsilon=1e-12): 
    """ 
    Computes cross entropy between targets (encoded as one-hot vectors) 
    and predictions. 
    Input: predictions (N, k) ndarray 
      targets (N, k) ndarray   
    Returns: scalar 
    """ 
    predictions = np.clip(predictions, epsilon, 1. - epsilon) 
    N = predictions.shape[0] 
    ce = -np.sum(np.sum(targets*np.log(predictions+1e-9)))/N 
    return ce 

predictions = np.array([[0.25,0.25,0.25,0.25], 
         [0.01,0.01,0.01,0.96]]) 
targets = np.array([[0,0,0,1], 
        [0,0,0,1]]) 
ans = 0.71355817782 #Correct answer 
x = cross_entropy(predictions, targets) 
print(np.isclose(x,ans)) 

ここでは、np.sum()に固執すれば少しはっきりしていると思います。また、計算にlog(0)を持つ可能性を避けるために、1e-9をnp.log()に追加しました。お役に立てれば!

関連する問題