igor174 commited on
Commit
d5bf823
·
verified ·
1 Parent(s): cd4e73c

Upload NRMS (pytorch) trained on MIND-small — 3 seeds

Browse files
README.md CHANGED
@@ -19,10 +19,10 @@ NRMS news recommendation model trained on MIND-small using the
19
 
20
  | Seed | AUC | MRR | NDCG@5 | NDCG@10 |
21
  |------|-----|-----|--------|---------|
22
- | 123 | 0.6514 | 0.3041 | 0.3346 | 0.3996 |
23
- | 42 | 0.6503 | 0.3021 | 0.3326 | 0.3981 |
24
- | 456 * | 0.6551 | 0.3060 | 0.3369 | 0.4026 |
25
- | **mean ± std** | **0.6523±0.0021** | **0.3041±0.0016** | **0.3347±0.0018** | **0.4001±0.0019** |
26
 
27
  \* Best seed (weights at repo root)
28
 
 
19
 
20
  | Seed | AUC | MRR | NDCG@5 | NDCG@10 |
21
  |------|-----|-----|--------|---------|
22
+ | 123 | 0.6499 | 0.3024 | 0.3341 | 0.3986 |
23
+ | 42 | 0.6546 | 0.3061 | 0.3379 | 0.4031 |
24
+ | 456 * | 0.6557 | 0.3073 | 0.3382 | 0.4035 |
25
+ | **mean ± std** | **0.6534±0.0025** | **0.3052±0.0021** | **0.3367±0.0019** | **0.4017±0.0022** |
26
 
27
  \* Best seed (weights at repo root)
28
 
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:793b4246af25be037c5d753152156abc1972706c8247134c62d9f3615f3cc74b
3
+ size 31185653
seed_123/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f73339de61aa432e2177e5a8ea4f513cb2e5b7a25d41d6acbf5646152141564
3
+ size 31185653
seed_123/test_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "loss": 4.902909890087249,
3
- "auc": 0.651390856409277,
4
- "mrr": 0.30410659096918746,
5
- "ndcg@5": 0.3345759925405266,
6
- "ndcg@10": 0.39958514718105775,
7
  "num_impressions": 72903.0
8
  }
 
1
  {
2
+ "loss": 4.899622355878045,
3
+ "auc": 0.6499077096637903,
4
+ "mrr": 0.30237477582625777,
5
+ "ndcg@5": 0.33412206372866454,
6
+ "ndcg@10": 0.39860992006739304,
7
  "num_impressions": 72903.0
8
  }
seed_123/training_run_summary.json CHANGED
@@ -236,47 +236,49 @@
236
  },
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
- "epoch_number": 4.0,
240
- "train_loss": 1.3360118222072583,
241
- "average_metric_value": 0.5007627030000833,
242
- "val_loss": 4.577544437914921,
243
- "val_auc": 0.7281870848065012,
244
- "val_mrr": 0.3817817234784785,
245
- "val_ndcg@5": 0.41838593386746503,
246
- "val_ndcg@10": 0.47469606984788854,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
- 257.2602527141571,
251
- 258.0003807544708,
252
- 258.13529324531555,
253
- 257.61827874183655,
254
- 257.9994740486145,
255
- 257.9180815219879,
256
- 257.8408410549164,
257
- 258.0513799190521,
258
- 258.0490565299988
 
259
  ],
260
  "epoch_validation_times": [
261
- 7.482877969741821,
262
- 7.5352623462677,
263
- 6.93285870552063,
264
- 7.12975001335144,
265
- 7.973841667175293,
266
- 6.967895269393921,
267
- 7.366846799850464,
268
- 6.555442810058594,
269
- 7.397475004196167
 
270
  ],
271
- "total_training_time": 2386.75838971138
272
  }
273
  },
274
  "final_test_metrics": {
275
- "loss": 4.902909890087249,
276
- "auc": 0.651390856409277,
277
- "mrr": 0.30410659096918746,
278
- "ndcg@5": 0.3345759925405266,
279
- "ndcg@10": 0.39958514718105775,
280
  "num_impressions": 72903.0
281
  }
282
  }
 
236
  },
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
+ "epoch_number": 5.0,
240
+ "train_loss": 1.3224480016789306,
241
+ "average_metric_value": 0.5042549489402233,
242
+ "val_loss": 4.554530820314837,
243
+ "val_auc": 0.7311774575461522,
244
+ "val_mrr": 0.38428358669989027,
245
+ "val_ndcg@5": 0.42384847309241136,
246
+ "val_ndcg@10": 0.4777102784224392,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
+ 170.52551984786987,
251
+ 170.40193510055542,
252
+ 170.52726674079895,
253
+ 170.7886083126068,
254
+ 170.73736143112183,
255
+ 170.55459928512573,
256
+ 170.63723158836365,
257
+ 170.7861430644989,
258
+ 170.56233382225037,
259
+ 170.56839871406555
260
  ],
261
  "epoch_validation_times": [
262
+ 5.449059963226318,
263
+ 5.667611837387085,
264
+ 5.507464170455933,
265
+ 5.512972116470337,
266
+ 5.491974115371704,
267
+ 5.519476413726807,
268
+ 5.482134103775024,
269
+ 5.5142576694488525,
270
+ 5.708292722702026,
271
+ 5.473921060562134
272
  ],
273
+ "total_training_time": 1761.5992500782013
274
  }
275
  },
276
  "final_test_metrics": {
277
+ "loss": 4.899622355878045,
278
+ "auc": 0.6499077096637903,
279
+ "mrr": 0.30237477582625777,
280
+ "ndcg@5": 0.33412206372866454,
281
+ "ndcg@10": 0.39860992006739304,
282
  "num_impressions": 72903.0
283
  }
284
  }
seed_42/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db0766a02877bd46b12770e7de74de969c35e1a167910cccb67ec259a2b03702
3
+ size 31185653
seed_42/test_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "loss": 4.90704488716887,
3
- "auc": 0.6502633736798069,
4
- "mrr": 0.3021278607642923,
5
- "ndcg@5": 0.3326071769634611,
6
- "ndcg@10": 0.3981011491197833,
7
  "num_impressions": 72903.0
8
  }
 
1
  {
2
+ "loss": 4.942550075925959,
3
+ "auc": 0.6546368368794008,
4
+ "mrr": 0.3060654427094731,
5
+ "ndcg@5": 0.33787584545553123,
6
+ "ndcg@10": 0.4030953465983671,
7
  "num_impressions": 72903.0
8
  }
seed_42/training_run_summary.json CHANGED
@@ -236,51 +236,59 @@
236
  },
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
- "epoch_number": 6.0,
240
- "train_loss": 1.3115188907319253,
241
- "average_metric_value": 0.5044032370354452,
242
- "val_loss": 4.548184821488641,
243
- "val_auc": 0.7310099887877498,
244
- "val_mrr": 0.38385719474540975,
245
- "val_ndcg@5": 0.4235970467917191,
246
- "val_ndcg@10": 0.4791487178169021,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
- 257.9344186782837,
251
- 257.6759557723999,
252
- 257.6570737361908,
253
- 257.4844753742218,
254
- 257.5554881095886,
255
- 257.585813999176,
256
- 258.16601943969727,
257
- 257.6333541870117,
258
- 257.7767117023468,
259
- 257.66410303115845,
260
- 257.29704117774963
 
 
 
 
261
  ],
262
  "epoch_validation_times": [
263
- 6.866185665130615,
264
- 6.339306116104126,
265
- 6.4490647315979,
266
- 7.0740966796875,
267
- 6.54096794128418,
268
- 7.421713590621948,
269
- 7.108036994934082,
270
- 6.734248399734497,
271
- 6.50663161277771,
272
- 7.467777729034424,
273
- 7.561642408370972
 
 
 
 
274
  ],
275
- "total_training_time": 2910.8110024929047
276
  }
277
  },
278
  "final_test_metrics": {
279
- "loss": 4.90704488716887,
280
- "auc": 0.6502633736798069,
281
- "mrr": 0.3021278607642923,
282
- "ndcg@5": 0.3326071769634611,
283
- "ndcg@10": 0.3981011491197833,
284
  "num_impressions": 72903.0
285
  }
286
  }
 
236
  },
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
+ "epoch_number": 10.0,
240
+ "train_loss": 1.273207871465508,
241
+ "average_metric_value": 0.5112104808800277,
242
+ "val_loss": 4.527057873079249,
243
+ "val_auc": 0.7393426322797343,
244
+ "val_mrr": 0.3892933577817272,
245
+ "val_ndcg@5": 0.4302412008559449,
246
+ "val_ndcg@10": 0.48596473260270434,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
+ 170.63399362564087,
251
+ 170.24239301681519,
252
+ 170.309086561203,
253
+ 170.81041073799133,
254
+ 170.47880387306213,
255
+ 170.415442943573,
256
+ 170.50919270515442,
257
+ 170.68091201782227,
258
+ 170.5483238697052,
259
+ 170.63278102874756,
260
+ 170.82863974571228,
261
+ 170.750239610672,
262
+ 170.46901392936707,
263
+ 170.35236501693726,
264
+ 170.57720494270325
265
  ],
266
  "epoch_validation_times": [
267
+ 5.3182737827301025,
268
+ 5.533895492553711,
269
+ 5.35921311378479,
270
+ 5.345299243927002,
271
+ 5.660549163818359,
272
+ 5.416679859161377,
273
+ 5.432182312011719,
274
+ 5.4819183349609375,
275
+ 5.647303342819214,
276
+ 5.431931495666504,
277
+ 5.479611158370972,
278
+ 5.470383882522583,
279
+ 5.629648447036743,
280
+ 5.457347631454468,
281
+ 5.433174133300781
282
  ],
283
+ "total_training_time": 2640.5629494190216
284
  }
285
  },
286
  "final_test_metrics": {
287
+ "loss": 4.942550075925959,
288
+ "auc": 0.6546368368794008,
289
+ "mrr": 0.3060654427094731,
290
+ "ndcg@5": 0.33787584545553123,
291
+ "ndcg@10": 0.4030953465983671,
292
  "num_impressions": 72903.0
293
  }
294
  }
seed_456/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:793b4246af25be037c5d753152156abc1972706c8247134c62d9f3615f3cc74b
3
+ size 31185653
seed_456/test_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "loss": 4.908889115781316,
3
- "auc": 0.6551307987714882,
4
- "mrr": 0.3060355492115324,
5
- "ndcg@5": 0.3369321408156793,
6
- "ndcg@10": 0.4026482962112655,
7
  "num_impressions": 72903.0
8
  }
 
1
  {
2
+ "loss": 4.904600863656334,
3
+ "auc": 0.6556561797524001,
4
+ "mrr": 0.30725266403497975,
5
+ "ndcg@5": 0.3382060685619407,
6
+ "ndcg@10": 0.4034838957378341,
7
  "num_impressions": 72903.0
8
  }
seed_456/training_run_summary.json CHANGED
@@ -237,54 +237,54 @@
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
  "epoch_number": 8.0,
240
- "train_loss": 1.2898840191987677,
241
- "average_metric_value": 0.5102701171090065,
242
- "val_loss": 4.527755548914938,
243
- "val_auc": 0.7371236196185322,
244
- "val_mrr": 0.3894036037225983,
245
- "val_ndcg@5": 0.4292683360191741,
246
- "val_ndcg@10": 0.4852849090757213,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
- 257.83956575393677,
251
- 258.18328881263733,
252
- 258.01654505729675,
253
- 257.8664515018463,
254
- 257.9182119369507,
255
- 257.7222595214844,
256
- 257.910275220871,
257
- 257.8587996959686,
258
- 257.98480582237244,
259
- 257.97685956954956,
260
- 258.06838393211365,
261
- 258.0029225349426,
262
- 258.0660049915314
263
  ],
264
  "epoch_validation_times": [
265
- 6.457681655883789,
266
- 7.62598180770874,
267
- 7.037186145782471,
268
- 6.5134196281433105,
269
- 6.92226767539978,
270
- 7.230839252471924,
271
- 8.4147047996521,
272
- 7.745731830596924,
273
- 7.493134260177612,
274
- 6.909805536270142,
275
- 7.03028416633606,
276
- 6.726389169692993,
277
- 7.046547174453735
278
  ],
279
- "total_training_time": 3447.185839176178
280
  }
281
  },
282
  "final_test_metrics": {
283
- "loss": 4.908889115781316,
284
- "auc": 0.6551307987714882,
285
- "mrr": 0.3060355492115324,
286
- "ndcg@5": 0.3369321408156793,
287
- "ndcg@10": 0.4026482962112655,
288
  "num_impressions": 72903.0
289
  }
290
  }
 
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
  "epoch_number": 8.0,
240
+ "train_loss": 1.2902591521040014,
241
+ "average_metric_value": 0.5108568923254101,
242
+ "val_loss": 4.527366265908255,
243
+ "val_auc": 0.7384775366529492,
244
+ "val_mrr": 0.3897771162612568,
245
+ "val_ndcg@5": 0.43001213182000775,
246
+ "val_ndcg@10": 0.4851607845674266,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
+ 170.56755471229553,
251
+ 170.59103989601135,
252
+ 170.5441117286682,
253
+ 170.66781497001648,
254
+ 170.54219365119934,
255
+ 170.46811938285828,
256
+ 170.70471787452698,
257
+ 171.04492950439453,
258
+ 170.6264250278473,
259
+ 170.76346349716187,
260
+ 170.9272804260254,
261
+ 170.77860307693481,
262
+ 170.67607498168945
263
  ],
264
  "epoch_validation_times": [
265
+ 5.454913139343262,
266
+ 5.704206943511963,
267
+ 5.526165962219238,
268
+ 5.636821746826172,
269
+ 5.698307275772095,
270
+ 5.52191162109375,
271
+ 5.467671871185303,
272
+ 5.518418312072754,
273
+ 5.74641752243042,
274
+ 5.498151063919067,
275
+ 5.47381591796875,
276
+ 5.721930742263794,
277
+ 5.5079920291900635
278
  ],
279
+ "total_training_time": 2291.63307261467
280
  }
281
  },
282
  "final_test_metrics": {
283
+ "loss": 4.904600863656334,
284
+ "auc": 0.6556561797524001,
285
+ "mrr": 0.30725266403497975,
286
+ "ndcg@5": 0.3382060685619407,
287
+ "ndcg@10": 0.4034838957378341,
288
  "num_impressions": 72903.0
289
  }
290
  }
test_results.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
- "loss": 4.908889115781316,
3
- "auc": 0.6551307987714882,
4
- "mrr": 0.3060355492115324,
5
- "ndcg@5": 0.3369321408156793,
6
- "ndcg@10": 0.4026482962112655,
7
  "num_impressions": 72903.0
8
  }
 
1
  {
2
+ "loss": 4.904600863656334,
3
+ "auc": 0.6556561797524001,
4
+ "mrr": 0.30725266403497975,
5
+ "ndcg@5": 0.3382060685619407,
6
+ "ndcg@10": 0.4034838957378341,
7
  "num_impressions": 72903.0
8
  }
training_run_summary.json CHANGED
@@ -237,54 +237,54 @@
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
  "epoch_number": 8.0,
240
- "train_loss": 1.2898840191987677,
241
- "average_metric_value": 0.5102701171090065,
242
- "val_loss": 4.527755548914938,
243
- "val_auc": 0.7371236196185322,
244
- "val_mrr": 0.3894036037225983,
245
- "val_ndcg@5": 0.4292683360191741,
246
- "val_ndcg@10": 0.4852849090757213,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
- 257.83956575393677,
251
- 258.18328881263733,
252
- 258.01654505729675,
253
- 257.8664515018463,
254
- 257.9182119369507,
255
- 257.7222595214844,
256
- 257.910275220871,
257
- 257.8587996959686,
258
- 257.98480582237244,
259
- 257.97685956954956,
260
- 258.06838393211365,
261
- 258.0029225349426,
262
- 258.0660049915314
263
  ],
264
  "epoch_validation_times": [
265
- 6.457681655883789,
266
- 7.62598180770874,
267
- 7.037186145782471,
268
- 6.5134196281433105,
269
- 6.92226767539978,
270
- 7.230839252471924,
271
- 8.4147047996521,
272
- 7.745731830596924,
273
- 7.493134260177612,
274
- 6.909805536270142,
275
- 7.03028416633606,
276
- 6.726389169692993,
277
- 7.046547174453735
278
  ],
279
- "total_training_time": 3447.185839176178
280
  }
281
  },
282
  "final_test_metrics": {
283
- "loss": 4.908889115781316,
284
- "auc": 0.6551307987714882,
285
- "mrr": 0.3060355492115324,
286
- "ndcg@5": 0.3369321408156793,
287
- "ndcg@10": 0.4026482962112655,
288
  "num_impressions": 72903.0
289
  }
290
  }
 
237
  "initial_validation_metrics": {},
238
  "best_validation_summary": {
239
  "epoch_number": 8.0,
240
+ "train_loss": 1.2902591521040014,
241
+ "average_metric_value": 0.5108568923254101,
242
+ "val_loss": 4.527366265908255,
243
+ "val_auc": 0.7384775366529492,
244
+ "val_mrr": 0.3897771162612568,
245
+ "val_ndcg@5": 0.43001213182000775,
246
+ "val_ndcg@10": 0.4851607845674266,
247
  "val_num_impressions": 7824.0,
248
  "timing": {
249
  "epoch_training_times": [
250
+ 170.56755471229553,
251
+ 170.59103989601135,
252
+ 170.5441117286682,
253
+ 170.66781497001648,
254
+ 170.54219365119934,
255
+ 170.46811938285828,
256
+ 170.70471787452698,
257
+ 171.04492950439453,
258
+ 170.6264250278473,
259
+ 170.76346349716187,
260
+ 170.9272804260254,
261
+ 170.77860307693481,
262
+ 170.67607498168945
263
  ],
264
  "epoch_validation_times": [
265
+ 5.454913139343262,
266
+ 5.704206943511963,
267
+ 5.526165962219238,
268
+ 5.636821746826172,
269
+ 5.698307275772095,
270
+ 5.52191162109375,
271
+ 5.467671871185303,
272
+ 5.518418312072754,
273
+ 5.74641752243042,
274
+ 5.498151063919067,
275
+ 5.47381591796875,
276
+ 5.721930742263794,
277
+ 5.5079920291900635
278
  ],
279
+ "total_training_time": 2291.63307261467
280
  }
281
  },
282
  "final_test_metrics": {
283
+ "loss": 4.904600863656334,
284
+ "auc": 0.6556561797524001,
285
+ "mrr": 0.30725266403497975,
286
+ "ndcg@5": 0.3382060685619407,
287
+ "ndcg@10": 0.4034838957378341,
288
  "num_impressions": 72903.0
289
  }
290
  }