2017-10-11 9 views
1

私は今スタンフォードcs231nコースを学習しています。 softmax_loss関数を完了すると、フルベクトル化された型、特にdwという用語を扱うのは簡単ではないことがわかりました。以下は私のコードです。誰かがコードを最適化できますか?感謝します。softmax_loss関数:ループを行列演算に変換

def softmax_loss_vectorized(W, X, y, reg): 

    loss = 0.0 
    dW = np.zeros_like(W) 


    num_train = X.shape[0] 
    num_classes = W.shape[1] 

    scores = X.dot(W) 
    scores -= np.max(scores, axis = 1)[:, np.newaxis] 
    exp_scores = np.exp(scores) 
    sum_exp_scores = np.sum(exp_scores, axis = 1) 
    correct_class_score = scores[range(num_train), y] 

    loss = np.sum(np.log(sum_exp_scores)) - np.sum(correct_class_score) 

    exp_scores = exp_scores/sum_exp_scores[:,np.newaxis] 

    # **maybe here can be rewroten into matrix operations** 
    for i in xrange(num_train): 
    dW += exp_scores[i] * X[i][:,np.newaxis] 
    dW[:, y[i]] -= X[i] 

    loss /= num_train 
    loss += 0.5 * reg * np.sum(W*W) 
    dW /= num_train 
    dW += reg * W 


    return loss, dW 

答えて

0

以下は、ベクトル化された実装です。しかし、私は少し時間を費やして自分自身で解決策を試してみることをお勧めします。考え方はすべてのsoftmax値を持つ行列を構築し、正しい要素から-1を引くことです。

def softmax_loss_vectorized(W, X, y, reg): 
    num_train = X.shape[0] 

    scores = X.dot(W) 
    scores -= np.max(scores) 
    correct_scores = scores[np.arange(num_train), y] 

    # Compute the softmax per correct scores in bulk, and sum over its logs. 
    exponents = np.exp(scores) 
    sums_per_row = np.sum(exponents, axis=1) 
    softmax_array = np.exp(correct_scores)/sums_per_row 
    information_array = -np.log(softmax_array) 
    loss = np.mean(information_array) 

    # Compute the softmax per whole scores matrix, which gives the matrix for X rows coefficients. 
    # Their linear combination is algebraically dot product X transpose. 
    all_softmax_matrix = (exponents.T/sums_per_row).T 
    grad_coeff = np.zeros_like(scores) 
    grad_coeff[np.arange(num_train), y] = -1 
    grad_coeff += all_softmax_matrix 
    dW = np.dot(X.T, grad_coeff)/num_train 

    # Regularization 
    loss += 0.5 * reg * np.sum(W * W) 
    dW += reg * W 

    return loss, dW 
+0

ご回答いただきありがとうございます。私はまだあなたの提案ごとにそれを学んでいます。 – luoshao23

関連する問題