4

私は、特定のユーザーに製品を提案する協調フィルタリングアルゴリズムを作成しようとしています。TensorFlowの推奨システム(SVD)

私はまもなく開始し、TensorFlowで作業を開始しました(これは十分効果的で柔軟性があると思いました)。 私は、私が興味を持ってる何がこのコードを発見したモデルを作成し、ユーザID、製品、および評価を訓練:https://github.com/songgc/TF-recomm

私はコードを開始し、モデルを訓練しました。

モデルをトレーニングした後、予測を行う必要があります。つまり、NODE.jsアプリケーションでアクセスするDBに保存できるように、各ユーザーの提案を取得する必要があります。

トレーニング終了時に、このユーザーのリストを取得するにはどうすればよいですか?

if __name__ == '__main__': 
    df_train, df_test=get_data() 
    svd(df_train, df_test) 
    print("Done!") 

答えて

1

コードの予測部分を変更して、top K推奨製品を出力する必要があります。予測が行われる現在のコードがある:ここembed_user

embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user") 
embd_item = tf.nn.embedding_lookup(w_item, item_batch, name="embedding_item") 
infer = tf.reduce_sum(tf.multiply(embd_user, embd_item), 1) 

は、特定のユーザのユーザ埋め込みあり、embd_itemは、特定のアイテムのためのものです。したがって、particular userparticular itemを比較する代わりに、すべての項目と比較するために変更する必要があります。マトリックスw_itemは、すべてのアイテムの埋め込みです。これはして行うことができます。

embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user") 
# Multiply user embedding of shape: [1 x dim] 
# with every item embeddings of shape: [item_num, dim], 
# to produce rank of all items of shape: [item_num] 
predict = tf.matmul(embd_user, w_item, transpose_b=True) 

次に、あなたが予測出力で最大のtop kインデックスを選択することができます。

+0

おいしいです。 int32の代わりに英数字のidを持つDBデータを使用すると、embedding_lookupでエラーが発生します。エラーは次のようになります。 "TypeError:パラメータ 'index'に渡された値に許容値のリストでないDataType文字列があります。どうすれば修正できますか? –

+0

エラーがどの行に表示されますか? 「DBデータ」とは何ですか? embedding_lookupでは、このエラーの原因となっている入力は何ですか? –

+0

エラーを引き起こす行の1つ: "embd_user = tf.nn.embedding_lookup(w_user、user_batch、name =" embedding_user ")"。 これは、データベースからロードされるIDが英数字であり、int32ではないためです。https://www.dropbox.com/s/9s4vxsciue3mu38/Schermata%202017-07-08%20alle%2010.32.48.png?dl=0 。 ユーザバッチは、タイプtf.stringを持っています: "user_batch = tf.placeholder(tf.string、shape = [None]、name =" id_user ")" –

1

あなたは、すべてのアイテムIDのすべてのユーザーIDと項目を意味し、predict_resultは、すべての項目について、各ユーザのスコアである、あなたがDBにpredict_resultを格納できるユーザー

predict_result = sess.run(inter_op, feed_dict={user_batch:users, item_batch:items}) 

を実行することができます。

+0

ユーザーごとに推奨製品を10個見つけなければならない場合は、すべての製品をすべてのユーザーと手動で組み合わせてトップ10を見つけなければなりませんか?より効率的な方法はありませんか? predict_result = sess.run(infer、feed_dict = {user_batch:[users [0]、users [0]、users [0] .........]、item_batch:[items [0]、items [1]、items [2] ................]}) –