Upload unknown (jax) trained on MIND-small — 3 seeds
Browse files- README.md +5 -5
- model.safetensors +2 -2
- seed_123/model.safetensors +2 -2
- seed_123/test_results.json +5 -5
- seed_123/training_run_summary.json +75 -58
- seed_42/model.safetensors +2 -2
- seed_42/test_results.json +5 -5
- seed_42/training_run_summary.json +72 -57
- seed_456/model.safetensors +2 -2
- seed_456/test_results.json +5 -5
- seed_456/training_run_summary.json +82 -59
- test_results.json +5 -5
- training_run_summary.json +76 -59
README.md
CHANGED
|
@@ -19,10 +19,10 @@ unknown news recommendation model trained on MIND-small using the
|
|
| 19 |
|
| 20 |
| Seed | AUC | MRR | NDCG@5 | NDCG@10 |
|
| 21 |
|------|-----|-----|--------|---------|
|
| 22 |
-
| 123 | 0.
|
| 23 |
-
| 42 | 0.
|
| 24 |
-
| 456
|
| 25 |
-
| **mean ± std** | **0.
|
| 26 |
|
| 27 |
\* Best seed (weights at repo root)
|
| 28 |
|
|
@@ -30,7 +30,7 @@ unknown news recommendation model trained on MIND-small using the
|
|
| 30 |
|
| 31 |
```
|
| 32 |
newsrex/unknown-JAX-MIND-small/
|
| 33 |
-
├── model.safetensors ← best seed (
|
| 34 |
├── test_results.json
|
| 35 |
├── training_run_summary.json
|
| 36 |
├── seed_123/model.safetensors
|
|
|
|
| 19 |
|
| 20 |
| Seed | AUC | MRR | NDCG@5 | NDCG@10 |
|
| 21 |
|------|-----|-----|--------|---------|
|
| 22 |
+
| 123 * | 0.6747 | 0.3214 | 0.3556 | 0.4193 |
|
| 23 |
+
| 42 | 0.6716 | 0.3195 | 0.3534 | 0.4178 |
|
| 24 |
+
| 456 | 0.6738 | 0.3198 | 0.3547 | 0.4184 |
|
| 25 |
+
| **mean ± std** | **0.6734±0.0013** | **0.3202±0.0009** | **0.3546±0.0009** | **0.4185±0.0006** |
|
| 26 |
|
| 27 |
\* Best seed (weights at repo root)
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
```
|
| 32 |
newsrex/unknown-JAX-MIND-small/
|
| 33 |
+
├── model.safetensors ← best seed (123)
|
| 34 |
├── test_results.json
|
| 35 |
├── training_run_summary.json
|
| 36 |
├── seed_123/model.safetensors
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc1f1e1bc073a357b8f16209c0c38aa2b992377373af01fb3b1f650f99ebebf7
|
| 3 |
+
size 47322396
|
seed_123/model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc1f1e1bc073a357b8f16209c0c38aa2b992377373af01fb3b1f650f99ebebf7
|
| 3 |
+
size 47322396
|
seed_123/test_results.json
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
{
|
| 2 |
-
"loss": 4.
|
| 3 |
-
"auc": 0.
|
| 4 |
-
"mrr": 0.
|
| 5 |
-
"ndcg@5": 0.
|
| 6 |
-
"ndcg@10": 0.
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"loss": 4.895912475360167,
|
| 3 |
+
"auc": 0.6746637875073116,
|
| 4 |
+
"mrr": 0.3214330456909908,
|
| 5 |
+
"ndcg@5": 0.35560859266126893,
|
| 6 |
+
"ndcg@10": 0.4193156331068628,
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
seed_123/training_run_summary.json
CHANGED
|
@@ -14,9 +14,9 @@
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
-
"batch_size":
|
| 18 |
"num_epochs": 20,
|
| 19 |
-
"learning_rate":
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
@@ -35,8 +35,8 @@
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
-
"experiment_name": "jax/MIND-small/
|
| 39 |
-
"wandb_group": "jax/MIND-small/
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
@@ -57,33 +57,49 @@
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
-
"name": "
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
-
"head_dim":
|
| 66 |
-
"attention_hidden_dim": 200
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
"user_encoder": {
|
| 69 |
-
"type": "
|
| 70 |
-
"
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
},
|
| 73 |
"click_predictor": {
|
| 74 |
-
"type": "
|
| 75 |
}
|
| 76 |
},
|
| 77 |
"embedding": {
|
| 78 |
"size": 300,
|
| 79 |
"trainable": true
|
| 80 |
},
|
|
|
|
|
|
|
|
|
|
| 81 |
"dropout_rate": 0.2,
|
| 82 |
"seed": 42
|
| 83 |
},
|
| 84 |
"inputs": {
|
| 85 |
"title": {
|
| 86 |
-
"max_length":
|
| 87 |
},
|
| 88 |
"history": {
|
| 89 |
"max_length": 50
|
|
@@ -91,11 +107,13 @@
|
|
| 91 |
"impressions": {
|
| 92 |
"max_length": 5
|
| 93 |
},
|
|
|
|
| 94 |
"process_title": true,
|
| 95 |
"process_abstract": false,
|
| 96 |
-
"process_category":
|
| 97 |
"process_subcategory": false,
|
| 98 |
-
"process_user_id": false
|
|
|
|
| 99 |
},
|
| 100 |
"training": {
|
| 101 |
"loss": {
|
|
@@ -105,8 +123,8 @@
|
|
| 105 |
"label_smoothing": 0.0
|
| 106 |
},
|
| 107 |
"optimizer": "adam",
|
| 108 |
-
"learning_rate":
|
| 109 |
-
"batch_size":
|
| 110 |
"num_epochs": 20,
|
| 111 |
"gradient_clip_val": 1.0,
|
| 112 |
"grad_accum_steps": 1,
|
|
@@ -117,12 +135,11 @@
|
|
| 117 |
"negative_sampling": {
|
| 118 |
"strategy": "random",
|
| 119 |
"candidates": 4
|
| 120 |
-
}
|
| 121 |
-
"disagreement_beta": 0.0
|
| 122 |
},
|
| 123 |
"evaluation": {
|
| 124 |
"mode": "fast",
|
| 125 |
-
"evaluator": "
|
| 126 |
"metrics": [
|
| 127 |
"auc",
|
| 128 |
"mrr",
|
|
@@ -165,7 +182,7 @@
|
|
| 165 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 166 |
}
|
| 167 |
},
|
| 168 |
-
"max_title_length":
|
| 169 |
"max_abstract_length": 50,
|
| 170 |
"max_history_length": 50,
|
| 171 |
"max_impressions_length": 5,
|
|
@@ -201,10 +218,10 @@
|
|
| 201 |
},
|
| 202 |
"process_title": true,
|
| 203 |
"process_abstract": false,
|
| 204 |
-
"process_category":
|
| 205 |
"process_subcategory": false,
|
| 206 |
"process_user_id": false,
|
| 207 |
-
"process_entities":
|
| 208 |
},
|
| 209 |
"sampling": {
|
| 210 |
"max_impressions_length": 5,
|
|
@@ -233,55 +250,55 @@
|
|
| 233 |
"popularity_metric": "clicks"
|
| 234 |
}
|
| 235 |
},
|
| 236 |
-
"name": "
|
| 237 |
-
"model_name": "
|
| 238 |
-
"_output_run_dir": "outputs/train/MIND-small/
|
| 239 |
},
|
| 240 |
"initial_validation_metrics": {},
|
| 241 |
"best_validation_summary": {
|
| 242 |
"epoch_number": 10.0,
|
| 243 |
-
"train_loss": 1.
|
| 244 |
-
"average_metric_value": 0.
|
| 245 |
-
"val_loss": 4.
|
| 246 |
-
"val_auc": 0.
|
| 247 |
-
"val_mrr": 0.
|
| 248 |
-
"val_ndcg@5": 0.
|
| 249 |
-
"val_ndcg@10": 0.
|
| 250 |
"val_num_impressions": 7824.0,
|
| 251 |
"timing": {
|
| 252 |
"epoch_training_times": [
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
],
|
| 264 |
"epoch_validation_times": [
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
],
|
| 276 |
-
"total_training_time":
|
| 277 |
}
|
| 278 |
},
|
| 279 |
"final_test_metrics": {
|
| 280 |
-
"loss": 4.
|
| 281 |
-
"auc": 0.
|
| 282 |
-
"mrr": 0.
|
| 283 |
-
"ndcg@5": 0.
|
| 284 |
-
"ndcg@10": 0.
|
| 285 |
"num_impressions": 72903.0
|
| 286 |
}
|
| 287 |
}
|
|
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
+
"batch_size": 64,
|
| 18 |
"num_epochs": 20,
|
| 19 |
+
"learning_rate": 5e-05,
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
+
"experiment_name": "jax/MIND-small/CAUM",
|
| 39 |
+
"wandb_group": "jax/MIND-small/CAUM",
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
+
"name": "caum",
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
+
"head_dim": 20,
|
| 66 |
+
"attention_hidden_dim": 200,
|
| 67 |
+
"entity_embedding_dim": 100,
|
| 68 |
+
"entity_num_heads": 4,
|
| 69 |
+
"entity_head_dim": 40,
|
| 70 |
+
"category_embedding_dim": 100
|
| 71 |
},
|
| 72 |
"user_encoder": {
|
| 73 |
+
"type": "candidate_aware",
|
| 74 |
+
"candi_selfatt": {
|
| 75 |
+
"num_heads": 20,
|
| 76 |
+
"head_dim": 20
|
| 77 |
+
},
|
| 78 |
+
"candi_cnn": {
|
| 79 |
+
"half_window": 1
|
| 80 |
+
},
|
| 81 |
+
"candi_att": {
|
| 82 |
+
"hidden_dim": 400,
|
| 83 |
+
"mid_dim": 256
|
| 84 |
+
}
|
| 85 |
},
|
| 86 |
"click_predictor": {
|
| 87 |
+
"type": "dot_product"
|
| 88 |
}
|
| 89 |
},
|
| 90 |
"embedding": {
|
| 91 |
"size": 300,
|
| 92 |
"trainable": true
|
| 93 |
},
|
| 94 |
+
"news_dim": 400,
|
| 95 |
+
"use_entity": true,
|
| 96 |
+
"use_category": true,
|
| 97 |
"dropout_rate": 0.2,
|
| 98 |
"seed": 42
|
| 99 |
},
|
| 100 |
"inputs": {
|
| 101 |
"title": {
|
| 102 |
+
"max_length": 30
|
| 103 |
},
|
| 104 |
"history": {
|
| 105 |
"max_length": 50
|
|
|
|
| 107 |
"impressions": {
|
| 108 |
"max_length": 5
|
| 109 |
},
|
| 110 |
+
"max_entities": 5,
|
| 111 |
"process_title": true,
|
| 112 |
"process_abstract": false,
|
| 113 |
+
"process_category": true,
|
| 114 |
"process_subcategory": false,
|
| 115 |
+
"process_user_id": false,
|
| 116 |
+
"process_entities": true
|
| 117 |
},
|
| 118 |
"training": {
|
| 119 |
"loss": {
|
|
|
|
| 123 |
"label_smoothing": 0.0
|
| 124 |
},
|
| 125 |
"optimizer": "adam",
|
| 126 |
+
"learning_rate": 5e-05,
|
| 127 |
+
"batch_size": 64,
|
| 128 |
"num_epochs": 20,
|
| 129 |
"gradient_clip_val": 1.0,
|
| 130 |
"grad_accum_steps": 1,
|
|
|
|
| 135 |
"negative_sampling": {
|
| 136 |
"strategy": "random",
|
| 137 |
"candidates": 4
|
| 138 |
+
}
|
|
|
|
| 139 |
},
|
| 140 |
"evaluation": {
|
| 141 |
"mode": "fast",
|
| 142 |
+
"evaluator": "caum",
|
| 143 |
"metrics": [
|
| 144 |
"auc",
|
| 145 |
"mrr",
|
|
|
|
| 182 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 183 |
}
|
| 184 |
},
|
| 185 |
+
"max_title_length": 30,
|
| 186 |
"max_abstract_length": 50,
|
| 187 |
"max_history_length": 50,
|
| 188 |
"max_impressions_length": 5,
|
|
|
|
| 218 |
},
|
| 219 |
"process_title": true,
|
| 220 |
"process_abstract": false,
|
| 221 |
+
"process_category": true,
|
| 222 |
"process_subcategory": false,
|
| 223 |
"process_user_id": false,
|
| 224 |
+
"process_entities": true
|
| 225 |
},
|
| 226 |
"sampling": {
|
| 227 |
"max_impressions_length": 5,
|
|
|
|
| 250 |
"popularity_metric": "clicks"
|
| 251 |
}
|
| 252 |
},
|
| 253 |
+
"name": "mind_caum",
|
| 254 |
+
"model_name": "CAUM",
|
| 255 |
+
"_output_run_dir": "outputs/train/MIND-small/CAUM/jax/seed_123"
|
| 256 |
},
|
| 257 |
"initial_validation_metrics": {},
|
| 258 |
"best_validation_summary": {
|
| 259 |
"epoch_number": 10.0,
|
| 260 |
+
"train_loss": 1.2371471108698318,
|
| 261 |
+
"average_metric_value": 0.5206724943439838,
|
| 262 |
+
"val_loss": 4.4820805537700945,
|
| 263 |
+
"val_auc": 0.7497600715485296,
|
| 264 |
+
"val_mrr": 0.39857799672194205,
|
| 265 |
+
"val_ndcg@5": 0.4387125845284695,
|
| 266 |
+
"val_ndcg@10": 0.4956393245769941,
|
| 267 |
"val_num_impressions": 7824.0,
|
| 268 |
"timing": {
|
| 269 |
"epoch_training_times": [
|
| 270 |
+
152.7604115009308,
|
| 271 |
+
128.3657763004303,
|
| 272 |
+
128.71619987487793,
|
| 273 |
+
126.83259797096252,
|
| 274 |
+
126.7269389629364,
|
| 275 |
+
127.33337998390198,
|
| 276 |
+
126.36471319198608,
|
| 277 |
+
127.29102325439453,
|
| 278 |
+
126.6120913028717,
|
| 279 |
+
126.97762560844421
|
| 280 |
],
|
| 281 |
"epoch_validation_times": [
|
| 282 |
+
156.19692087173462,
|
| 283 |
+
156.4893136024475,
|
| 284 |
+
156.90509462356567,
|
| 285 |
+
155.93345546722412,
|
| 286 |
+
156.1876072883606,
|
| 287 |
+
154.95236468315125,
|
| 288 |
+
157.0661187171936,
|
| 289 |
+
156.74578547477722,
|
| 290 |
+
156.08843064308167,
|
| 291 |
+
156.18353486061096
|
| 292 |
],
|
| 293 |
+
"total_training_time": 2860.9791276454926
|
| 294 |
}
|
| 295 |
},
|
| 296 |
"final_test_metrics": {
|
| 297 |
+
"loss": 4.895912475360167,
|
| 298 |
+
"auc": 0.6746637875073116,
|
| 299 |
+
"mrr": 0.3214330456909908,
|
| 300 |
+
"ndcg@5": 0.35560859266126893,
|
| 301 |
+
"ndcg@10": 0.4193156331068628,
|
| 302 |
"num_impressions": 72903.0
|
| 303 |
}
|
| 304 |
}
|
seed_42/model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46661584d959a93125cf9e06b08b854139a5e1b41f9fcaf59f431ddde61d5553
|
| 3 |
+
size 47322396
|
seed_42/test_results.json
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
{
|
| 2 |
-
"loss": 4.
|
| 3 |
-
"auc": 0.
|
| 4 |
-
"mrr": 0.
|
| 5 |
-
"ndcg@5": 0.
|
| 6 |
-
"ndcg@10": 0.
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"loss": 4.885279593137135,
|
| 3 |
+
"auc": 0.6716255651925904,
|
| 4 |
+
"mrr": 0.31951080672544796,
|
| 5 |
+
"ndcg@5": 0.3533953887559558,
|
| 6 |
+
"ndcg@10": 0.417788739745737,
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
seed_42/training_run_summary.json
CHANGED
|
@@ -14,9 +14,9 @@
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
-
"batch_size":
|
| 18 |
"num_epochs": 20,
|
| 19 |
-
"learning_rate":
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
@@ -35,8 +35,8 @@
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
-
"experiment_name": "jax/MIND-small/
|
| 39 |
-
"wandb_group": "jax/MIND-small/
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
@@ -57,33 +57,49 @@
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
-
"name": "
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
-
"head_dim":
|
| 66 |
-
"attention_hidden_dim": 200
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
"user_encoder": {
|
| 69 |
-
"type": "
|
| 70 |
-
"
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
},
|
| 73 |
"click_predictor": {
|
| 74 |
-
"type": "
|
| 75 |
}
|
| 76 |
},
|
| 77 |
"embedding": {
|
| 78 |
"size": 300,
|
| 79 |
"trainable": true
|
| 80 |
},
|
|
|
|
|
|
|
|
|
|
| 81 |
"dropout_rate": 0.2,
|
| 82 |
"seed": 42
|
| 83 |
},
|
| 84 |
"inputs": {
|
| 85 |
"title": {
|
| 86 |
-
"max_length":
|
| 87 |
},
|
| 88 |
"history": {
|
| 89 |
"max_length": 50
|
|
@@ -91,11 +107,13 @@
|
|
| 91 |
"impressions": {
|
| 92 |
"max_length": 5
|
| 93 |
},
|
|
|
|
| 94 |
"process_title": true,
|
| 95 |
"process_abstract": false,
|
| 96 |
-
"process_category":
|
| 97 |
"process_subcategory": false,
|
| 98 |
-
"process_user_id": false
|
|
|
|
| 99 |
},
|
| 100 |
"training": {
|
| 101 |
"loss": {
|
|
@@ -105,8 +123,8 @@
|
|
| 105 |
"label_smoothing": 0.0
|
| 106 |
},
|
| 107 |
"optimizer": "adam",
|
| 108 |
-
"learning_rate":
|
| 109 |
-
"batch_size":
|
| 110 |
"num_epochs": 20,
|
| 111 |
"gradient_clip_val": 1.0,
|
| 112 |
"grad_accum_steps": 1,
|
|
@@ -117,12 +135,11 @@
|
|
| 117 |
"negative_sampling": {
|
| 118 |
"strategy": "random",
|
| 119 |
"candidates": 4
|
| 120 |
-
}
|
| 121 |
-
"disagreement_beta": 0.0
|
| 122 |
},
|
| 123 |
"evaluation": {
|
| 124 |
"mode": "fast",
|
| 125 |
-
"evaluator": "
|
| 126 |
"metrics": [
|
| 127 |
"auc",
|
| 128 |
"mrr",
|
|
@@ -165,7 +182,7 @@
|
|
| 165 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 166 |
}
|
| 167 |
},
|
| 168 |
-
"max_title_length":
|
| 169 |
"max_abstract_length": 50,
|
| 170 |
"max_history_length": 50,
|
| 171 |
"max_impressions_length": 5,
|
|
@@ -201,10 +218,10 @@
|
|
| 201 |
},
|
| 202 |
"process_title": true,
|
| 203 |
"process_abstract": false,
|
| 204 |
-
"process_category":
|
| 205 |
"process_subcategory": false,
|
| 206 |
"process_user_id": false,
|
| 207 |
-
"process_entities":
|
| 208 |
},
|
| 209 |
"sampling": {
|
| 210 |
"max_impressions_length": 5,
|
|
@@ -233,53 +250,51 @@
|
|
| 233 |
"popularity_metric": "clicks"
|
| 234 |
}
|
| 235 |
},
|
| 236 |
-
"name": "
|
| 237 |
-
"model_name": "
|
| 238 |
-
"_output_run_dir": "outputs/train/MIND-small/
|
| 239 |
},
|
| 240 |
"initial_validation_metrics": {},
|
| 241 |
"best_validation_summary": {
|
| 242 |
-
"epoch_number":
|
| 243 |
-
"train_loss": 1.
|
| 244 |
-
"average_metric_value": 0.
|
| 245 |
-
"val_loss": 4.
|
| 246 |
-
"val_auc": 0.
|
| 247 |
-
"val_mrr": 0.
|
| 248 |
-
"val_ndcg@5": 0.
|
| 249 |
-
"val_ndcg@10": 0.
|
| 250 |
"val_num_impressions": 7824.0,
|
| 251 |
"timing": {
|
| 252 |
"epoch_training_times": [
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
62.84117293357849
|
| 262 |
],
|
| 263 |
"epoch_validation_times": [
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
11.277428150177002
|
| 273 |
],
|
| 274 |
-
"total_training_time":
|
| 275 |
}
|
| 276 |
},
|
| 277 |
"final_test_metrics": {
|
| 278 |
-
"loss": 4.
|
| 279 |
-
"auc": 0.
|
| 280 |
-
"mrr": 0.
|
| 281 |
-
"ndcg@5": 0.
|
| 282 |
-
"ndcg@10": 0.
|
| 283 |
"num_impressions": 72903.0
|
| 284 |
}
|
| 285 |
}
|
|
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
+
"batch_size": 64,
|
| 18 |
"num_epochs": 20,
|
| 19 |
+
"learning_rate": 5e-05,
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
+
"experiment_name": "jax/MIND-small/CAUM",
|
| 39 |
+
"wandb_group": "jax/MIND-small/CAUM",
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
+
"name": "caum",
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
+
"head_dim": 20,
|
| 66 |
+
"attention_hidden_dim": 200,
|
| 67 |
+
"entity_embedding_dim": 100,
|
| 68 |
+
"entity_num_heads": 4,
|
| 69 |
+
"entity_head_dim": 40,
|
| 70 |
+
"category_embedding_dim": 100
|
| 71 |
},
|
| 72 |
"user_encoder": {
|
| 73 |
+
"type": "candidate_aware",
|
| 74 |
+
"candi_selfatt": {
|
| 75 |
+
"num_heads": 20,
|
| 76 |
+
"head_dim": 20
|
| 77 |
+
},
|
| 78 |
+
"candi_cnn": {
|
| 79 |
+
"half_window": 1
|
| 80 |
+
},
|
| 81 |
+
"candi_att": {
|
| 82 |
+
"hidden_dim": 400,
|
| 83 |
+
"mid_dim": 256
|
| 84 |
+
}
|
| 85 |
},
|
| 86 |
"click_predictor": {
|
| 87 |
+
"type": "dot_product"
|
| 88 |
}
|
| 89 |
},
|
| 90 |
"embedding": {
|
| 91 |
"size": 300,
|
| 92 |
"trainable": true
|
| 93 |
},
|
| 94 |
+
"news_dim": 400,
|
| 95 |
+
"use_entity": true,
|
| 96 |
+
"use_category": true,
|
| 97 |
"dropout_rate": 0.2,
|
| 98 |
"seed": 42
|
| 99 |
},
|
| 100 |
"inputs": {
|
| 101 |
"title": {
|
| 102 |
+
"max_length": 30
|
| 103 |
},
|
| 104 |
"history": {
|
| 105 |
"max_length": 50
|
|
|
|
| 107 |
"impressions": {
|
| 108 |
"max_length": 5
|
| 109 |
},
|
| 110 |
+
"max_entities": 5,
|
| 111 |
"process_title": true,
|
| 112 |
"process_abstract": false,
|
| 113 |
+
"process_category": true,
|
| 114 |
"process_subcategory": false,
|
| 115 |
+
"process_user_id": false,
|
| 116 |
+
"process_entities": true
|
| 117 |
},
|
| 118 |
"training": {
|
| 119 |
"loss": {
|
|
|
|
| 123 |
"label_smoothing": 0.0
|
| 124 |
},
|
| 125 |
"optimizer": "adam",
|
| 126 |
+
"learning_rate": 5e-05,
|
| 127 |
+
"batch_size": 64,
|
| 128 |
"num_epochs": 20,
|
| 129 |
"gradient_clip_val": 1.0,
|
| 130 |
"grad_accum_steps": 1,
|
|
|
|
| 135 |
"negative_sampling": {
|
| 136 |
"strategy": "random",
|
| 137 |
"candidates": 4
|
| 138 |
+
}
|
|
|
|
| 139 |
},
|
| 140 |
"evaluation": {
|
| 141 |
"mode": "fast",
|
| 142 |
+
"evaluator": "caum",
|
| 143 |
"metrics": [
|
| 144 |
"auc",
|
| 145 |
"mrr",
|
|
|
|
| 182 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 183 |
}
|
| 184 |
},
|
| 185 |
+
"max_title_length": 30,
|
| 186 |
"max_abstract_length": 50,
|
| 187 |
"max_history_length": 50,
|
| 188 |
"max_impressions_length": 5,
|
|
|
|
| 218 |
},
|
| 219 |
"process_title": true,
|
| 220 |
"process_abstract": false,
|
| 221 |
+
"process_category": true,
|
| 222 |
"process_subcategory": false,
|
| 223 |
"process_user_id": false,
|
| 224 |
+
"process_entities": true
|
| 225 |
},
|
| 226 |
"sampling": {
|
| 227 |
"max_impressions_length": 5,
|
|
|
|
| 250 |
"popularity_metric": "clicks"
|
| 251 |
}
|
| 252 |
},
|
| 253 |
+
"name": "mind_caum",
|
| 254 |
+
"model_name": "CAUM",
|
| 255 |
+
"_output_run_dir": "outputs/train/MIND-small/CAUM/jax/seed_42"
|
| 256 |
},
|
| 257 |
"initial_validation_metrics": {},
|
| 258 |
"best_validation_summary": {
|
| 259 |
+
"epoch_number": 8.0,
|
| 260 |
+
"train_loss": 1.2569683923193469,
|
| 261 |
+
"average_metric_value": 0.5173871047050278,
|
| 262 |
+
"val_loss": 4.491243967664317,
|
| 263 |
+
"val_auc": 0.7467889007717999,
|
| 264 |
+
"val_mrr": 0.3941287371779766,
|
| 265 |
+
"val_ndcg@5": 0.43629077698141816,
|
| 266 |
+
"val_ndcg@10": 0.4923400038889164,
|
| 267 |
"val_num_impressions": 7824.0,
|
| 268 |
"timing": {
|
| 269 |
"epoch_training_times": [
|
| 270 |
+
163.71576189994812,
|
| 271 |
+
128.66470193862915,
|
| 272 |
+
124.40946054458618,
|
| 273 |
+
122.63417339324951,
|
| 274 |
+
123.02388978004456,
|
| 275 |
+
124.40860843658447,
|
| 276 |
+
123.64945530891418,
|
| 277 |
+
124.03472113609314
|
|
|
|
| 278 |
],
|
| 279 |
"epoch_validation_times": [
|
| 280 |
+
225.55031847953796,
|
| 281 |
+
155.1456437110901,
|
| 282 |
+
154.2360875606537,
|
| 283 |
+
154.55679368972778,
|
| 284 |
+
155.49029302597046,
|
| 285 |
+
155.03285694122314,
|
| 286 |
+
156.5079698562622,
|
| 287 |
+
172.0624237060547
|
|
|
|
| 288 |
],
|
| 289 |
+
"total_training_time": 2363.861449956894
|
| 290 |
}
|
| 291 |
},
|
| 292 |
"final_test_metrics": {
|
| 293 |
+
"loss": 4.885279593137135,
|
| 294 |
+
"auc": 0.6716255651925904,
|
| 295 |
+
"mrr": 0.31951080672544796,
|
| 296 |
+
"ndcg@5": 0.3533953887559558,
|
| 297 |
+
"ndcg@10": 0.417788739745737,
|
| 298 |
"num_impressions": 72903.0
|
| 299 |
}
|
| 300 |
}
|
seed_456/model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cd4d9a97441772d74ccab684c954e1171a9ecf441fe34670675fde572bea7cd
|
| 3 |
+
size 47322396
|
seed_456/test_results.json
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
{
|
| 2 |
-
"loss": 4.
|
| 3 |
-
"auc": 0.
|
| 4 |
-
"mrr": 0.
|
| 5 |
-
"ndcg@5": 0.
|
| 6 |
-
"ndcg@10": 0.
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"loss": 4.950947674984418,
|
| 3 |
+
"auc": 0.6738278414537523,
|
| 4 |
+
"mrr": 0.3197613372828518,
|
| 5 |
+
"ndcg@5": 0.35466981589059265,
|
| 6 |
+
"ndcg@10": 0.4184225692327627,
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
seed_456/training_run_summary.json
CHANGED
|
@@ -14,9 +14,9 @@
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
-
"batch_size":
|
| 18 |
"num_epochs": 20,
|
| 19 |
-
"learning_rate":
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
@@ -35,8 +35,8 @@
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
-
"experiment_name": "jax/MIND-small/
|
| 39 |
-
"wandb_group": "jax/MIND-small/
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
@@ -57,33 +57,49 @@
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
-
"name": "
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
-
"head_dim":
|
| 66 |
-
"attention_hidden_dim": 200
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
"user_encoder": {
|
| 69 |
-
"type": "
|
| 70 |
-
"
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
},
|
| 73 |
"click_predictor": {
|
| 74 |
-
"type": "
|
| 75 |
}
|
| 76 |
},
|
| 77 |
"embedding": {
|
| 78 |
"size": 300,
|
| 79 |
"trainable": true
|
| 80 |
},
|
|
|
|
|
|
|
|
|
|
| 81 |
"dropout_rate": 0.2,
|
| 82 |
"seed": 42
|
| 83 |
},
|
| 84 |
"inputs": {
|
| 85 |
"title": {
|
| 86 |
-
"max_length":
|
| 87 |
},
|
| 88 |
"history": {
|
| 89 |
"max_length": 50
|
|
@@ -91,11 +107,13 @@
|
|
| 91 |
"impressions": {
|
| 92 |
"max_length": 5
|
| 93 |
},
|
|
|
|
| 94 |
"process_title": true,
|
| 95 |
"process_abstract": false,
|
| 96 |
-
"process_category":
|
| 97 |
"process_subcategory": false,
|
| 98 |
-
"process_user_id": false
|
|
|
|
| 99 |
},
|
| 100 |
"training": {
|
| 101 |
"loss": {
|
|
@@ -105,8 +123,8 @@
|
|
| 105 |
"label_smoothing": 0.0
|
| 106 |
},
|
| 107 |
"optimizer": "adam",
|
| 108 |
-
"learning_rate":
|
| 109 |
-
"batch_size":
|
| 110 |
"num_epochs": 20,
|
| 111 |
"gradient_clip_val": 1.0,
|
| 112 |
"grad_accum_steps": 1,
|
|
@@ -117,12 +135,11 @@
|
|
| 117 |
"negative_sampling": {
|
| 118 |
"strategy": "random",
|
| 119 |
"candidates": 4
|
| 120 |
-
}
|
| 121 |
-
"disagreement_beta": 0.0
|
| 122 |
},
|
| 123 |
"evaluation": {
|
| 124 |
"mode": "fast",
|
| 125 |
-
"evaluator": "
|
| 126 |
"metrics": [
|
| 127 |
"auc",
|
| 128 |
"mrr",
|
|
@@ -165,7 +182,7 @@
|
|
| 165 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 166 |
}
|
| 167 |
},
|
| 168 |
-
"max_title_length":
|
| 169 |
"max_abstract_length": 50,
|
| 170 |
"max_history_length": 50,
|
| 171 |
"max_impressions_length": 5,
|
|
@@ -201,10 +218,10 @@
|
|
| 201 |
},
|
| 202 |
"process_title": true,
|
| 203 |
"process_abstract": false,
|
| 204 |
-
"process_category":
|
| 205 |
"process_subcategory": false,
|
| 206 |
"process_user_id": false,
|
| 207 |
-
"process_entities":
|
| 208 |
},
|
| 209 |
"sampling": {
|
| 210 |
"max_impressions_length": 5,
|
|
@@ -233,55 +250,61 @@
|
|
| 233 |
"popularity_metric": "clicks"
|
| 234 |
}
|
| 235 |
},
|
| 236 |
-
"name": "
|
| 237 |
-
"model_name": "
|
| 238 |
-
"_output_run_dir": "outputs/train/MIND-small/
|
| 239 |
},
|
| 240 |
"initial_validation_metrics": {},
|
| 241 |
"best_validation_summary": {
|
| 242 |
-
"epoch_number":
|
| 243 |
-
"train_loss": 1.
|
| 244 |
-
"average_metric_value": 0.
|
| 245 |
-
"val_loss": 4.
|
| 246 |
-
"val_auc": 0.
|
| 247 |
-
"val_mrr": 0.
|
| 248 |
-
"val_ndcg@5": 0.
|
| 249 |
-
"val_ndcg@10": 0.
|
| 250 |
"val_num_impressions": 7824.0,
|
| 251 |
"timing": {
|
| 252 |
"epoch_training_times": [
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
| 263 |
],
|
| 264 |
"epoch_validation_times": [
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
| 275 |
],
|
| 276 |
-
"total_training_time":
|
| 277 |
}
|
| 278 |
},
|
| 279 |
"final_test_metrics": {
|
| 280 |
-
"loss": 4.
|
| 281 |
-
"auc": 0.
|
| 282 |
-
"mrr": 0.
|
| 283 |
-
"ndcg@5": 0.
|
| 284 |
-
"ndcg@10": 0.
|
| 285 |
"num_impressions": 72903.0
|
| 286 |
}
|
| 287 |
}
|
|
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
+
"batch_size": 64,
|
| 18 |
"num_epochs": 20,
|
| 19 |
+
"learning_rate": 5e-05,
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
+
"experiment_name": "jax/MIND-small/CAUM",
|
| 39 |
+
"wandb_group": "jax/MIND-small/CAUM",
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
+
"name": "caum",
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
+
"head_dim": 20,
|
| 66 |
+
"attention_hidden_dim": 200,
|
| 67 |
+
"entity_embedding_dim": 100,
|
| 68 |
+
"entity_num_heads": 4,
|
| 69 |
+
"entity_head_dim": 40,
|
| 70 |
+
"category_embedding_dim": 100
|
| 71 |
},
|
| 72 |
"user_encoder": {
|
| 73 |
+
"type": "candidate_aware",
|
| 74 |
+
"candi_selfatt": {
|
| 75 |
+
"num_heads": 20,
|
| 76 |
+
"head_dim": 20
|
| 77 |
+
},
|
| 78 |
+
"candi_cnn": {
|
| 79 |
+
"half_window": 1
|
| 80 |
+
},
|
| 81 |
+
"candi_att": {
|
| 82 |
+
"hidden_dim": 400,
|
| 83 |
+
"mid_dim": 256
|
| 84 |
+
}
|
| 85 |
},
|
| 86 |
"click_predictor": {
|
| 87 |
+
"type": "dot_product"
|
| 88 |
}
|
| 89 |
},
|
| 90 |
"embedding": {
|
| 91 |
"size": 300,
|
| 92 |
"trainable": true
|
| 93 |
},
|
| 94 |
+
"news_dim": 400,
|
| 95 |
+
"use_entity": true,
|
| 96 |
+
"use_category": true,
|
| 97 |
"dropout_rate": 0.2,
|
| 98 |
"seed": 42
|
| 99 |
},
|
| 100 |
"inputs": {
|
| 101 |
"title": {
|
| 102 |
+
"max_length": 30
|
| 103 |
},
|
| 104 |
"history": {
|
| 105 |
"max_length": 50
|
|
|
|
| 107 |
"impressions": {
|
| 108 |
"max_length": 5
|
| 109 |
},
|
| 110 |
+
"max_entities": 5,
|
| 111 |
"process_title": true,
|
| 112 |
"process_abstract": false,
|
| 113 |
+
"process_category": true,
|
| 114 |
"process_subcategory": false,
|
| 115 |
+
"process_user_id": false,
|
| 116 |
+
"process_entities": true
|
| 117 |
},
|
| 118 |
"training": {
|
| 119 |
"loss": {
|
|
|
|
| 123 |
"label_smoothing": 0.0
|
| 124 |
},
|
| 125 |
"optimizer": "adam",
|
| 126 |
+
"learning_rate": 5e-05,
|
| 127 |
+
"batch_size": 64,
|
| 128 |
"num_epochs": 20,
|
| 129 |
"gradient_clip_val": 1.0,
|
| 130 |
"grad_accum_steps": 1,
|
|
|
|
| 135 |
"negative_sampling": {
|
| 136 |
"strategy": "random",
|
| 137 |
"candidates": 4
|
| 138 |
+
}
|
|
|
|
| 139 |
},
|
| 140 |
"evaluation": {
|
| 141 |
"mode": "fast",
|
| 142 |
+
"evaluator": "caum",
|
| 143 |
"metrics": [
|
| 144 |
"auc",
|
| 145 |
"mrr",
|
|
|
|
| 182 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 183 |
}
|
| 184 |
},
|
| 185 |
+
"max_title_length": 30,
|
| 186 |
"max_abstract_length": 50,
|
| 187 |
"max_history_length": 50,
|
| 188 |
"max_impressions_length": 5,
|
|
|
|
| 218 |
},
|
| 219 |
"process_title": true,
|
| 220 |
"process_abstract": false,
|
| 221 |
+
"process_category": true,
|
| 222 |
"process_subcategory": false,
|
| 223 |
"process_user_id": false,
|
| 224 |
+
"process_entities": true
|
| 225 |
},
|
| 226 |
"sampling": {
|
| 227 |
"max_impressions_length": 5,
|
|
|
|
| 250 |
"popularity_metric": "clicks"
|
| 251 |
}
|
| 252 |
},
|
| 253 |
+
"name": "mind_caum",
|
| 254 |
+
"model_name": "CAUM",
|
| 255 |
+
"_output_run_dir": "outputs/train/MIND-small/CAUM/jax/seed_456"
|
| 256 |
},
|
| 257 |
"initial_validation_metrics": {},
|
| 258 |
"best_validation_summary": {
|
| 259 |
+
"epoch_number": 13.0,
|
| 260 |
+
"train_loss": 1.2085938783405776,
|
| 261 |
+
"average_metric_value": 0.5244346178265784,
|
| 262 |
+
"val_loss": 4.477954248841599,
|
| 263 |
+
"val_auc": 0.7521172050673615,
|
| 264 |
+
"val_mrr": 0.40180118932638526,
|
| 265 |
+
"val_ndcg@5": 0.44354011060889836,
|
| 266 |
+
"val_ndcg@10": 0.5002799663036684,
|
| 267 |
"val_num_impressions": 7824.0,
|
| 268 |
"timing": {
|
| 269 |
"epoch_training_times": [
|
| 270 |
+
152.42811727523804,
|
| 271 |
+
125.81604671478271,
|
| 272 |
+
127.34339237213135,
|
| 273 |
+
127.5574357509613,
|
| 274 |
+
127.87119317054749,
|
| 275 |
+
129.09366917610168,
|
| 276 |
+
125.57747483253479,
|
| 277 |
+
126.41677451133728,
|
| 278 |
+
131.13811135292053,
|
| 279 |
+
133.3327136039734,
|
| 280 |
+
128.3407974243164,
|
| 281 |
+
126.74892687797546,
|
| 282 |
+
127.14060878753662
|
| 283 |
],
|
| 284 |
"epoch_validation_times": [
|
| 285 |
+
157.19076895713806,
|
| 286 |
+
158.20304155349731,
|
| 287 |
+
155.41660618782043,
|
| 288 |
+
156.2762279510498,
|
| 289 |
+
158.21885561943054,
|
| 290 |
+
157.79540181159973,
|
| 291 |
+
157.7053678035736,
|
| 292 |
+
156.86289072036743,
|
| 293 |
+
196.53685116767883,
|
| 294 |
+
156.58508348464966,
|
| 295 |
+
156.7785358428955,
|
| 296 |
+
155.07468962669373,
|
| 297 |
+
156.94045853614807
|
| 298 |
],
|
| 299 |
+
"total_training_time": 3768.693561077118
|
| 300 |
}
|
| 301 |
},
|
| 302 |
"final_test_metrics": {
|
| 303 |
+
"loss": 4.950947674984418,
|
| 304 |
+
"auc": 0.6738278414537523,
|
| 305 |
+
"mrr": 0.3197613372828518,
|
| 306 |
+
"ndcg@5": 0.35466981589059265,
|
| 307 |
+
"ndcg@10": 0.4184225692327627,
|
| 308 |
"num_impressions": 72903.0
|
| 309 |
}
|
| 310 |
}
|
test_results.json
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
{
|
| 2 |
-
"loss": 4.
|
| 3 |
-
"auc": 0.
|
| 4 |
-
"mrr": 0.
|
| 5 |
-
"ndcg@5": 0.
|
| 6 |
-
"ndcg@10": 0.
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"loss": 4.895912475360167,
|
| 3 |
+
"auc": 0.6746637875073116,
|
| 4 |
+
"mrr": 0.3214330456909908,
|
| 5 |
+
"ndcg@5": 0.35560859266126893,
|
| 6 |
+
"ndcg@10": 0.4193156331068628,
|
| 7 |
"num_impressions": 72903.0
|
| 8 |
}
|
training_run_summary.json
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"configuration": {
|
| 3 |
"framework": "jax",
|
| 4 |
"weights": null,
|
| 5 |
-
"seed":
|
| 6 |
"output_base_dir": "outputs",
|
| 7 |
"device": {
|
| 8 |
"gpu_ids": [
|
|
@@ -14,9 +14,9 @@
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
-
"batch_size":
|
| 18 |
"num_epochs": 20,
|
| 19 |
-
"learning_rate":
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
@@ -35,8 +35,8 @@
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
-
"experiment_name": "jax/MIND-small/
|
| 39 |
-
"wandb_group": "jax/MIND-small/
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
@@ -57,33 +57,49 @@
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
-
"name": "
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
-
"head_dim":
|
| 66 |
-
"attention_hidden_dim": 200
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
"user_encoder": {
|
| 69 |
-
"type": "
|
| 70 |
-
"
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
},
|
| 73 |
"click_predictor": {
|
| 74 |
-
"type": "
|
| 75 |
}
|
| 76 |
},
|
| 77 |
"embedding": {
|
| 78 |
"size": 300,
|
| 79 |
"trainable": true
|
| 80 |
},
|
|
|
|
|
|
|
|
|
|
| 81 |
"dropout_rate": 0.2,
|
| 82 |
"seed": 42
|
| 83 |
},
|
| 84 |
"inputs": {
|
| 85 |
"title": {
|
| 86 |
-
"max_length":
|
| 87 |
},
|
| 88 |
"history": {
|
| 89 |
"max_length": 50
|
|
@@ -91,11 +107,13 @@
|
|
| 91 |
"impressions": {
|
| 92 |
"max_length": 5
|
| 93 |
},
|
|
|
|
| 94 |
"process_title": true,
|
| 95 |
"process_abstract": false,
|
| 96 |
-
"process_category":
|
| 97 |
"process_subcategory": false,
|
| 98 |
-
"process_user_id": false
|
|
|
|
| 99 |
},
|
| 100 |
"training": {
|
| 101 |
"loss": {
|
|
@@ -105,8 +123,8 @@
|
|
| 105 |
"label_smoothing": 0.0
|
| 106 |
},
|
| 107 |
"optimizer": "adam",
|
| 108 |
-
"learning_rate":
|
| 109 |
-
"batch_size":
|
| 110 |
"num_epochs": 20,
|
| 111 |
"gradient_clip_val": 1.0,
|
| 112 |
"grad_accum_steps": 1,
|
|
@@ -117,12 +135,11 @@
|
|
| 117 |
"negative_sampling": {
|
| 118 |
"strategy": "random",
|
| 119 |
"candidates": 4
|
| 120 |
-
}
|
| 121 |
-
"disagreement_beta": 0.0
|
| 122 |
},
|
| 123 |
"evaluation": {
|
| 124 |
"mode": "fast",
|
| 125 |
-
"evaluator": "
|
| 126 |
"metrics": [
|
| 127 |
"auc",
|
| 128 |
"mrr",
|
|
@@ -165,7 +182,7 @@
|
|
| 165 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 166 |
}
|
| 167 |
},
|
| 168 |
-
"max_title_length":
|
| 169 |
"max_abstract_length": 50,
|
| 170 |
"max_history_length": 50,
|
| 171 |
"max_impressions_length": 5,
|
|
@@ -201,10 +218,10 @@
|
|
| 201 |
},
|
| 202 |
"process_title": true,
|
| 203 |
"process_abstract": false,
|
| 204 |
-
"process_category":
|
| 205 |
"process_subcategory": false,
|
| 206 |
"process_user_id": false,
|
| 207 |
-
"process_entities":
|
| 208 |
},
|
| 209 |
"sampling": {
|
| 210 |
"max_impressions_length": 5,
|
|
@@ -233,55 +250,55 @@
|
|
| 233 |
"popularity_metric": "clicks"
|
| 234 |
}
|
| 235 |
},
|
| 236 |
-
"name": "
|
| 237 |
-
"model_name": "
|
| 238 |
-
"_output_run_dir": "outputs/train/MIND-small/
|
| 239 |
},
|
| 240 |
"initial_validation_metrics": {},
|
| 241 |
"best_validation_summary": {
|
| 242 |
"epoch_number": 10.0,
|
| 243 |
-
"train_loss": 1.
|
| 244 |
-
"average_metric_value": 0.
|
| 245 |
-
"val_loss": 4.
|
| 246 |
-
"val_auc": 0.
|
| 247 |
-
"val_mrr": 0.
|
| 248 |
-
"val_ndcg@5": 0.
|
| 249 |
-
"val_ndcg@10": 0.
|
| 250 |
"val_num_impressions": 7824.0,
|
| 251 |
"timing": {
|
| 252 |
"epoch_training_times": [
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
],
|
| 264 |
"epoch_validation_times": [
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
],
|
| 276 |
-
"total_training_time":
|
| 277 |
}
|
| 278 |
},
|
| 279 |
"final_test_metrics": {
|
| 280 |
-
"loss": 4.
|
| 281 |
-
"auc": 0.
|
| 282 |
-
"mrr": 0.
|
| 283 |
-
"ndcg@5": 0.
|
| 284 |
-
"ndcg@10": 0.
|
| 285 |
"num_impressions": 72903.0
|
| 286 |
}
|
| 287 |
}
|
|
|
|
| 2 |
"configuration": {
|
| 3 |
"framework": "jax",
|
| 4 |
"weights": null,
|
| 5 |
+
"seed": 123,
|
| 6 |
"output_base_dir": "outputs",
|
| 7 |
"device": {
|
| 8 |
"gpu_ids": [
|
|
|
|
| 14 |
},
|
| 15 |
"num_workers": 4,
|
| 16 |
"train": {
|
| 17 |
+
"batch_size": 64,
|
| 18 |
"num_epochs": 20,
|
| 19 |
+
"learning_rate": 5e-05,
|
| 20 |
"gradient_clip_val": 1.0,
|
| 21 |
"grad_accum_steps": 1,
|
| 22 |
"early_stopping": {
|
|
|
|
| 35 |
"logging": {
|
| 36 |
"project_name": "NewsReX",
|
| 37 |
"enable_wandb": true,
|
| 38 |
+
"experiment_name": "jax/MIND-small/CAUM",
|
| 39 |
+
"wandb_group": "jax/MIND-small/CAUM",
|
| 40 |
"progress_backend": "tqdm"
|
| 41 |
},
|
| 42 |
"metrics": {
|
|
|
|
| 57 |
},
|
| 58 |
"spec": {
|
| 59 |
"model": {
|
| 60 |
+
"name": "caum",
|
| 61 |
"architecture": {
|
| 62 |
"news_encoder": {
|
| 63 |
"type": "multi_head_self_attention",
|
| 64 |
"num_heads": 20,
|
| 65 |
+
"head_dim": 20,
|
| 66 |
+
"attention_hidden_dim": 200,
|
| 67 |
+
"entity_embedding_dim": 100,
|
| 68 |
+
"entity_num_heads": 4,
|
| 69 |
+
"entity_head_dim": 40,
|
| 70 |
+
"category_embedding_dim": 100
|
| 71 |
},
|
| 72 |
"user_encoder": {
|
| 73 |
+
"type": "candidate_aware",
|
| 74 |
+
"candi_selfatt": {
|
| 75 |
+
"num_heads": 20,
|
| 76 |
+
"head_dim": 20
|
| 77 |
+
},
|
| 78 |
+
"candi_cnn": {
|
| 79 |
+
"half_window": 1
|
| 80 |
+
},
|
| 81 |
+
"candi_att": {
|
| 82 |
+
"hidden_dim": 400,
|
| 83 |
+
"mid_dim": 256
|
| 84 |
+
}
|
| 85 |
},
|
| 86 |
"click_predictor": {
|
| 87 |
+
"type": "dot_product"
|
| 88 |
}
|
| 89 |
},
|
| 90 |
"embedding": {
|
| 91 |
"size": 300,
|
| 92 |
"trainable": true
|
| 93 |
},
|
| 94 |
+
"news_dim": 400,
|
| 95 |
+
"use_entity": true,
|
| 96 |
+
"use_category": true,
|
| 97 |
"dropout_rate": 0.2,
|
| 98 |
"seed": 42
|
| 99 |
},
|
| 100 |
"inputs": {
|
| 101 |
"title": {
|
| 102 |
+
"max_length": 30
|
| 103 |
},
|
| 104 |
"history": {
|
| 105 |
"max_length": 50
|
|
|
|
| 107 |
"impressions": {
|
| 108 |
"max_length": 5
|
| 109 |
},
|
| 110 |
+
"max_entities": 5,
|
| 111 |
"process_title": true,
|
| 112 |
"process_abstract": false,
|
| 113 |
+
"process_category": true,
|
| 114 |
"process_subcategory": false,
|
| 115 |
+
"process_user_id": false,
|
| 116 |
+
"process_entities": true
|
| 117 |
},
|
| 118 |
"training": {
|
| 119 |
"loss": {
|
|
|
|
| 123 |
"label_smoothing": 0.0
|
| 124 |
},
|
| 125 |
"optimizer": "adam",
|
| 126 |
+
"learning_rate": 5e-05,
|
| 127 |
+
"batch_size": 64,
|
| 128 |
"num_epochs": 20,
|
| 129 |
"gradient_clip_val": 1.0,
|
| 130 |
"grad_accum_steps": 1,
|
|
|
|
| 135 |
"negative_sampling": {
|
| 136 |
"strategy": "random",
|
| 137 |
"candidates": 4
|
| 138 |
+
}
|
|
|
|
| 139 |
},
|
| 140 |
"evaluation": {
|
| 141 |
"mode": "fast",
|
| 142 |
+
"evaluator": "caum",
|
| 143 |
"metrics": [
|
| 144 |
"auc",
|
| 145 |
"mrr",
|
|
|
|
| 182 |
"test": "https://huggingface.co/datasets/yjw1029/MIND/resolve/main/MINDlarge_test.zip"
|
| 183 |
}
|
| 184 |
},
|
| 185 |
+
"max_title_length": 30,
|
| 186 |
"max_abstract_length": 50,
|
| 187 |
"max_history_length": 50,
|
| 188 |
"max_impressions_length": 5,
|
|
|
|
| 218 |
},
|
| 219 |
"process_title": true,
|
| 220 |
"process_abstract": false,
|
| 221 |
+
"process_category": true,
|
| 222 |
"process_subcategory": false,
|
| 223 |
"process_user_id": false,
|
| 224 |
+
"process_entities": true
|
| 225 |
},
|
| 226 |
"sampling": {
|
| 227 |
"max_impressions_length": 5,
|
|
|
|
| 250 |
"popularity_metric": "clicks"
|
| 251 |
}
|
| 252 |
},
|
| 253 |
+
"name": "mind_caum",
|
| 254 |
+
"model_name": "CAUM",
|
| 255 |
+
"_output_run_dir": "outputs/train/MIND-small/CAUM/jax/seed_123"
|
| 256 |
},
|
| 257 |
"initial_validation_metrics": {},
|
| 258 |
"best_validation_summary": {
|
| 259 |
"epoch_number": 10.0,
|
| 260 |
+
"train_loss": 1.2371471108698318,
|
| 261 |
+
"average_metric_value": 0.5206724943439838,
|
| 262 |
+
"val_loss": 4.4820805537700945,
|
| 263 |
+
"val_auc": 0.7497600715485296,
|
| 264 |
+
"val_mrr": 0.39857799672194205,
|
| 265 |
+
"val_ndcg@5": 0.4387125845284695,
|
| 266 |
+
"val_ndcg@10": 0.4956393245769941,
|
| 267 |
"val_num_impressions": 7824.0,
|
| 268 |
"timing": {
|
| 269 |
"epoch_training_times": [
|
| 270 |
+
152.7604115009308,
|
| 271 |
+
128.3657763004303,
|
| 272 |
+
128.71619987487793,
|
| 273 |
+
126.83259797096252,
|
| 274 |
+
126.7269389629364,
|
| 275 |
+
127.33337998390198,
|
| 276 |
+
126.36471319198608,
|
| 277 |
+
127.29102325439453,
|
| 278 |
+
126.6120913028717,
|
| 279 |
+
126.97762560844421
|
| 280 |
],
|
| 281 |
"epoch_validation_times": [
|
| 282 |
+
156.19692087173462,
|
| 283 |
+
156.4893136024475,
|
| 284 |
+
156.90509462356567,
|
| 285 |
+
155.93345546722412,
|
| 286 |
+
156.1876072883606,
|
| 287 |
+
154.95236468315125,
|
| 288 |
+
157.0661187171936,
|
| 289 |
+
156.74578547477722,
|
| 290 |
+
156.08843064308167,
|
| 291 |
+
156.18353486061096
|
| 292 |
],
|
| 293 |
+
"total_training_time": 2860.9791276454926
|
| 294 |
}
|
| 295 |
},
|
| 296 |
"final_test_metrics": {
|
| 297 |
+
"loss": 4.895912475360167,
|
| 298 |
+
"auc": 0.6746637875073116,
|
| 299 |
+
"mrr": 0.3214330456909908,
|
| 300 |
+
"ndcg@5": 0.35560859266126893,
|
| 301 |
+
"ndcg@10": 0.4193156331068628,
|
| 302 |
"num_impressions": 72903.0
|
| 303 |
}
|
| 304 |
}
|