Project 2

This commit is contained in:
unknown 2021-06-09 22:12:22 +02:00
parent 4499231c79
commit fa6ba669c8

View File

@ -67,8 +67,8 @@ def evaluate(test,
def ranking_metrics(test_ui, reco, super_reactions=[], topK=10):
nb_items=test_ui.shape[1]
relevant_users, super_relevant_users, prec, rec, F_1, F_05, prec_super, rec_super, ndcg, mAP, MRR, LAUC, HR=\
0,0,0,0,0,0,0,0,0,0,0,0,0
relevant_users, super_relevant_users, prec, rec, F_1, F_05, prec_super, rec_super, ndcg, mAP, MRR, LAUC, HR, HitRate2, HitRate3=\
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
cg = (1.0 / np.log2(np.arange(2, topK + 2)))
cg_sum = np.cumsum(cg)
@ -121,7 +121,10 @@ def ranking_metrics(test_ui, reco, super_reactions=[], topK=10):
((nb_items-nb_u_rated_items)*nb_u_rated_items)
HR+=nb_user_successes>0
HitRate2+=nb_user_successes>1
HitRate3+=nb_user_successes>2
result=[]
result.append(('precision', prec/relevant_users))
@ -135,6 +138,9 @@ def ranking_metrics(test_ui, reco, super_reactions=[], topK=10):
result.append(('MRR', MRR/relevant_users))
result.append(('LAUC', LAUC/relevant_users))
result.append(('HR', HR/relevant_users))
result.append(('HitRate2', HitRate2/relevant_users))
result.append(('HitRate3', HitRate3/relevant_users))
df_result=pd.DataFrame()
if len(result)>0: