zirobtc commited on
Commit
c471f42
·
verified ·
1 Parent(s): 7064310

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. TASK_LIST.md +87 -0
  2. cache_dataset.py +225 -58
  3. data/data_loader.py +154 -16
  4. log.log +2 -2
  5. resume.md +190 -0
  6. scripts/evaluate_sample.py +77 -3
  7. scripts/rebuild_metadata.py +15 -1
  8. train.py +77 -45
TASK_LIST.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Task List
2
+
3
+ Fix validation splitting by token identity
4
+ Replace class-only split logic in train.py (line 235)
5
+ Group cached context samples by source_token or token_address
6
+ Ensure one token can only exist in train or val, never both
7
+ Keep class balance as a secondary constraint, not the primary identity rule
8
+ Stop using current validation as decision-grade signal
9
+ Treat old val curves/checkpoints as contaminated
10
+ Re-evaluate only after token-grouped split is in place
11
+ Audit cache metadata and make token identity explicit
12
+ Ensure every cached sample has stable token identity fields
13
+ Required minimum: source_token, class_id
14
+ Prefer also storing lightweight cache-planning metadata for later analysis
15
+ Redesign cache generation around fixed budgets
16
+ Define total cache budget first
17
+ Allocate exact sample counts per token class before writing files
18
+ Do not let raw source distribution decide cache composition
19
+ Remove destructive dependence on token class map filtering alone
20
+ Token class should guide budget allocation
21
+ It should not be the only logic determining whether cache is useful
22
+ Add cache-time context-level balancing
23
+ After sampling a candidate context, evaluate realized future labels for that context
24
+ Use realized context outcome to decide whether to keep or skip it
25
+ Do this before saving to disk
26
+ Start with binary polarity, not movement-threshold balancing
27
+ Positive if max valid horizon return > 0
28
+ Negative otherwise
29
+ Use this only as cache-selection metadata first
30
+ Make polarity quotas class-conditional
31
+ For stronger classes, target positive/negative ratios
32
+ For garbage classes, do not force positives
33
+ Keep class 0 mostly natural/negative
34
+ Keep T_cutoff random during cache generation
35
+ Do not freeze a single deterministic cutoff per token
36
+ Determinism should be in the planning/budget logic, not in removing context diversity
37
+ Add exact acceptance accounting during cache build
38
+ Track how many samples have already been accepted per class
39
+ Track polarity counts per class
40
+ Stop accepting once quotas are filled
41
+ Avoid cache waste from duplicate low-value contexts
42
+ Add retry/attempt limits per token
43
+ If a token cannot satisfy desired quota type, stop oversampling it endlessly
44
+ Move on to other tokens instead of filling disk with junk
45
+ Keep label derivation in the data pipeline, not in training logic
46
+ Loader should produce final labels and masks
47
+ Collator should only stack/batch them
48
+ Model should only consume them
49
+ Reduce or remove train-time class reweighting after cache is fixed
50
+ Revisit WeightedRandomSampler
51
+ Revisit class_loss_weights
52
+ If cache is balanced upstream, training should not need heavy rescue weighting
53
+ Revisit movement head only after split and cache are fixed
54
+ Keep it auxiliary
55
+ Do not let movement-label threshold debates block the more important data fixes
56
+ Later simplify naming/threshold assumptions if needed
57
+ Add cache audit tooling
58
+ Report counts by class_id
59
+ Report counts by class x polarity
60
+ Report unique tokens by class
61
+ Report acceptance/rejection reasons
62
+ Report train/val token overlap check
63
+ Add validation integrity checks
64
+ Assert zero token overlap between train and val
65
+ Print per-class token counts, not just sample counts
66
+ Print per-class sample counts too
67
+ Rebuild cache after the new policy is implemented
68
+ Old cache is shaped by the wrong distribution
69
+ Old validation split is not trustworthy
70
+ New training should start from the rebuilt corpus
71
+ Retrain and re-baseline from scratch
72
+ New split
73
+ New cache
74
+ Minimal train-time rescue weighting
75
+ Recompare backbone behavior only after that
76
+ Recommended implementation order
77
+
78
+ Token-grouped validation split
79
+ Validation overlap checks
80
+ Cache metadata cleanup
81
+ Exact class quotas in cache generation
82
+ Class-conditional polarity quotas
83
+ Cache audit reports
84
+ Remove/reduce train-time weighting
85
+ Rebuild cache
86
+ Retrain
87
+ Reassess movement head
cache_dataset.py CHANGED
@@ -23,6 +23,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
24
  from scripts.analyze_distribution import get_return_class_map
25
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
 
26
 
27
  from clickhouse_driver import Client as ClickHouseClient
28
  from neo4j import GraphDatabase
@@ -32,6 +33,61 @@ _worker_return_class_map = None
32
  _worker_quality_scores_map = None
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
36
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
37
  from data.data_loader import OracleDataset
@@ -43,7 +99,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
43
 
44
  _worker_dataset = OracleDataset(
45
  data_fetcher=data_fetcher,
46
- max_samples=dataset_config['max_samples'],
47
  start_date=dataset_config['start_date'],
48
  horizons_seconds=dataset_config['horizons_seconds'],
49
  quantiles=dataset_config['quantiles'],
@@ -68,42 +124,15 @@ def _process_single_token_context(args):
68
  q_score = _worker_quality_scores_map.get(mint_addr)
69
  if q_score is None:
70
  return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
71
- saved_files = []
72
- for ctx_idx, ctx in enumerate(contexts):
73
- ctx["quality_score"] = q_score
74
- ctx["class_id"] = class_id
75
- ctx["source_token"] = mint_addr
76
- ctx["cache_mode"] = "context"
77
- filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
78
- output_path = Path(output_dir) / filename
79
- torch.save(ctx, output_path)
80
- saved_files.append(filename)
81
- return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
82
- except Exception as e:
83
- import traceback
84
- return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
85
-
86
-
87
- def _process_single_token_raw(args):
88
- idx, mint_addr, output_dir = args
89
- global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
90
- try:
91
- class_id = _worker_return_class_map.get(mint_addr)
92
- if class_id is None:
93
- return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
94
- item = _worker_dataset.__cacheitem__(idx)
95
- if item is None:
96
- return {'status': 'skipped', 'reason': 'cacheitem returned None', 'mint': mint_addr}
97
- q_score = _worker_quality_scores_map.get(mint_addr)
98
- if q_score is None:
99
- return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
100
- item["quality_score"] = q_score
101
- item["class_id"] = class_id
102
- item["cache_mode"] = "raw"
103
- filename = f"sample_{mint_addr[:16]}.pt"
104
- output_path = Path(output_dir) / filename
105
- torch.save(item, output_path)
106
- return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_trades': len(item.get('trades', [])), 'files': [filename]}
107
  except Exception as e:
108
  import traceback
109
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
@@ -122,14 +151,16 @@ def main():
122
 
123
  parser = argparse.ArgumentParser()
124
  parser.add_argument("--output_dir", type=str, default="data/cache")
125
- parser.add_argument("--max_samples", type=int, default=None)
126
  parser.add_argument("--start_date", type=str, default=None)
127
 
128
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
129
- parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
130
- parser.add_argument("--context_length", type=int, default=8192)
131
  parser.add_argument("--min_trades", type=int, default=10)
 
132
  parser.add_argument("--samples_per_token", type=int, default=1)
 
 
 
 
133
  parser.add_argument("--num_workers", type=int, default=1)
134
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
135
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
@@ -138,6 +169,11 @@ def main():
138
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
139
  args = parser.parse_args()
140
 
 
 
 
 
 
141
  if args.num_workers == 0:
142
  args.num_workers = max(1, mp.cpu_count() - 4)
143
 
@@ -163,7 +199,15 @@ def main():
163
  quality_scores_map = get_token_quality_scores(clickhouse_client)
164
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
165
 
166
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
 
 
 
 
 
 
 
 
167
 
168
  if len(dataset) == 0:
169
  print("WARNING: No samples. Exiting.")
@@ -178,25 +222,52 @@ def main():
178
  print("WARNING: No tokens after filtering.")
179
  return
180
 
181
- print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
 
 
 
 
 
 
 
 
182
 
183
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
184
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
185
 
186
  # Build tasks from filtered_mints directly
187
  tasks = []
188
  for i, mint_record in enumerate(filtered_mints):
189
  mint_addr = mint_record['mint_address']
190
- if args.cache_mode == "context":
191
- tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
192
- else:
193
- tasks.append((i, mint_addr, str(output_dir)))
194
 
195
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
196
 
197
  success_count, skipped_count, error_count = 0, 0, 0
198
  class_distribution = {}
199
- process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  if args.num_workers == 1:
202
  print("INFO: Single-threaded mode...")
@@ -204,8 +275,61 @@ def main():
204
  for task in tqdm(tasks, desc="Caching"):
205
  result = process_fn(task)
206
  if result['status'] == 'success':
207
- success_count += 1
208
- class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  elif result['status'] == 'skipped':
210
  skipped_count += 1
211
  else:
@@ -219,8 +343,26 @@ def main():
219
  try:
220
  result = future.result(timeout=300)
221
  if result['status'] == 'success':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  success_count += 1
223
- class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
224
  elif result['status'] == 'skipped':
225
  skipped_count += 1
226
  else:
@@ -230,15 +372,40 @@ def main():
230
  tqdm.write(f"WORKER ERROR: {e}")
231
 
232
  print("INFO: Building metadata...")
233
- file_class_map = {}
234
- for f in sorted(output_dir.glob("sample_*.pt")):
235
- try:
236
- file_class_map[f.name] = torch.load(f, map_location="cpu", weights_only=False).get("class_id", 0)
237
- except:
238
- pass
 
 
 
 
 
239
 
240
  with open(output_dir / "class_metadata.json", 'w') as f:
241
- json.dump({'file_class_map': file_class_map, 'class_distribution': {str(k): v for k, v in class_distribution.items()}, 'cache_mode': args.cache_mode, 'num_workers': args.num_workers}, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
244
 
 
23
 
24
  from scripts.analyze_distribution import get_return_class_map
25
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
26
+ from data.data_loader import summarize_context_window
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
29
  from neo4j import GraphDatabase
 
33
  _worker_quality_scores_map = None
34
 
35
 
36
+ def _build_context_quota_plan(
37
+ class_ids,
38
+ target_contexts_per_class,
39
+ target_contexts_total,
40
+ good_ratio_nonzero,
41
+ good_ratio_class0,
42
+ ):
43
+ unique_class_ids = sorted(set(int(cid) for cid in class_ids))
44
+ if not unique_class_ids:
45
+ return {}
46
+
47
+ if target_contexts_per_class is not None:
48
+ per_class_target = int(target_contexts_per_class)
49
+ elif target_contexts_total is not None:
50
+ per_class_target = max(1, int(target_contexts_total) // len(unique_class_ids))
51
+ else:
52
+ return {}
53
+
54
+ if per_class_target <= 0:
55
+ raise RuntimeError("Context quota target must be positive.")
56
+
57
+ plan = {}
58
+ for class_id in unique_class_ids:
59
+ ratio = float(good_ratio_class0 if class_id == 0 else good_ratio_nonzero)
60
+ ratio = max(0.0, min(1.0, ratio))
61
+ good_target = int(round(per_class_target * ratio))
62
+ bad_target = per_class_target - good_target
63
+ plan[class_id] = {
64
+ "total_target": per_class_target,
65
+ "good_target": good_target,
66
+ "bad_target": bad_target,
67
+ }
68
+ return plan
69
+
70
+
71
+ def _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
72
+ if not quota_plan:
73
+ return True
74
+
75
+ if class_id not in quota_plan:
76
+ return False
77
+
78
+ class_plan = quota_plan[class_id]
79
+ class_counts = accepted_counts[class_id]
80
+ if class_counts["total"] >= class_plan["total_target"]:
81
+ return False
82
+
83
+ bucket_key = "good" if context_bucket == "good" else "bad"
84
+ target_key = f"{bucket_key}_target"
85
+ if class_counts[bucket_key] >= class_plan[target_key]:
86
+ return False
87
+
88
+ return True
89
+
90
+
91
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
92
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
93
  from data.data_loader import OracleDataset
 
99
 
100
  _worker_dataset = OracleDataset(
101
  data_fetcher=data_fetcher,
102
+ min_trades=dataset_config['min_trades'],
103
  start_date=dataset_config['start_date'],
104
  horizons_seconds=dataset_config['horizons_seconds'],
105
  quantiles=dataset_config['quantiles'],
 
124
  q_score = _worker_quality_scores_map.get(mint_addr)
125
  if q_score is None:
126
  return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
127
+ return {
128
+ 'status': 'success',
129
+ 'mint': mint_addr,
130
+ 'class_id': class_id,
131
+ 'q_score': q_score,
132
+ 'n_contexts': len(contexts),
133
+ 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
134
+ 'contexts': contexts,
135
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
137
  import traceback
138
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
 
151
 
152
  parser = argparse.ArgumentParser()
153
  parser.add_argument("--output_dir", type=str, default="data/cache")
 
154
  parser.add_argument("--start_date", type=str, default=None)
155
 
156
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
 
 
157
  parser.add_argument("--min_trades", type=int, default=10)
158
+ parser.add_argument("--context_length", type=int, default=8192)
159
  parser.add_argument("--samples_per_token", type=int, default=1)
160
+ parser.add_argument("--target_contexts_per_class", type=int, default=None)
161
+ parser.add_argument("--target_contexts_total", type=int, default=None)
162
+ parser.add_argument("--good_ratio_nonzero", type=float, default=0.5)
163
+ parser.add_argument("--good_ratio_class0", type=float, default=0.0)
164
  parser.add_argument("--num_workers", type=int, default=1)
165
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
166
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
 
169
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
170
  args = parser.parse_args()
171
 
172
+ if args.target_contexts_per_class is not None and args.target_contexts_total is not None:
173
+ raise RuntimeError(
174
+ "Choose exactly one cache budget: either --target_contexts_per_class or --target_contexts_total."
175
+ )
176
+
177
  if args.num_workers == 0:
178
  args.num_workers = max(1, mp.cpu_count() - 4)
179
 
 
199
  quality_scores_map = get_token_quality_scores(clickhouse_client)
200
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
201
 
202
+ dataset = OracleDataset(
203
+ data_fetcher=data_fetcher,
204
+ min_trades=args.min_trades,
205
+ start_date=start_date_dt,
206
+ horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
207
+ quantiles=[0.5],
208
+ min_trade_usd=args.min_trade_usd,
209
+ max_seq_len=args.context_length,
210
+ )
211
 
212
  if len(dataset) == 0:
213
  print("WARNING: No samples. Exiting.")
 
222
  print("WARNING: No tokens after filtering.")
223
  return
224
 
225
+ print(f"INFO: Building canonical context cache | Workers: {args.num_workers}")
226
+
227
+ if args.num_workers != 1 and (
228
+ args.target_contexts_per_class is not None or args.target_contexts_total is not None
229
+ ):
230
+ raise RuntimeError(
231
+ "Quota-driven context caching currently requires --num_workers 1 so accepted contexts "
232
+ "can be planned and written deterministically in one process."
233
+ )
234
 
235
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
236
+ dataset_config = {'start_date': start_date_dt, 'min_trades': args.min_trades, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
237
 
238
  # Build tasks from filtered_mints directly
239
  tasks = []
240
  for i, mint_record in enumerate(filtered_mints):
241
  mint_addr = mint_record['mint_address']
242
+ tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
 
 
 
243
 
244
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
245
 
246
  success_count, skipped_count, error_count = 0, 0, 0
247
  class_distribution = {}
248
+ context_distribution = defaultdict(lambda: defaultdict(int))
249
+ file_class_map = {}
250
+ file_context_bucket_map = {}
251
+ file_context_summary_map = {}
252
+ process_fn = _process_single_token_context
253
+ quota_plan = {}
254
+ accepted_counts = defaultdict(lambda: {"total": 0, "good": 0, "bad": 0})
255
+ accepted_per_token = defaultdict(int)
256
+
257
+ quota_plan = _build_context_quota_plan(
258
+ class_ids=[return_class_map[m['mint_address']] for m in filtered_mints if m['mint_address'] in return_class_map],
259
+ target_contexts_per_class=args.target_contexts_per_class,
260
+ target_contexts_total=args.target_contexts_total,
261
+ good_ratio_nonzero=args.good_ratio_nonzero,
262
+ good_ratio_class0=args.good_ratio_class0,
263
+ )
264
+ if quota_plan:
265
+ print("INFO: Context quota plan:")
266
+ for class_id, plan in sorted(quota_plan.items()):
267
+ print(
268
+ f" Class {class_id}: total={plan['total_target']} "
269
+ f"(good={plan['good_target']}, bad={plan['bad_target']})"
270
+ )
271
 
272
  if args.num_workers == 1:
273
  print("INFO: Single-threaded mode...")
 
275
  for task in tqdm(tasks, desc="Caching"):
276
  result = process_fn(task)
277
  if result['status'] == 'success':
278
+ if quota_plan:
279
+ class_id = result['class_id']
280
+ mint_addr = result['mint']
281
+ q_score = result['q_score']
282
+ saved_any = False
283
+ for ctx in result.get("contexts", []):
284
+ context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
285
+ context_bucket = context_summary["context_bucket"]
286
+ if not _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
287
+ continue
288
+
289
+ ctx["quality_score"] = q_score
290
+ ctx["class_id"] = class_id
291
+ ctx["source_token"] = mint_addr
292
+ ctx["context_bucket"] = context_bucket
293
+ ctx["context_score"] = context_summary["context_score"]
294
+
295
+ file_idx = accepted_per_token[mint_addr]
296
+ filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
297
+ output_path = Path(output_dir) / filename
298
+ torch.save(ctx, output_path)
299
+
300
+ accepted_per_token[mint_addr] += 1
301
+ accepted_counts[class_id]["total"] += 1
302
+ accepted_counts[class_id][context_bucket] += 1
303
+ class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
304
+ context_distribution[class_id][context_bucket] += 1
305
+ file_class_map[filename] = class_id
306
+ file_context_bucket_map[filename] = context_bucket
307
+ file_context_summary_map[filename] = context_summary
308
+ saved_any = True
309
+
310
+ if saved_any:
311
+ success_count += 1
312
+ else:
313
+ class_id = result['class_id']
314
+ mint_addr = result['mint']
315
+ q_score = result['q_score']
316
+ for ctx_idx, ctx in enumerate(result.get("contexts", [])):
317
+ context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
318
+ context_bucket = context_summary["context_bucket"]
319
+ ctx["quality_score"] = q_score
320
+ ctx["class_id"] = class_id
321
+ ctx["source_token"] = mint_addr
322
+ ctx["context_bucket"] = context_bucket
323
+ ctx["context_score"] = context_summary["context_score"]
324
+ filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
325
+ output_path = Path(output_dir) / filename
326
+ torch.save(ctx, output_path)
327
+ file_class_map[filename] = class_id
328
+ file_context_bucket_map[filename] = context_bucket
329
+ file_context_summary_map[filename] = context_summary
330
+ class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
331
+ context_distribution[class_id][context_bucket] += 1
332
+ success_count += 1
333
  elif result['status'] == 'skipped':
334
  skipped_count += 1
335
  else:
 
343
  try:
344
  result = future.result(timeout=300)
345
  if result['status'] == 'success':
346
+ class_id = result['class_id']
347
+ mint_addr = result['mint']
348
+ q_score = result['q_score']
349
+ for ctx_idx, ctx in enumerate(result.get("contexts", [])):
350
+ context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
351
+ context_bucket = context_summary["context_bucket"]
352
+ ctx["quality_score"] = q_score
353
+ ctx["class_id"] = class_id
354
+ ctx["source_token"] = mint_addr
355
+ ctx["context_bucket"] = context_bucket
356
+ ctx["context_score"] = context_summary["context_score"]
357
+ filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
358
+ output_path = Path(output_dir) / filename
359
+ torch.save(ctx, output_path)
360
+ file_class_map[filename] = class_id
361
+ file_context_bucket_map[filename] = context_bucket
362
+ file_context_summary_map[filename] = context_summary
363
+ class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
364
+ context_distribution[class_id][context_bucket] += 1
365
  success_count += 1
 
366
  elif result['status'] == 'skipped':
367
  skipped_count += 1
368
  else:
 
372
  tqdm.write(f"WORKER ERROR: {e}")
373
 
374
  print("INFO: Building metadata...")
375
+ if not file_class_map:
376
+ for f in sorted(output_dir.glob("sample_*.pt")):
377
+ try:
378
+ cached = torch.load(f, map_location="cpu", weights_only=False)
379
+ file_class_map[f.name] = cached.get("class_id", 0)
380
+ if "labels" in cached and "labels_mask" in cached:
381
+ context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
382
+ file_context_bucket_map[f.name] = context_summary["context_bucket"]
383
+ file_context_summary_map[f.name] = context_summary
384
+ except Exception:
385
+ pass
386
 
387
  with open(output_dir / "class_metadata.json", 'w') as f:
388
+ json.dump({
389
+ 'file_class_map': file_class_map,
390
+ 'file_context_bucket_map': file_context_bucket_map,
391
+ 'file_context_summary_map': file_context_summary_map,
392
+ 'class_distribution': {str(k): v for k, v in class_distribution.items()},
393
+ 'context_distribution': {
394
+ str(k): {bucket: count for bucket, count in bucket_counts.items()}
395
+ for k, bucket_counts in context_distribution.items()
396
+ },
397
+ 'quota_plan': {str(k): v for k, v in quota_plan.items()},
398
+ 'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
399
+ 'num_workers': args.num_workers,
400
+ }, f, indent=2)
401
+
402
+ if quota_plan:
403
+ print("INFO: Accepted context counts:")
404
+ for class_id, counts in sorted(accepted_counts.items()):
405
+ print(
406
+ f" Class {class_id}: total={counts['total']} "
407
+ f"good={counts['good']} bad={counts['bad']}"
408
+ )
409
 
410
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
411
 
data/data_loader.py CHANGED
@@ -64,6 +64,63 @@ MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
64
  HOLDER_SNAPSHOT_INTERVAL_SEC = 300
65
  HOLDER_SNAPSHOT_TOP_K = 200
66
  DEAD_URI_RETRY_LIMIT = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  class EmbeddingPooler:
@@ -123,6 +180,7 @@ class OracleDataset(Dataset):
123
  horizons_seconds: List[int] = [],
124
  quantiles: List[float] = [],
125
  max_samples: Optional[int] = None,
 
126
 
127
  token_allowlist: Optional[List[str]] = None,
128
  cache_dir: Optional[Union[str, Path]] = None,
@@ -131,8 +189,11 @@ class OracleDataset(Dataset):
131
  max_seq_len: int = 8192,
132
  p99_clamps: Optional[Dict[str, float]] = None,
133
  movement_label_config: Optional[Dict[str, float]] = None):
134
-
135
  self.max_seq_len = max_seq_len
 
 
 
136
 
137
  # --- P99 data-driven clamp values (replace hardcoded min/max) ---
138
  self.p99_clamps = {
@@ -187,9 +248,12 @@ class OracleDataset(Dataset):
187
  if not self.cached_files:
188
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
189
 
190
- # --- OPTIMIZED: Load class_ids from metadata cache file ---
191
  file_class_map = {}
 
 
192
  class_counts = defaultdict(int)
 
193
  metadata_path = self.cache_dir / "class_metadata.json"
194
 
195
  if metadata_path.exists():
@@ -199,20 +263,29 @@ class OracleDataset(Dataset):
199
  with open(metadata_path, 'r') as f:
200
  cached_metadata = json.load(f)
201
  file_class_map = cached_metadata.get('file_class_map', {})
 
 
202
  # Validate that cached files match metadata
203
  cached_file_names = {p.name for p in self.cached_files}
204
  metadata_file_names = set(file_class_map.keys())
205
  if cached_file_names != metadata_file_names:
206
  print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...")
207
  file_class_map = {}
 
 
208
  else:
209
  # Rebuild class_counts from loaded map
210
- for cid in file_class_map.values():
211
  class_counts[cid] += 1
 
 
 
212
  print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s")
213
  except Exception as e:
214
  print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...")
215
  file_class_map = {}
 
 
216
 
217
  # Slow path: scan all files and build metadata cache
218
  if not file_class_map:
@@ -229,23 +302,41 @@ class OracleDataset(Dataset):
229
  if cid is None:
230
  print(f"WARN: File {p.name} missing class_id. Skipping.")
231
  continue
 
 
 
 
 
232
  file_class_map[p.name] = cid
 
 
233
  class_counts[cid] += 1
 
234
  except Exception as e:
235
  print(f"WARN: Failed to read cached sample {p.name}: {e}")
236
 
237
  # Save metadata cache for future runs
238
  try:
239
  with open(metadata_path, 'w') as f:
240
- json.dump({'file_class_map': file_class_map}, f)
 
 
 
 
241
  print(f"INFO: Saved class metadata cache to {metadata_path}")
242
  except Exception as e:
243
  print(f"WARN: Failed to save metadata cache: {e}")
244
 
245
  print(f"INFO: Class Distribution: {dict(class_counts)}")
 
 
 
 
246
 
247
  # Store file_class_map for fast lookup by train.py's create_balanced_split
248
  self.file_class_map = {p: cid for p, cid in file_class_map.items()}
 
 
249
 
250
  # Compute Weights
251
  self.weights_list = []
@@ -260,8 +351,24 @@ class OracleDataset(Dataset):
260
  continue
261
 
262
  cid = file_class_map[fname]
263
- count = class_counts[cid]
264
- weight = 1.0 / (count ** 0.5) if count > 0 else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  self.weights_list.append(weight)
266
  valid_files.append(p)
267
 
@@ -273,6 +380,25 @@ class OracleDataset(Dataset):
273
  self.cached_files = self.cached_files[:self.num_samples]
274
  self.weights_list = self.weights_list[:self.num_samples]
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.")
277
  self.sampled_mints = [] # Not needed in cached mode
278
  self.available_mints = []
@@ -1297,7 +1423,7 @@ class OracleDataset(Dataset):
1297
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
1298
  )
1299
 
1300
- min_context_trades = 10
1301
  if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
1302
  return None
1303
 
@@ -1626,7 +1752,7 @@ class OracleDataset(Dataset):
1626
  max_horizon_seconds=self.max_cache_horizon_seconds,
1627
  include_wallet_data=False,
1628
  include_graph=False,
1629
- min_trades=10, # Enforce min trades for context
1630
  full_history=True, # Bypass H/B/H limits
1631
  prune_failed=False, # Keep failed trades for realistic simulation
1632
  prune_transfers=False # Keep transfers for snapshot reconstruction
@@ -1641,8 +1767,14 @@ class OracleDataset(Dataset):
1641
  raw_data['name'] = initial_mint_record.get('token_name', '')
1642
  raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
1643
  raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
1644
- raw_data['total_supply'] = initial_mint_record.get('total_supply', 0)
1645
- raw_data['decimals'] = initial_mint_record.get('token_decimals', 6)
 
 
 
 
 
 
1646
  raw_data['protocol'] = initial_mint_record.get('protocol', 1)
1647
 
1648
  def _timestamp_to_order_value(ts_value: Any) -> float:
@@ -2686,7 +2818,7 @@ class OracleDataset(Dataset):
2686
  max_horizon_seconds=self.max_cache_horizon_seconds,
2687
  include_wallet_data=False,
2688
  include_graph=False,
2689
- min_trades=10,
2690
  full_history=True,
2691
  prune_failed=False,
2692
  prune_transfers=False
@@ -2704,8 +2836,14 @@ class OracleDataset(Dataset):
2704
  raw_data['name'] = initial_mint_record.get('token_name', '')
2705
  raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
2706
  raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
2707
- raw_data['total_supply'] = initial_mint_record.get('total_supply', 0)
2708
- raw_data['decimals'] = initial_mint_record.get('token_decimals', 6)
 
 
 
 
 
 
2709
  raw_data['protocol'] = initial_mint_record.get('protocol', 1)
2710
 
2711
  def _timestamp_to_order_value(ts_value) -> float:
@@ -2734,7 +2872,7 @@ class OracleDataset(Dataset):
2734
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
2735
  )
2736
 
2737
- min_context_trades = 10
2738
  if len(all_trades_sorted) < (min_context_trades + 1):
2739
  print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
2740
  return []
@@ -2821,9 +2959,9 @@ class OracleDataset(Dataset):
2821
  total_supply_raw = int(raw_total_supply)
2822
  token_decimals = int(raw_decimals)
2823
  if total_supply_raw <= 0:
2824
- raise RuntimeError(f"Invalid total_supply for {token_address}: {total_supply_raw}")
2825
  if token_decimals < 0:
2826
- raise RuntimeError(f"Invalid decimals for {token_address}: {token_decimals}")
2827
  token_scale = 10 ** token_decimals
2828
 
2829
  def _strict_int(v: Any, field_name: str) -> int:
 
64
  HOLDER_SNAPSHOT_INTERVAL_SEC = 300
65
  HOLDER_SNAPSHOT_TOP_K = 200
66
  DEAD_URI_RETRY_LIMIT = 2
67
+ DEFAULT_TOTAL_SUPPLY_RAW = 1_000_000_000_000_000
68
+ DEFAULT_TOKEN_DECIMALS = 6
69
+
70
+ CONTEXT_BUCKET_NEGATIVE = "bad"
71
+ CONTEXT_BUCKET_POSITIVE = "good"
72
+
73
+
74
+ def summarize_context_window(
75
+ labels: Any,
76
+ labels_mask: Any,
77
+ ) -> Dict[str, Any]:
78
+ """
79
+ Summarize a realized context window using its valid future returns.
80
+
81
+ Base rule:
82
+ - each horizon contributes signed terminal PnL from buying at cutoff
83
+ - magnitude matters, so we compress returns with signed log1p
84
+ - the context is `good` only if the net score is positive
85
+ """
86
+ if labels is None or labels_mask is None:
87
+ raise RuntimeError("Context weighting requires both 'labels' and 'labels_mask'.")
88
+
89
+ if isinstance(labels, torch.Tensor):
90
+ label_vals = labels.tolist()
91
+ else:
92
+ label_vals = list(labels)
93
+
94
+ if isinstance(labels_mask, torch.Tensor):
95
+ mask_vals = labels_mask.tolist()
96
+ else:
97
+ mask_vals = list(labels_mask)
98
+
99
+ valid_returns = [
100
+ float(ret)
101
+ for ret, keep in zip(label_vals, mask_vals)
102
+ if float(keep) > 0.0
103
+ ]
104
+ signed_contributions = []
105
+ for ret in valid_returns:
106
+ magnitude = np.log1p(abs(ret))
107
+ signed_contributions.append(magnitude if ret > 0.0 else -magnitude)
108
+
109
+ positive_count = sum(1 for ret in valid_returns if ret > 0.0)
110
+ negative_count = len(valid_returns) - positive_count
111
+ context_score = float(sum(signed_contributions) / len(signed_contributions)) if signed_contributions else 0.0
112
+ context_bucket = (
113
+ CONTEXT_BUCKET_POSITIVE
114
+ if context_score > 0.0
115
+ else CONTEXT_BUCKET_NEGATIVE
116
+ )
117
+ return {
118
+ "context_bucket": context_bucket,
119
+ "context_score": context_score,
120
+ "positive_horizons": positive_count,
121
+ "negative_horizons": negative_count,
122
+ "valid_horizons": len(valid_returns),
123
+ }
124
 
125
 
126
  class EmbeddingPooler:
 
180
  horizons_seconds: List[int] = [],
181
  quantiles: List[float] = [],
182
  max_samples: Optional[int] = None,
183
+ min_trades: int = 10,
184
 
185
  token_allowlist: Optional[List[str]] = None,
186
  cache_dir: Optional[Union[str, Path]] = None,
 
189
  max_seq_len: int = 8192,
190
  p99_clamps: Optional[Dict[str, float]] = None,
191
  movement_label_config: Optional[Dict[str, float]] = None):
192
+
193
  self.max_seq_len = max_seq_len
194
+ self.min_trades = int(min_trades)
195
+ if self.min_trades < 1:
196
+ raise RuntimeError(f"min_trades must be >= 1, got {self.min_trades}")
197
 
198
  # --- P99 data-driven clamp values (replace hardcoded min/max) ---
199
  self.p99_clamps = {
 
248
  if not self.cached_files:
249
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
250
 
251
+ # --- OPTIMIZED: Load cached metadata if available ---
252
  file_class_map = {}
253
+ file_context_bucket_map = {}
254
+ file_context_summary_map = {}
255
  class_counts = defaultdict(int)
256
+ class_context_counts = defaultdict(lambda: defaultdict(int))
257
  metadata_path = self.cache_dir / "class_metadata.json"
258
 
259
  if metadata_path.exists():
 
263
  with open(metadata_path, 'r') as f:
264
  cached_metadata = json.load(f)
265
  file_class_map = cached_metadata.get('file_class_map', {})
266
+ file_context_bucket_map = cached_metadata.get('file_context_bucket_map', {})
267
+ file_context_summary_map = cached_metadata.get('file_context_summary_map', {})
268
  # Validate that cached files match metadata
269
  cached_file_names = {p.name for p in self.cached_files}
270
  metadata_file_names = set(file_class_map.keys())
271
  if cached_file_names != metadata_file_names:
272
  print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...")
273
  file_class_map = {}
274
+ file_context_bucket_map = {}
275
+ file_context_summary_map = {}
276
  else:
277
  # Rebuild class_counts from loaded map
278
+ for fname, cid in file_class_map.items():
279
  class_counts[cid] += 1
280
+ bucket = file_context_bucket_map.get(fname)
281
+ if bucket is not None:
282
+ class_context_counts[cid][bucket] += 1
283
  print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s")
284
  except Exception as e:
285
  print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...")
286
  file_class_map = {}
287
+ file_context_bucket_map = {}
288
+ file_context_summary_map = {}
289
 
290
  # Slow path: scan all files and build metadata cache
291
  if not file_class_map:
 
302
  if cid is None:
303
  print(f"WARN: File {p.name} missing class_id. Skipping.")
304
  continue
305
+ context_summary = summarize_context_window(
306
+ cached_item.get("labels"),
307
+ cached_item.get("labels_mask"),
308
+ )
309
+ bucket = context_summary["context_bucket"]
310
  file_class_map[p.name] = cid
311
+ file_context_bucket_map[p.name] = bucket
312
+ file_context_summary_map[p.name] = context_summary
313
  class_counts[cid] += 1
314
+ class_context_counts[cid][bucket] += 1
315
  except Exception as e:
316
  print(f"WARN: Failed to read cached sample {p.name}: {e}")
317
 
318
  # Save metadata cache for future runs
319
  try:
320
  with open(metadata_path, 'w') as f:
321
+ json.dump({
322
+ 'file_class_map': file_class_map,
323
+ 'file_context_bucket_map': file_context_bucket_map,
324
+ 'file_context_summary_map': file_context_summary_map,
325
+ }, f)
326
  print(f"INFO: Saved class metadata cache to {metadata_path}")
327
  except Exception as e:
328
  print(f"WARN: Failed to save metadata cache: {e}")
329
 
330
  print(f"INFO: Class Distribution: {dict(class_counts)}")
331
+ print(
332
+ "INFO: Context Distribution by Class: "
333
+ f"{ {cid: dict(bucket_counts) for cid, bucket_counts in class_context_counts.items()} }"
334
+ )
335
 
336
  # Store file_class_map for fast lookup by train.py's create_balanced_split
337
  self.file_class_map = {p: cid for p, cid in file_class_map.items()}
338
+ self.file_context_bucket_map = {p: bucket for p, bucket in file_context_bucket_map.items()}
339
+ self.file_context_summary_map = {p: summary for p, summary in file_context_summary_map.items()}
340
 
341
  # Compute Weights
342
  self.weights_list = []
 
351
  continue
352
 
353
  cid = file_class_map[fname]
354
+ bucket = file_context_bucket_map.get(fname)
355
+ if bucket is None:
356
+ raise RuntimeError(
357
+ f"Cached sample '{fname}' is missing a context bucket. "
358
+ "Rebuild metadata or cache before training."
359
+ )
360
+ class_bucket_counts = class_context_counts[cid]
361
+ present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0]
362
+ if not present_buckets:
363
+ raise RuntimeError(
364
+ f"Class {cid} has no valid context buckets recorded. Cannot compute sampler weights."
365
+ )
366
+ bucket_count = class_bucket_counts[bucket]
367
+ if bucket_count <= 0:
368
+ raise RuntimeError(
369
+ f"Class {cid} bucket '{bucket}' has invalid count {bucket_count} for sample '{fname}'."
370
+ )
371
+ weight = 1.0 / (len(present_buckets) * bucket_count)
372
  self.weights_list.append(weight)
373
  valid_files.append(p)
374
 
 
380
  self.cached_files = self.cached_files[:self.num_samples]
381
  self.weights_list = self.weights_list[:self.num_samples]
382
 
383
+ # Recompute sampler weights against the active cached file subset so the
384
+ # class/context balancing reflects the actual dataset seen by training.
385
+ active_class_context_counts = defaultdict(lambda: defaultdict(int))
386
+ for p in self.cached_files:
387
+ fname = p.name
388
+ cid = file_class_map[fname]
389
+ bucket = file_context_bucket_map[fname]
390
+ active_class_context_counts[cid][bucket] += 1
391
+
392
+ self.weights_list = []
393
+ for p in self.cached_files:
394
+ fname = p.name
395
+ cid = file_class_map[fname]
396
+ bucket = file_context_bucket_map[fname]
397
+ class_bucket_counts = active_class_context_counts[cid]
398
+ present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0]
399
+ bucket_count = class_bucket_counts[bucket]
400
+ self.weights_list.append(1.0 / (len(present_buckets) * bucket_count))
401
+
402
  print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.")
403
  self.sampled_mints = [] # Not needed in cached mode
404
  self.available_mints = []
 
1423
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
1424
  )
1425
 
1426
+ min_context_trades = self.min_trades
1427
  if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
1428
  return None
1429
 
 
1752
  max_horizon_seconds=self.max_cache_horizon_seconds,
1753
  include_wallet_data=False,
1754
  include_graph=False,
1755
+ min_trades=self.min_trades,
1756
  full_history=True, # Bypass H/B/H limits
1757
  prune_failed=False, # Keep failed trades for realistic simulation
1758
  prune_transfers=False # Keep transfers for snapshot reconstruction
 
1767
  raw_data['name'] = initial_mint_record.get('token_name', '')
1768
  raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
1769
  raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
1770
+ raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW)
1771
+ raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS)
1772
+ raw_data['total_supply'] = (
1773
+ int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW
1774
+ )
1775
+ raw_data['decimals'] = (
1776
+ int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS
1777
+ )
1778
  raw_data['protocol'] = initial_mint_record.get('protocol', 1)
1779
 
1780
  def _timestamp_to_order_value(ts_value: Any) -> float:
 
2818
  max_horizon_seconds=self.max_cache_horizon_seconds,
2819
  include_wallet_data=False,
2820
  include_graph=False,
2821
+ min_trades=self.min_trades,
2822
  full_history=True,
2823
  prune_failed=False,
2824
  prune_transfers=False
 
2836
  raw_data['name'] = initial_mint_record.get('token_name', '')
2837
  raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
2838
  raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
2839
+ raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW)
2840
+ raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS)
2841
+ raw_data['total_supply'] = (
2842
+ int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW
2843
+ )
2844
+ raw_data['decimals'] = (
2845
+ int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS
2846
+ )
2847
  raw_data['protocol'] = initial_mint_record.get('protocol', 1)
2848
 
2849
  def _timestamp_to_order_value(ts_value) -> float:
 
2872
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
2873
  )
2874
 
2875
+ min_context_trades = self.min_trades
2876
  if len(all_trades_sorted) < (min_context_trades + 1):
2877
  print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
2878
  return []
 
2959
  total_supply_raw = int(raw_total_supply)
2960
  token_decimals = int(raw_decimals)
2961
  if total_supply_raw <= 0:
2962
+ total_supply_raw = DEFAULT_TOTAL_SUPPLY_RAW
2963
  if token_decimals < 0:
2964
+ token_decimals = DEFAULT_TOKEN_DECIMALS
2965
  token_scale = 10 ** token_decimals
2966
 
2967
  def _strict_int(v: Any, field_name: str) -> int:
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:935233e4d7669b2a25173d7ae164317e85f1a5e8b0fc1d8d1832ab0893fca471
3
- size 19258
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df78e2b44dd97a148be762f91e3b00f397651f8e7e43ee21f938492291fdfa3a
3
+ size 83447
resume.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resume
2
+
3
+ ## Main conclusions from this chat
4
+
5
+ 1. The main issue is data/sample construction, not just checkpoint choice.
6
+ - `class_id` is token-level.
7
+ - labels are context-level and depend on sampled `T_cutoff`.
8
+ - balanced token classes do not imply balanced future outcomes.
9
+ - the model can easily learn an over-negative prior if the cache is built naively.
10
+
11
+ 2. Cache balancing must happen at cache generation time.
12
+ - train-time weighting is too late to fix disk waste or missing context diversity.
13
+ - the cache should control:
14
+ - class balance
15
+ - good/bad context balance within class
16
+
17
+ 3. Validation needed token isolation, not just class balancing.
18
+ - same token appearing in train and val through different contexts makes validation misleading.
19
+
20
+ 4. Movement-threshold circularity is not the main blocker anymore.
21
+ - movement labels are downstream of realized returns.
22
+ - movement thresholds should not control cache construction.
23
+
24
+ 5. OHLC is important, but current usage looks like trend/regime summary, not true pattern detection.
25
+ - the chart branch is being used.
26
+ - but probe tests suggest it is mostly acting as continuation/trend context.
27
+ - not as breakout / support-resistance / head-and-shoulders intelligence.
28
+
29
+ 6. The current code is still regression-first.
30
+ - main target is multi-horizon return prediction.
31
+ - there is also:
32
+ - quality head
33
+ - movement head
34
+ - calling movement auxiliary only matters if its loss contribution is actually secondary.
35
+
36
+ ## What was implemented in code
37
+
38
+ ### 1. Validation split
39
+ - `train.py`
40
+ - validation split was changed to group by token identity (`source_token` / `token_address`) instead of only `class_id`.
41
+
42
+ ### 2. Task tracker
43
+ - `TASK_LIST.md`
44
+ - created as the running checklist for this work.
45
+
46
+ ### 3. Context weighting signal
47
+ - `data/data_loader.py`
48
+ - added a context-quality summary derived from realized multi-horizon returns.
49
+ - current code computes:
50
+ - `context_score`
51
+ - `context_bucket` (`good` / `bad`)
52
+ - this is used for weighting and cache metadata.
53
+
54
+ ### 4. Cached dataset weighting
55
+ - `data/data_loader.py`
56
+ - sampler weights now account for:
57
+ - class balance
58
+ - good/bad context balance inside class
59
+
60
+ ### 5. Weighted cache generation
61
+ - `cache_dataset.py`
62
+ - canonical builder now writes context cache only.
63
+ - no `cache_mode`.
64
+ - no `max_samples`.
65
+ - cache budget is context-based:
66
+ - `--target_contexts_total`
67
+ - or `--target_contexts_per_class`
68
+ - quota-driven acceptance is implemented before save:
69
+ - level 1: class quotas
70
+ - level 2: good/bad quotas within class
71
+
72
+ ### 6. Metadata rebuild
73
+ - `scripts/rebuild_metadata.py`
74
+ - now rebuilds:
75
+ - `file_class_map`
76
+ - `file_context_bucket_map`
77
+ - `file_context_summary_map`
78
+
79
+ ### 7. `min_trades`
80
+ - re-added properly
81
+ - no longer a dead CLI arg
82
+ - now actually controls dataset/context eligibility thresholds
83
+
84
+ ### 8. Evaluate script OHLC probes
85
+ - `scripts/evaluate_sample.py`
86
+ - added these OHLC probe modes:
87
+ - `ohlc_reverse`
88
+ - `ohlc_shuffle_chunks`
89
+ - `ohlc_mask_recent`
90
+ - `ohlc_trend_only`
91
+ - `ohlc_summary_shuffle`
92
+ - `ohlc_detrend`
93
+ - `ohlc_smooth`
94
+
95
+ ### 9. Bad mint metadata fallback
96
+ - `data/data_loader.py`
97
+ - if mint metadata has invalid or zero supply/decimals, it now defaults to:
98
+ - `total_supply = 1000000000000000`
99
+ - `token_decimals = 6`
100
+
101
+ ## Current cache-builder interface
102
+
103
+ `cache_dataset.py` now uses a canonical context-cache path only.
104
+
105
+ Relevant args:
106
+ - `--output_dir`
107
+ - `--start_date`
108
+ - `--min_trade_usd`
109
+ - `--min_trades`
110
+ - `--context_length`
111
+ - `--samples_per_token`
112
+ - `--target_contexts_per_class`
113
+ - `--target_contexts_total`
114
+ - `--good_ratio_nonzero`
115
+ - `--good_ratio_class0`
116
+ - `--num_workers`
117
+ - DB connection args
118
+
119
+ Rules:
120
+ - choose exactly one of:
121
+ - `--target_contexts_per_class`
122
+ - `--target_contexts_total`
123
+ - if quota-driven caching is used, current implementation expects:
124
+ - `--num_workers 1`
125
+
126
+ `--samples_per_token` is still present.
127
+ - It is not a cache budget.
128
+ - It is a candidate-generation knob.
129
+ - It may still be removable later if a better attempt-based planner replaces it.
130
+
131
+ ## Important conceptual corrections from this chat
132
+
133
+ 1. Binary context type matters operationally, but a naive binary rule is dangerous.
134
+ - examples like `+1%` then `-20%` show why simplistic rules fail.
135
+ - context typing should come from realized multi-horizon behavior, not one crude shortcut.
136
+
137
+ 2. Patching alone does not force pattern learning.
138
+ - it only makes local pattern use possible.
139
+ - the model can still rely on trend shortcuts unless the representation and training setup make that harder.
140
+
141
+ 3. Support/resistance may be inferable from current chart inputs in principle.
142
+ - but current encoder likely compresses too early and learns an easier shortcut instead.
143
+
144
+ 4. For this question, the bottleneck is not just “what loss?”
145
+ - but also:
146
+ - chart input representation
147
+ - encoder compression
148
+ - whether the model can preserve local 1s structure
149
+
150
+ ## OHLC probe findings from this chat
151
+
152
+ The key probe result repeated across runs was:
153
+ - `ohlc_detrend` had the largest impact
154
+ - `ohlc_trend_only` had the second-largest impact
155
+ - `ohlc_smooth`, `shuffle`, `summary_shuffle`, `reverse`, and `mask_recent` had very small impact
156
+
157
+ Interpretation:
158
+ - the model is using OHLC mostly as:
159
+ - broad trend/regime context
160
+ - continuation-style directional signal
161
+ - not as:
162
+ - local chart pattern detector
163
+ - support/resistance-aware trader logic
164
+ - breakout/rejection/fair-value-gap style reasoning
165
+
166
+ This strongly suggests many of the model’s bearish/bad predictions are driven by trend continuation behavior from the chart branch.
167
+
168
+ ## What was explicitly rejected or corrected
169
+
170
+ - do not keep treating `cache_mode` as a real concept
171
+ - do not let `max_samples` and context-budget args overlap
172
+ - do not call something “auxiliary” if its weight can dominate optimization
173
+ - do not assume OHLC importance means chart-pattern understanding
174
+ - do not answer architecture questions by guessing from intention; inspect actual code
175
+
176
+ ## What still matters next
177
+
178
+ 1. verify token-overlap assertions are enforced in train/val, not just token-grouped split logic
179
+ 2. rebuild cache with the new quota-based builder and inspect actual distributions
180
+ 3. update cache audit tooling
181
+ 4. decide whether `samples_per_token` should be replaced by a better attempt planner
182
+ 5. decide how the chart branch should evolve if the goal is trader-like 1s pattern reasoning rather than trend summary
183
+
184
+ ## Short current state
185
+
186
+ - cache construction is much closer to the right direction now
187
+ - validation logic is less broken than before
188
+ - OHLC is confirmed important
189
+ - but OHLC is currently behaving more like a trend/summary branch than a technical-pattern branch
190
+ - the codebase is still primarily training on return regression, with extra heads layered on top
scripts/evaluate_sample.py CHANGED
@@ -49,6 +49,10 @@ OHLC_PROBE_MODES = [
49
  "ohlc_reverse",
50
  "ohlc_shuffle_chunks",
51
  "ohlc_mask_recent",
 
 
 
 
52
  ]
53
 
54
  def unlog_transform(tensor):
@@ -221,6 +225,56 @@ def _chunk_permutation_indices(length, chunk_size):
221
  return out
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def apply_ohlc_probe(batch, mode):
225
  probed = clone_batch(batch)
226
  if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
@@ -243,6 +297,26 @@ def apply_ohlc_probe(batch, mode):
243
  elif keep == 0:
244
  ohlc.zero_()
245
  probed["ohlc_price_tensors"] = ohlc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  return probed
248
 
@@ -683,7 +757,7 @@ def main():
683
  batch=batch,
684
  preds=full_preds,
685
  quality_pred=full_quality,
686
- direction_pred=full_direction,
687
  gt_labels=gt_labels,
688
  gt_mask=gt_mask,
689
  gt_quality=gt_quality,
@@ -701,7 +775,7 @@ def main():
701
  batch=probe_batch,
702
  preds=probe_preds,
703
  quality_pred=probe_quality,
704
- direction_pred=probe_direction,
705
  gt_labels=gt_labels,
706
  gt_mask=gt_mask,
707
  gt_quality=gt_quality,
@@ -717,7 +791,7 @@ def main():
717
  batch=ablated_batch,
718
  preds=ablated_preds,
719
  quality_pred=ablated_quality,
720
- direction_pred=ablated_direction,
721
  gt_labels=gt_labels,
722
  gt_mask=gt_mask,
723
  gt_quality=gt_quality,
 
49
  "ohlc_reverse",
50
  "ohlc_shuffle_chunks",
51
  "ohlc_mask_recent",
52
+ "ohlc_trend_only",
53
+ "ohlc_summary_shuffle",
54
+ "ohlc_detrend",
55
+ "ohlc_smooth",
56
  ]
57
 
58
  def unlog_transform(tensor):
 
225
  return out
226
 
227
 
228
+ def _moving_average_1d(series, kernel_size):
229
+ if kernel_size <= 1 or series.numel() == 0:
230
+ return series
231
+ pad = kernel_size // 2
232
+ kernel = torch.ones(1, 1, kernel_size, device=series.device, dtype=series.dtype) / float(kernel_size)
233
+ x = series.view(1, 1, -1)
234
+ x = torch.nn.functional.pad(x, (pad, pad), mode="replicate")
235
+ smoothed = torch.nn.functional.conv1d(x, kernel)
236
+ return smoothed.view(-1)[: series.numel()]
237
+
238
+
239
+ def _linear_trend(series):
240
+ if series.numel() <= 1:
241
+ return series.clone()
242
+ start = series[0]
243
+ end = series[-1]
244
+ steps = torch.linspace(0.0, 1.0, series.numel(), device=series.device, dtype=series.dtype)
245
+ return start + (end - start) * steps
246
+
247
+
248
+ def _summary_preserving_shuffle(series, chunk_size=20):
249
+ length = series.numel()
250
+ if length <= 2:
251
+ return series
252
+ chunks = []
253
+ interior_start = 1
254
+ interior_end = length - 1
255
+ for i in range(interior_start, interior_end, chunk_size):
256
+ chunks.append(series[i:min(i + chunk_size, interior_end)].clone())
257
+ if len(chunks) <= 1:
258
+ return series
259
+ reordered = list(reversed(chunks))
260
+ out = series.clone()
261
+ cursor = 1
262
+ for chunk in reordered:
263
+ out[cursor:cursor + chunk.numel()] = chunk
264
+ cursor += chunk.numel()
265
+ out[0] = series[0]
266
+ out[-1] = series[-1]
267
+ return out
268
+
269
+
270
+ def _apply_per_series(ohlc, transform_fn):
271
+ out = ohlc.clone()
272
+ for batch_idx in range(out.shape[0]):
273
+ for channel_idx in range(out.shape[1]):
274
+ out[batch_idx, channel_idx] = transform_fn(out[batch_idx, channel_idx])
275
+ return out
276
+
277
+
278
  def apply_ohlc_probe(batch, mode):
279
  probed = clone_batch(batch)
280
  if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
 
297
  elif keep == 0:
298
  ohlc.zero_()
299
  probed["ohlc_price_tensors"] = ohlc
300
+ elif mode == "ohlc_trend_only":
301
+ probed["ohlc_price_tensors"] = _apply_per_series(ohlc, _linear_trend)
302
+ elif mode == "ohlc_summary_shuffle":
303
+ probed["ohlc_price_tensors"] = _apply_per_series(
304
+ ohlc,
305
+ lambda series: _summary_preserving_shuffle(series, chunk_size=20),
306
+ )
307
+ elif mode == "ohlc_detrend":
308
+ def detrend(series):
309
+ trend = _linear_trend(series)
310
+ detrended = series - trend + series[0]
311
+ detrended[0] = series[0]
312
+ detrended[-1] = series[0]
313
+ return detrended
314
+ probed["ohlc_price_tensors"] = _apply_per_series(ohlc, detrend)
315
+ elif mode == "ohlc_smooth":
316
+ probed["ohlc_price_tensors"] = _apply_per_series(
317
+ ohlc,
318
+ lambda series: _moving_average_1d(series, kernel_size=11),
319
+ )
320
 
321
  return probed
322
 
 
757
  batch=batch,
758
  preds=full_preds,
759
  quality_pred=full_quality,
760
+ movement_pred=full_direction,
761
  gt_labels=gt_labels,
762
  gt_mask=gt_mask,
763
  gt_quality=gt_quality,
 
775
  batch=probe_batch,
776
  preds=probe_preds,
777
  quality_pred=probe_quality,
778
+ movement_pred=probe_direction,
779
  gt_labels=gt_labels,
780
  gt_mask=gt_mask,
781
  gt_quality=gt_quality,
 
791
  batch=ablated_batch,
792
  preds=ablated_preds,
793
  quality_pred=ablated_quality,
794
+ movement_pred=ablated_direction,
795
  gt_labels=gt_labels,
796
  gt_mask=gt_mask,
797
  gt_quality=gt_quality,
scripts/rebuild_metadata.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  from pathlib import Path
6
  from tqdm import tqdm
7
  from collections import defaultdict
 
8
 
9
  def rebuild_metadata(cache_dir="data/cache"):
10
  cache_path = Path(cache_dir)
@@ -15,10 +16,13 @@ def rebuild_metadata(cache_dir="data/cache"):
15
  print("No .pt files found!")
16
  return
17
 
18
- print(f"Found {len(files)} files. Reading class IDs...")
19
 
20
  file_class_map = {}
 
 
21
  class_distribution = defaultdict(int)
 
22
 
23
  for f in tqdm(files):
24
  try:
@@ -26,14 +30,24 @@ def rebuild_metadata(cache_dir="data/cache"):
26
  # But torch.load loads everything. To be safe/fast, we just load on CPU.
27
  data = torch.load(f, map_location="cpu", weights_only=False)
28
  cid = data.get("class_id", 0)
 
29
  file_class_map[f.name] = cid
 
 
30
  class_distribution[cid] += 1
 
31
  except Exception as e:
32
  print(f"Error reading {f.name}: {e}")
33
 
34
  output_data = {
35
  'file_class_map': file_class_map,
 
 
36
  'class_distribution': {str(k): v for k, v in class_distribution.items()},
 
 
 
 
37
  # These are informational, setting defaults to avoid breaking if loader checks them
38
  'num_workers': 1,
39
  'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
 
5
  from pathlib import Path
6
  from tqdm import tqdm
7
  from collections import defaultdict
8
+ from data.data_loader import summarize_context_window
9
 
10
  def rebuild_metadata(cache_dir="data/cache"):
11
  cache_path = Path(cache_dir)
 
16
  print("No .pt files found!")
17
  return
18
 
19
+ print(f"Found {len(files)} files. Reading class IDs and context summaries...")
20
 
21
  file_class_map = {}
22
+ file_context_bucket_map = {}
23
+ file_context_summary_map = {}
24
  class_distribution = defaultdict(int)
25
+ context_distribution = defaultdict(lambda: defaultdict(int))
26
 
27
  for f in tqdm(files):
28
  try:
 
30
  # But torch.load loads everything. To be safe/fast, we just load on CPU.
31
  data = torch.load(f, map_location="cpu", weights_only=False)
32
  cid = data.get("class_id", 0)
33
+ context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
34
  file_class_map[f.name] = cid
35
+ file_context_bucket_map[f.name] = context_summary["context_bucket"]
36
+ file_context_summary_map[f.name] = context_summary
37
  class_distribution[cid] += 1
38
+ context_distribution[cid][context_summary["context_bucket"]] += 1
39
  except Exception as e:
40
  print(f"Error reading {f.name}: {e}")
41
 
42
  output_data = {
43
  'file_class_map': file_class_map,
44
+ 'file_context_bucket_map': file_context_bucket_map,
45
+ 'file_context_summary_map': file_context_summary_map,
46
  'class_distribution': {str(k): v for k, v in class_distribution.items()},
47
+ 'context_distribution': {
48
+ str(k): {bucket: count for bucket, count in bucket_counts.items()}
49
+ for k, bucket_counts in context_distribution.items()
50
+ },
51
  # These are informational, setting defaults to avoid breaking if loader checks them
52
  'num_workers': 1,
53
  'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
train.py CHANGED
@@ -234,54 +234,77 @@ def collator_like_targets(labels, labels_mask, movement_label_config: Optional[D
234
 
235
  def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
236
  """
237
- Create train/val split with balanced classes in validation set.
238
- Uses dataset's internal file_class_map for speed (no file loading).
239
- Returns (train_indices, val_indices, class_to_indices).
 
 
 
 
 
240
  """
241
  import random
242
- random.seed(seed)
243
 
244
- # Group indices by class_id - use dataset's existing map if available
245
  class_to_indices = defaultdict(list)
 
246
 
247
- # Fast path: use dataset's sample_labels (aligned with __getitem__)
248
- if hasattr(dataset, 'sample_labels') and dataset.sample_labels:
249
- for idx, class_id in enumerate(dataset.sample_labels):
250
- class_to_indices[class_id].append(idx)
251
- # Legacy path: use dataset's file_class_map (for 1-file-1-sample datasets)
252
- elif hasattr(dataset, 'file_class_map') and dataset.file_class_map:
253
- for idx, cached_file in enumerate(dataset.cached_files):
254
- # file_class_map uses filename strings as keys, cached_files are Path objects
255
- fname = cached_file.name if hasattr(cached_file, 'name') else str(cached_file)
256
- class_id = dataset.file_class_map.get(fname, 0)
257
- class_to_indices[class_id].append(idx)
258
- else:
259
- # Fallback: load from files (slow but works)
260
- logger.info("No file_class_map found, loading class IDs from files (this may take a while)...")
261
- import torch
262
- for idx in range(len(dataset.cached_files)):
263
- try:
264
- cached_item = torch.load(dataset.cached_files[idx], map_location="cpu", weights_only=False)
265
- class_id = cached_item.get("class_id", 0)
266
- class_to_indices[class_id].append(idx)
267
- except Exception:
268
- class_to_indices[0].append(idx)
269
 
270
  train_indices = []
271
  val_indices = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # For each class, take n_val_per_class samples for validation
274
- for class_id, indices in class_to_indices.items():
275
- random.shuffle(indices)
276
- n_val = min(len(indices), n_val_per_class) # Ensure we don't take more than we have
277
- val_indices.extend(indices[:n_val])
278
- train_indices.extend(indices[n_val:])
279
-
280
- # Shuffle both sets
281
- random.shuffle(train_indices)
282
- random.shuffle(val_indices)
283
 
284
- return train_indices, val_indices, class_to_indices
285
 
286
 
287
  def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_fn, vocab):
@@ -688,18 +711,27 @@ def main() -> None:
688
  raise RuntimeError("Dataset is empty.")
689
 
690
  # --- NEW: Create balanced train/val split ---
691
- logger.info(f"Creating balanced split with {args.val_samples_per_class} validation samples per class...")
692
- train_indices, val_indices, class_distribution = create_balanced_split(
 
 
693
  dataset, n_val_per_class=args.val_samples_per_class, seed=seed
694
  )
695
 
696
- # Log class distribution (use set for O(1) lookup)
697
- train_set = set(train_indices)
698
  logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
699
  for class_id, indices in sorted(class_distribution.items()):
700
- n_val = min(len(indices), args.val_samples_per_class)
701
- n_train = len(indices) - n_val
702
- logger.info(f" Class {class_id}: {len(indices)} total (~{n_train} train, {n_val} val)")
 
 
 
 
 
 
 
 
 
703
 
704
  # --- Compute class weights for loss weighting ---
705
  num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
 
234
 
235
  def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
236
  """
237
+ Create a token-grouped train/val split.
238
+
239
+ Validation is selected by token identity first, then balanced by class. This
240
+ prevents the same token from appearing in both splits through different cached
241
+ windows.
242
+
243
+ Returns:
244
+ (train_indices, val_indices, class_to_indices, split_stats)
245
  """
246
  import random
 
247
 
248
+ rng = random.Random(seed)
249
  class_to_indices = defaultdict(list)
250
+ class_to_token_groups = defaultdict(lambda: defaultdict(list))
251
 
252
+ if not hasattr(dataset, "cached_files"):
253
+ raise RuntimeError("Token-grouped split requires a cached dataset with cached_files.")
254
+
255
+ for idx, cached_file in enumerate(dataset.cached_files):
256
+ fname = cached_file.name if hasattr(cached_file, "name") else str(cached_file)
257
+ cached_item = None
258
+ try:
259
+ cached_item = torch.load(cached_file, map_location="cpu", weights_only=False)
260
+ except Exception as exc:
261
+ raise RuntimeError(f"Failed to read cached sample '{fname}' for split planning: {exc}") from exc
262
+
263
+ class_id = int(cached_item.get("class_id", 0))
264
+ source_token = cached_item.get("source_token") or cached_item.get("token_address")
265
+ if not source_token:
266
+ raise RuntimeError(
267
+ f"Cached sample '{fname}' is missing both 'source_token' and 'token_address'; "
268
+ "cannot build token-isolated validation split."
269
+ )
270
+
271
+ class_to_indices[class_id].append(idx)
272
+ class_to_token_groups[class_id][str(source_token)].append(idx)
 
273
 
274
  train_indices = []
275
  val_indices = []
276
+ split_stats = {}
277
+
278
+ for class_id, token_groups in class_to_token_groups.items():
279
+ token_items = list(token_groups.items())
280
+ rng.shuffle(token_items)
281
+
282
+ n_val_tokens = min(len(token_items), n_val_per_class)
283
+ val_token_items = token_items[:n_val_tokens]
284
+ train_token_items = token_items[n_val_tokens:]
285
+
286
+ class_val_indices = []
287
+ class_train_indices = []
288
+ for _, indices in val_token_items:
289
+ class_val_indices.extend(indices)
290
+ for _, indices in train_token_items:
291
+ class_train_indices.extend(indices)
292
+
293
+ val_indices.extend(class_val_indices)
294
+ train_indices.extend(class_train_indices)
295
+ split_stats[class_id] = {
296
+ "total_samples": len(class_to_indices[class_id]),
297
+ "total_tokens": len(token_items),
298
+ "val_samples": len(class_val_indices),
299
+ "val_tokens": len(val_token_items),
300
+ "train_samples": len(class_train_indices),
301
+ "train_tokens": len(train_token_items),
302
+ }
303
 
304
+ rng.shuffle(train_indices)
305
+ rng.shuffle(val_indices)
 
 
 
 
 
 
 
 
306
 
307
+ return train_indices, val_indices, class_to_indices, split_stats
308
 
309
 
310
  def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_fn, vocab):
 
711
  raise RuntimeError("Dataset is empty.")
712
 
713
  # --- NEW: Create balanced train/val split ---
714
+ logger.info(
715
+ f"Creating token-grouped split with {args.val_samples_per_class} validation tokens per class..."
716
+ )
717
+ train_indices, val_indices, class_distribution, split_stats = create_balanced_split(
718
  dataset, n_val_per_class=args.val_samples_per_class, seed=seed
719
  )
720
 
 
 
721
  logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
722
  for class_id, indices in sorted(class_distribution.items()):
723
+ stats = split_stats.get(class_id, {})
724
+ logger.info(
725
+ " Class %s: %s samples across %s tokens (~%s train samples / %s val samples, "
726
+ "%s train tokens / %s val tokens)",
727
+ class_id,
728
+ len(indices),
729
+ stats.get("total_tokens", 0),
730
+ stats.get("train_samples", 0),
731
+ stats.get("val_samples", 0),
732
+ stats.get("train_tokens", 0),
733
+ stats.get("val_tokens", 0),
734
+ )
735
 
736
  # --- Compute class weights for loss weighting ---
737
  num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7