2016-04-18 9 views
0

私はケラを使用してリコメンダモデルを作成しています。アイテムセットがかなり大きいので、正確度の尺度としてHits @ Nメトリックを計算したいと思います。すなわち、観測されたアイテムが予測された上位Nにある場合、それは関連する推奨としてカウントされる。ヒットを計算するTheanoのメトリック

numpyを使用してN関数でヒットを構築できました。しかし、私はkerasのカスタム損失関数に移植しようとしているので、私はテンソルに問題があります。具体的には、テンソルでの列挙は異なります。同等のものを見つけるために構文を調べると、私は全体のアプローチに疑問を持ち始めました。私の一般的なPythonの親しみやすさを反映して、ゆるやかで遅いです。

def hits_at(y_true, y_pred): #numpy version 
    a=y_pred.argsort(axis=1) #ascending, sort by row, return index 
    a = np.fliplr(a) #reverse to get descending 
    a = a[:,0:10] #return only the first 10 columns of each row 
    Ybool = [] #initialze 2D arrray 
    for t, idx in enumerate(a): 
     ybool = np.zeros(num_items +1) #zero fill; 0 index is reserved 
     ybool[idx] = 1 #flip the recommended item from 0 to 1 
     Ybool.append(ybool) 
    A = map(lambda t: list(t), Ybool) 
    right_sum = (A * y_true).max(axis=1) #element-wise multiplication, then find the max 
    right_sum = right_sum.sum() #how many times did we score a hit? 
    return right_sum/len(y_true) #fraction of observations where we scored a hit 

私はこれをよりコンパクトでテンソルに優しい方法でどのようにアプローチすればよいですか?

更新: トップ1のバージョンを取得できました。 GRU4Recの説明にゆるやかに基づいています。 def custom_objective(y_true, y_pred): y_pred_idx_sort = T.argsort(-y_pred, axis=1)[:,0] #returns the first element, which is the index of the row with the largest value y_act_idx = T.argmax(y_true, axis=1)#returns an array of indexes with the top value return T.cast(-T.mean(T.nnet.sigmoid((T.eq(y_pred_idx_sort,y_act_idx)))), theano.config.floatX)

私は、上位1つの予測の配列を要素ごとに実際の配列と比較するだけでした。そしてTheanoにはeq()関数があります。

答えて

0

Nとは無関係に、損失関数の可能な値の数は有限です。したがって、それは賢明なテンソルの方法で微分することはできませんし、Keras/Theanoの損失関数として使用することはできません。あなたはトップNの男とのtheanoログの損失を使用しようとする可能性があります。

更新:

ケラスで - あなた自身の損失機能を書くことができます。彼らは、フォームの宣言を持っている:

def loss_function(y_pred, y_true): 

y_truey_pred両方がnumpyのアレイである、あなたはイースリー所与の例では、別段の上面500及び0であるとき1であるベクトルvを取得してもよいです。次に、それをテナー定数のベクトルに変換し、それを適用します。

return theano.tensor.net.binary_crossentropy(y_pred * v, y_true * v) 

これは正しく動作するはずです。

UPDATE 2:

ログの損失が何binary_crossentropy同じものです。

+0

私は現在categorical_crossentropyを使用しています。私は "トップNのログ損失"に慣れていない。それは何らかの形でhttp://deeplearning.net/tutorial/logreg.html#the-modelを拡張することですか? –

+0

私のコメント –

+0

Hrmを更新しました.y_predとy_trueは実際にはテンソルで、numpy配列ではないようです。ドミナントを見つけてそれらを反復することは、数が少ない配列に対して同じことをするより複雑です。 –

関連する問題