Upload folder using huggingface_hub
Browse files- TASK_LIST.md +87 -0
- cache_dataset.py +225 -58
- data/data_loader.py +154 -16
- log.log +2 -2
- resume.md +190 -0
- scripts/evaluate_sample.py +77 -3
- scripts/rebuild_metadata.py +15 -1
- train.py +77 -45
TASK_LIST.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Task List
|
| 2 |
+
|
| 3 |
+
Fix validation splitting by token identity
|
| 4 |
+
Replace class-only split logic in train.py (line 235)
|
| 5 |
+
Group cached context samples by source_token or token_address
|
| 6 |
+
Ensure one token can only exist in train or val, never both
|
| 7 |
+
Keep class balance as a secondary constraint, not the primary identity rule
|
| 8 |
+
Stop using current validation as decision-grade signal
|
| 9 |
+
Treat old val curves/checkpoints as contaminated
|
| 10 |
+
Re-evaluate only after token-grouped split is in place
|
| 11 |
+
Audit cache metadata and make token identity explicit
|
| 12 |
+
Ensure every cached sample has stable token identity fields
|
| 13 |
+
Required minimum: source_token, class_id
|
| 14 |
+
Prefer also storing lightweight cache-planning metadata for later analysis
|
| 15 |
+
Redesign cache generation around fixed budgets
|
| 16 |
+
Define total cache budget first
|
| 17 |
+
Allocate exact sample counts per token class before writing files
|
| 18 |
+
Do not let raw source distribution decide cache composition
|
| 19 |
+
Remove destructive dependence on token class map filtering alone
|
| 20 |
+
Token class should guide budget allocation
|
| 21 |
+
It should not be the only logic determining whether cache is useful
|
| 22 |
+
Add cache-time context-level balancing
|
| 23 |
+
After sampling a candidate context, evaluate realized future labels for that context
|
| 24 |
+
Use realized context outcome to decide whether to keep or skip it
|
| 25 |
+
Do this before saving to disk
|
| 26 |
+
Start with binary polarity, not movement-threshold balancing
|
| 27 |
+
Positive if max valid horizon return > 0
|
| 28 |
+
Negative otherwise
|
| 29 |
+
Use this only as cache-selection metadata first
|
| 30 |
+
Make polarity quotas class-conditional
|
| 31 |
+
For stronger classes, target positive/negative ratios
|
| 32 |
+
For garbage classes, do not force positives
|
| 33 |
+
Keep class 0 mostly natural/negative
|
| 34 |
+
Keep T_cutoff random during cache generation
|
| 35 |
+
Do not freeze a single deterministic cutoff per token
|
| 36 |
+
Determinism should be in the planning/budget logic, not in removing context diversity
|
| 37 |
+
Add exact acceptance accounting during cache build
|
| 38 |
+
Track how many samples have already been accepted per class
|
| 39 |
+
Track polarity counts per class
|
| 40 |
+
Stop accepting once quotas are filled
|
| 41 |
+
Avoid cache waste from duplicate low-value contexts
|
| 42 |
+
Add retry/attempt limits per token
|
| 43 |
+
If a token cannot satisfy desired quota type, stop oversampling it endlessly
|
| 44 |
+
Move on to other tokens instead of filling disk with junk
|
| 45 |
+
Keep label derivation in the data pipeline, not in training logic
|
| 46 |
+
Loader should produce final labels and masks
|
| 47 |
+
Collator should only stack/batch them
|
| 48 |
+
Model should only consume them
|
| 49 |
+
Reduce or remove train-time class reweighting after cache is fixed
|
| 50 |
+
Revisit WeightedRandomSampler
|
| 51 |
+
Revisit class_loss_weights
|
| 52 |
+
If cache is balanced upstream, training should not need heavy rescue weighting
|
| 53 |
+
Revisit movement head only after split and cache are fixed
|
| 54 |
+
Keep it auxiliary
|
| 55 |
+
Do not let movement-label threshold debates block the more important data fixes
|
| 56 |
+
Later simplify naming/threshold assumptions if needed
|
| 57 |
+
Add cache audit tooling
|
| 58 |
+
Report counts by class_id
|
| 59 |
+
Report counts by class x polarity
|
| 60 |
+
Report unique tokens by class
|
| 61 |
+
Report acceptance/rejection reasons
|
| 62 |
+
Report train/val token overlap check
|
| 63 |
+
Add validation integrity checks
|
| 64 |
+
Assert zero token overlap between train and val
|
| 65 |
+
Print per-class token counts, not just sample counts
|
| 66 |
+
Print per-class sample counts too
|
| 67 |
+
Rebuild cache after the new policy is implemented
|
| 68 |
+
Old cache is shaped by the wrong distribution
|
| 69 |
+
Old validation split is not trustworthy
|
| 70 |
+
New training should start from the rebuilt corpus
|
| 71 |
+
Retrain and re-baseline from scratch
|
| 72 |
+
New split
|
| 73 |
+
New cache
|
| 74 |
+
Minimal train-time rescue weighting
|
| 75 |
+
Recompare backbone behavior only after that
|
| 76 |
+
Recommended implementation order
|
| 77 |
+
|
| 78 |
+
Token-grouped validation split
|
| 79 |
+
Validation overlap checks
|
| 80 |
+
Cache metadata cleanup
|
| 81 |
+
Exact class quotas in cache generation
|
| 82 |
+
Class-conditional polarity quotas
|
| 83 |
+
Cache audit reports
|
| 84 |
+
Remove/reduce train-time weighting
|
| 85 |
+
Rebuild cache
|
| 86 |
+
Retrain
|
| 87 |
+
Reassess movement head
|
cache_dataset.py
CHANGED
|
@@ -23,6 +23,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 23 |
|
| 24 |
from scripts.analyze_distribution import get_return_class_map
|
| 25 |
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
|
|
|
| 26 |
|
| 27 |
from clickhouse_driver import Client as ClickHouseClient
|
| 28 |
from neo4j import GraphDatabase
|
|
@@ -32,6 +33,61 @@ _worker_return_class_map = None
|
|
| 32 |
_worker_quality_scores_map = None
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 36 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 37 |
from data.data_loader import OracleDataset
|
|
@@ -43,7 +99,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 43 |
|
| 44 |
_worker_dataset = OracleDataset(
|
| 45 |
data_fetcher=data_fetcher,
|
| 46 |
-
|
| 47 |
start_date=dataset_config['start_date'],
|
| 48 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 49 |
quantiles=dataset_config['quantiles'],
|
|
@@ -68,42 +124,15 @@ def _process_single_token_context(args):
|
|
| 68 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 69 |
if q_score is None:
|
| 70 |
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
saved_files.append(filename)
|
| 81 |
-
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
|
| 82 |
-
except Exception as e:
|
| 83 |
-
import traceback
|
| 84 |
-
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _process_single_token_raw(args):
|
| 88 |
-
idx, mint_addr, output_dir = args
|
| 89 |
-
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 90 |
-
try:
|
| 91 |
-
class_id = _worker_return_class_map.get(mint_addr)
|
| 92 |
-
if class_id is None:
|
| 93 |
-
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
|
| 94 |
-
item = _worker_dataset.__cacheitem__(idx)
|
| 95 |
-
if item is None:
|
| 96 |
-
return {'status': 'skipped', 'reason': 'cacheitem returned None', 'mint': mint_addr}
|
| 97 |
-
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 98 |
-
if q_score is None:
|
| 99 |
-
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 100 |
-
item["quality_score"] = q_score
|
| 101 |
-
item["class_id"] = class_id
|
| 102 |
-
item["cache_mode"] = "raw"
|
| 103 |
-
filename = f"sample_{mint_addr[:16]}.pt"
|
| 104 |
-
output_path = Path(output_dir) / filename
|
| 105 |
-
torch.save(item, output_path)
|
| 106 |
-
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_trades': len(item.get('trades', [])), 'files': [filename]}
|
| 107 |
except Exception as e:
|
| 108 |
import traceback
|
| 109 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
|
@@ -122,14 +151,16 @@ def main():
|
|
| 122 |
|
| 123 |
parser = argparse.ArgumentParser()
|
| 124 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 125 |
-
parser.add_argument("--max_samples", type=int, default=None)
|
| 126 |
parser.add_argument("--start_date", type=str, default=None)
|
| 127 |
|
| 128 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 129 |
-
parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
|
| 130 |
-
parser.add_argument("--context_length", type=int, default=8192)
|
| 131 |
parser.add_argument("--min_trades", type=int, default=10)
|
|
|
|
| 132 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
parser.add_argument("--num_workers", type=int, default=1)
|
| 134 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 135 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
|
@@ -138,6 +169,11 @@ def main():
|
|
| 138 |
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
|
| 139 |
args = parser.parse_args()
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
if args.num_workers == 0:
|
| 142 |
args.num_workers = max(1, mp.cpu_count() - 4)
|
| 143 |
|
|
@@ -163,7 +199,15 @@ def main():
|
|
| 163 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 164 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 165 |
|
| 166 |
-
dataset = OracleDataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
if len(dataset) == 0:
|
| 169 |
print("WARNING: No samples. Exiting.")
|
|
@@ -178,25 +222,52 @@ def main():
|
|
| 178 |
print("WARNING: No tokens after filtering.")
|
| 179 |
return
|
| 180 |
|
| 181 |
-
print(f"INFO:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 184 |
-
dataset_config = {'
|
| 185 |
|
| 186 |
# Build tasks from filtered_mints directly
|
| 187 |
tasks = []
|
| 188 |
for i, mint_record in enumerate(filtered_mints):
|
| 189 |
mint_addr = mint_record['mint_address']
|
| 190 |
-
|
| 191 |
-
tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
|
| 192 |
-
else:
|
| 193 |
-
tasks.append((i, mint_addr, str(output_dir)))
|
| 194 |
|
| 195 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 196 |
|
| 197 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 198 |
class_distribution = {}
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
if args.num_workers == 1:
|
| 202 |
print("INFO: Single-threaded mode...")
|
|
@@ -204,8 +275,61 @@ def main():
|
|
| 204 |
for task in tqdm(tasks, desc="Caching"):
|
| 205 |
result = process_fn(task)
|
| 206 |
if result['status'] == 'success':
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
elif result['status'] == 'skipped':
|
| 210 |
skipped_count += 1
|
| 211 |
else:
|
|
@@ -219,8 +343,26 @@ def main():
|
|
| 219 |
try:
|
| 220 |
result = future.result(timeout=300)
|
| 221 |
if result['status'] == 'success':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
success_count += 1
|
| 223 |
-
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 224 |
elif result['status'] == 'skipped':
|
| 225 |
skipped_count += 1
|
| 226 |
else:
|
|
@@ -230,15 +372,40 @@ def main():
|
|
| 230 |
tqdm.write(f"WORKER ERROR: {e}")
|
| 231 |
|
| 232 |
print("INFO: Building metadata...")
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 241 |
-
json.dump({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
| 244 |
|
|
|
|
| 23 |
|
| 24 |
from scripts.analyze_distribution import get_return_class_map
|
| 25 |
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
| 26 |
+
from data.data_loader import summarize_context_window
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
| 29 |
from neo4j import GraphDatabase
|
|
|
|
| 33 |
_worker_quality_scores_map = None
|
| 34 |
|
| 35 |
|
| 36 |
+
def _build_context_quota_plan(
|
| 37 |
+
class_ids,
|
| 38 |
+
target_contexts_per_class,
|
| 39 |
+
target_contexts_total,
|
| 40 |
+
good_ratio_nonzero,
|
| 41 |
+
good_ratio_class0,
|
| 42 |
+
):
|
| 43 |
+
unique_class_ids = sorted(set(int(cid) for cid in class_ids))
|
| 44 |
+
if not unique_class_ids:
|
| 45 |
+
return {}
|
| 46 |
+
|
| 47 |
+
if target_contexts_per_class is not None:
|
| 48 |
+
per_class_target = int(target_contexts_per_class)
|
| 49 |
+
elif target_contexts_total is not None:
|
| 50 |
+
per_class_target = max(1, int(target_contexts_total) // len(unique_class_ids))
|
| 51 |
+
else:
|
| 52 |
+
return {}
|
| 53 |
+
|
| 54 |
+
if per_class_target <= 0:
|
| 55 |
+
raise RuntimeError("Context quota target must be positive.")
|
| 56 |
+
|
| 57 |
+
plan = {}
|
| 58 |
+
for class_id in unique_class_ids:
|
| 59 |
+
ratio = float(good_ratio_class0 if class_id == 0 else good_ratio_nonzero)
|
| 60 |
+
ratio = max(0.0, min(1.0, ratio))
|
| 61 |
+
good_target = int(round(per_class_target * ratio))
|
| 62 |
+
bad_target = per_class_target - good_target
|
| 63 |
+
plan[class_id] = {
|
| 64 |
+
"total_target": per_class_target,
|
| 65 |
+
"good_target": good_target,
|
| 66 |
+
"bad_target": bad_target,
|
| 67 |
+
}
|
| 68 |
+
return plan
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
|
| 72 |
+
if not quota_plan:
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
if class_id not in quota_plan:
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
class_plan = quota_plan[class_id]
|
| 79 |
+
class_counts = accepted_counts[class_id]
|
| 80 |
+
if class_counts["total"] >= class_plan["total_target"]:
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
bucket_key = "good" if context_bucket == "good" else "bad"
|
| 84 |
+
target_key = f"{bucket_key}_target"
|
| 85 |
+
if class_counts[bucket_key] >= class_plan[target_key]:
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 92 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 93 |
from data.data_loader import OracleDataset
|
|
|
|
| 99 |
|
| 100 |
_worker_dataset = OracleDataset(
|
| 101 |
data_fetcher=data_fetcher,
|
| 102 |
+
min_trades=dataset_config['min_trades'],
|
| 103 |
start_date=dataset_config['start_date'],
|
| 104 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 105 |
quantiles=dataset_config['quantiles'],
|
|
|
|
| 124 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 125 |
if q_score is None:
|
| 126 |
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 127 |
+
return {
|
| 128 |
+
'status': 'success',
|
| 129 |
+
'mint': mint_addr,
|
| 130 |
+
'class_id': class_id,
|
| 131 |
+
'q_score': q_score,
|
| 132 |
+
'n_contexts': len(contexts),
|
| 133 |
+
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
|
| 134 |
+
'contexts': contexts,
|
| 135 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
except Exception as e:
|
| 137 |
import traceback
|
| 138 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
|
|
|
| 151 |
|
| 152 |
parser = argparse.ArgumentParser()
|
| 153 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
|
|
|
| 154 |
parser.add_argument("--start_date", type=str, default=None)
|
| 155 |
|
| 156 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
|
|
|
|
|
|
| 157 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 158 |
+
parser.add_argument("--context_length", type=int, default=8192)
|
| 159 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
| 160 |
+
parser.add_argument("--target_contexts_per_class", type=int, default=None)
|
| 161 |
+
parser.add_argument("--target_contexts_total", type=int, default=None)
|
| 162 |
+
parser.add_argument("--good_ratio_nonzero", type=float, default=0.5)
|
| 163 |
+
parser.add_argument("--good_ratio_class0", type=float, default=0.0)
|
| 164 |
parser.add_argument("--num_workers", type=int, default=1)
|
| 165 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 166 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
|
|
|
| 169 |
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
|
| 170 |
args = parser.parse_args()
|
| 171 |
|
| 172 |
+
if args.target_contexts_per_class is not None and args.target_contexts_total is not None:
|
| 173 |
+
raise RuntimeError(
|
| 174 |
+
"Choose exactly one cache budget: either --target_contexts_per_class or --target_contexts_total."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
if args.num_workers == 0:
|
| 178 |
args.num_workers = max(1, mp.cpu_count() - 4)
|
| 179 |
|
|
|
|
| 199 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 200 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 201 |
|
| 202 |
+
dataset = OracleDataset(
|
| 203 |
+
data_fetcher=data_fetcher,
|
| 204 |
+
min_trades=args.min_trades,
|
| 205 |
+
start_date=start_date_dt,
|
| 206 |
+
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
|
| 207 |
+
quantiles=[0.5],
|
| 208 |
+
min_trade_usd=args.min_trade_usd,
|
| 209 |
+
max_seq_len=args.context_length,
|
| 210 |
+
)
|
| 211 |
|
| 212 |
if len(dataset) == 0:
|
| 213 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 222 |
print("WARNING: No tokens after filtering.")
|
| 223 |
return
|
| 224 |
|
| 225 |
+
print(f"INFO: Building canonical context cache | Workers: {args.num_workers}")
|
| 226 |
+
|
| 227 |
+
if args.num_workers != 1 and (
|
| 228 |
+
args.target_contexts_per_class is not None or args.target_contexts_total is not None
|
| 229 |
+
):
|
| 230 |
+
raise RuntimeError(
|
| 231 |
+
"Quota-driven context caching currently requires --num_workers 1 so accepted contexts "
|
| 232 |
+
"can be planned and written deterministically in one process."
|
| 233 |
+
)
|
| 234 |
|
| 235 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 236 |
+
dataset_config = {'start_date': start_date_dt, 'min_trades': args.min_trades, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
|
| 237 |
|
| 238 |
# Build tasks from filtered_mints directly
|
| 239 |
tasks = []
|
| 240 |
for i, mint_record in enumerate(filtered_mints):
|
| 241 |
mint_addr = mint_record['mint_address']
|
| 242 |
+
tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 245 |
|
| 246 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 247 |
class_distribution = {}
|
| 248 |
+
context_distribution = defaultdict(lambda: defaultdict(int))
|
| 249 |
+
file_class_map = {}
|
| 250 |
+
file_context_bucket_map = {}
|
| 251 |
+
file_context_summary_map = {}
|
| 252 |
+
process_fn = _process_single_token_context
|
| 253 |
+
quota_plan = {}
|
| 254 |
+
accepted_counts = defaultdict(lambda: {"total": 0, "good": 0, "bad": 0})
|
| 255 |
+
accepted_per_token = defaultdict(int)
|
| 256 |
+
|
| 257 |
+
quota_plan = _build_context_quota_plan(
|
| 258 |
+
class_ids=[return_class_map[m['mint_address']] for m in filtered_mints if m['mint_address'] in return_class_map],
|
| 259 |
+
target_contexts_per_class=args.target_contexts_per_class,
|
| 260 |
+
target_contexts_total=args.target_contexts_total,
|
| 261 |
+
good_ratio_nonzero=args.good_ratio_nonzero,
|
| 262 |
+
good_ratio_class0=args.good_ratio_class0,
|
| 263 |
+
)
|
| 264 |
+
if quota_plan:
|
| 265 |
+
print("INFO: Context quota plan:")
|
| 266 |
+
for class_id, plan in sorted(quota_plan.items()):
|
| 267 |
+
print(
|
| 268 |
+
f" Class {class_id}: total={plan['total_target']} "
|
| 269 |
+
f"(good={plan['good_target']}, bad={plan['bad_target']})"
|
| 270 |
+
)
|
| 271 |
|
| 272 |
if args.num_workers == 1:
|
| 273 |
print("INFO: Single-threaded mode...")
|
|
|
|
| 275 |
for task in tqdm(tasks, desc="Caching"):
|
| 276 |
result = process_fn(task)
|
| 277 |
if result['status'] == 'success':
|
| 278 |
+
if quota_plan:
|
| 279 |
+
class_id = result['class_id']
|
| 280 |
+
mint_addr = result['mint']
|
| 281 |
+
q_score = result['q_score']
|
| 282 |
+
saved_any = False
|
| 283 |
+
for ctx in result.get("contexts", []):
|
| 284 |
+
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
|
| 285 |
+
context_bucket = context_summary["context_bucket"]
|
| 286 |
+
if not _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
ctx["quality_score"] = q_score
|
| 290 |
+
ctx["class_id"] = class_id
|
| 291 |
+
ctx["source_token"] = mint_addr
|
| 292 |
+
ctx["context_bucket"] = context_bucket
|
| 293 |
+
ctx["context_score"] = context_summary["context_score"]
|
| 294 |
+
|
| 295 |
+
file_idx = accepted_per_token[mint_addr]
|
| 296 |
+
filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
|
| 297 |
+
output_path = Path(output_dir) / filename
|
| 298 |
+
torch.save(ctx, output_path)
|
| 299 |
+
|
| 300 |
+
accepted_per_token[mint_addr] += 1
|
| 301 |
+
accepted_counts[class_id]["total"] += 1
|
| 302 |
+
accepted_counts[class_id][context_bucket] += 1
|
| 303 |
+
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
|
| 304 |
+
context_distribution[class_id][context_bucket] += 1
|
| 305 |
+
file_class_map[filename] = class_id
|
| 306 |
+
file_context_bucket_map[filename] = context_bucket
|
| 307 |
+
file_context_summary_map[filename] = context_summary
|
| 308 |
+
saved_any = True
|
| 309 |
+
|
| 310 |
+
if saved_any:
|
| 311 |
+
success_count += 1
|
| 312 |
+
else:
|
| 313 |
+
class_id = result['class_id']
|
| 314 |
+
mint_addr = result['mint']
|
| 315 |
+
q_score = result['q_score']
|
| 316 |
+
for ctx_idx, ctx in enumerate(result.get("contexts", [])):
|
| 317 |
+
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
|
| 318 |
+
context_bucket = context_summary["context_bucket"]
|
| 319 |
+
ctx["quality_score"] = q_score
|
| 320 |
+
ctx["class_id"] = class_id
|
| 321 |
+
ctx["source_token"] = mint_addr
|
| 322 |
+
ctx["context_bucket"] = context_bucket
|
| 323 |
+
ctx["context_score"] = context_summary["context_score"]
|
| 324 |
+
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 325 |
+
output_path = Path(output_dir) / filename
|
| 326 |
+
torch.save(ctx, output_path)
|
| 327 |
+
file_class_map[filename] = class_id
|
| 328 |
+
file_context_bucket_map[filename] = context_bucket
|
| 329 |
+
file_context_summary_map[filename] = context_summary
|
| 330 |
+
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
|
| 331 |
+
context_distribution[class_id][context_bucket] += 1
|
| 332 |
+
success_count += 1
|
| 333 |
elif result['status'] == 'skipped':
|
| 334 |
skipped_count += 1
|
| 335 |
else:
|
|
|
|
| 343 |
try:
|
| 344 |
result = future.result(timeout=300)
|
| 345 |
if result['status'] == 'success':
|
| 346 |
+
class_id = result['class_id']
|
| 347 |
+
mint_addr = result['mint']
|
| 348 |
+
q_score = result['q_score']
|
| 349 |
+
for ctx_idx, ctx in enumerate(result.get("contexts", [])):
|
| 350 |
+
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
|
| 351 |
+
context_bucket = context_summary["context_bucket"]
|
| 352 |
+
ctx["quality_score"] = q_score
|
| 353 |
+
ctx["class_id"] = class_id
|
| 354 |
+
ctx["source_token"] = mint_addr
|
| 355 |
+
ctx["context_bucket"] = context_bucket
|
| 356 |
+
ctx["context_score"] = context_summary["context_score"]
|
| 357 |
+
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 358 |
+
output_path = Path(output_dir) / filename
|
| 359 |
+
torch.save(ctx, output_path)
|
| 360 |
+
file_class_map[filename] = class_id
|
| 361 |
+
file_context_bucket_map[filename] = context_bucket
|
| 362 |
+
file_context_summary_map[filename] = context_summary
|
| 363 |
+
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
|
| 364 |
+
context_distribution[class_id][context_bucket] += 1
|
| 365 |
success_count += 1
|
|
|
|
| 366 |
elif result['status'] == 'skipped':
|
| 367 |
skipped_count += 1
|
| 368 |
else:
|
|
|
|
| 372 |
tqdm.write(f"WORKER ERROR: {e}")
|
| 373 |
|
| 374 |
print("INFO: Building metadata...")
|
| 375 |
+
if not file_class_map:
|
| 376 |
+
for f in sorted(output_dir.glob("sample_*.pt")):
|
| 377 |
+
try:
|
| 378 |
+
cached = torch.load(f, map_location="cpu", weights_only=False)
|
| 379 |
+
file_class_map[f.name] = cached.get("class_id", 0)
|
| 380 |
+
if "labels" in cached and "labels_mask" in cached:
|
| 381 |
+
context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
|
| 382 |
+
file_context_bucket_map[f.name] = context_summary["context_bucket"]
|
| 383 |
+
file_context_summary_map[f.name] = context_summary
|
| 384 |
+
except Exception:
|
| 385 |
+
pass
|
| 386 |
|
| 387 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 388 |
+
json.dump({
|
| 389 |
+
'file_class_map': file_class_map,
|
| 390 |
+
'file_context_bucket_map': file_context_bucket_map,
|
| 391 |
+
'file_context_summary_map': file_context_summary_map,
|
| 392 |
+
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 393 |
+
'context_distribution': {
|
| 394 |
+
str(k): {bucket: count for bucket, count in bucket_counts.items()}
|
| 395 |
+
for k, bucket_counts in context_distribution.items()
|
| 396 |
+
},
|
| 397 |
+
'quota_plan': {str(k): v for k, v in quota_plan.items()},
|
| 398 |
+
'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
|
| 399 |
+
'num_workers': args.num_workers,
|
| 400 |
+
}, f, indent=2)
|
| 401 |
+
|
| 402 |
+
if quota_plan:
|
| 403 |
+
print("INFO: Accepted context counts:")
|
| 404 |
+
for class_id, counts in sorted(accepted_counts.items()):
|
| 405 |
+
print(
|
| 406 |
+
f" Class {class_id}: total={counts['total']} "
|
| 407 |
+
f"good={counts['good']} bad={counts['bad']}"
|
| 408 |
+
)
|
| 409 |
|
| 410 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
| 411 |
|
data/data_loader.py
CHANGED
|
@@ -64,6 +64,63 @@ MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
|
|
| 64 |
HOLDER_SNAPSHOT_INTERVAL_SEC = 300
|
| 65 |
HOLDER_SNAPSHOT_TOP_K = 200
|
| 66 |
DEAD_URI_RETRY_LIMIT = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class EmbeddingPooler:
|
|
@@ -123,6 +180,7 @@ class OracleDataset(Dataset):
|
|
| 123 |
horizons_seconds: List[int] = [],
|
| 124 |
quantiles: List[float] = [],
|
| 125 |
max_samples: Optional[int] = None,
|
|
|
|
| 126 |
|
| 127 |
token_allowlist: Optional[List[str]] = None,
|
| 128 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
@@ -131,8 +189,11 @@ class OracleDataset(Dataset):
|
|
| 131 |
max_seq_len: int = 8192,
|
| 132 |
p99_clamps: Optional[Dict[str, float]] = None,
|
| 133 |
movement_label_config: Optional[Dict[str, float]] = None):
|
| 134 |
-
|
| 135 |
self.max_seq_len = max_seq_len
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# --- P99 data-driven clamp values (replace hardcoded min/max) ---
|
| 138 |
self.p99_clamps = {
|
|
@@ -187,9 +248,12 @@ class OracleDataset(Dataset):
|
|
| 187 |
if not self.cached_files:
|
| 188 |
raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
|
| 189 |
|
| 190 |
-
# --- OPTIMIZED: Load
|
| 191 |
file_class_map = {}
|
|
|
|
|
|
|
| 192 |
class_counts = defaultdict(int)
|
|
|
|
| 193 |
metadata_path = self.cache_dir / "class_metadata.json"
|
| 194 |
|
| 195 |
if metadata_path.exists():
|
|
@@ -199,20 +263,29 @@ class OracleDataset(Dataset):
|
|
| 199 |
with open(metadata_path, 'r') as f:
|
| 200 |
cached_metadata = json.load(f)
|
| 201 |
file_class_map = cached_metadata.get('file_class_map', {})
|
|
|
|
|
|
|
| 202 |
# Validate that cached files match metadata
|
| 203 |
cached_file_names = {p.name for p in self.cached_files}
|
| 204 |
metadata_file_names = set(file_class_map.keys())
|
| 205 |
if cached_file_names != metadata_file_names:
|
| 206 |
print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...")
|
| 207 |
file_class_map = {}
|
|
|
|
|
|
|
| 208 |
else:
|
| 209 |
# Rebuild class_counts from loaded map
|
| 210 |
-
for cid in file_class_map.
|
| 211 |
class_counts[cid] += 1
|
|
|
|
|
|
|
|
|
|
| 212 |
print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s")
|
| 213 |
except Exception as e:
|
| 214 |
print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...")
|
| 215 |
file_class_map = {}
|
|
|
|
|
|
|
| 216 |
|
| 217 |
# Slow path: scan all files and build metadata cache
|
| 218 |
if not file_class_map:
|
|
@@ -229,23 +302,41 @@ class OracleDataset(Dataset):
|
|
| 229 |
if cid is None:
|
| 230 |
print(f"WARN: File {p.name} missing class_id. Skipping.")
|
| 231 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
file_class_map[p.name] = cid
|
|
|
|
|
|
|
| 233 |
class_counts[cid] += 1
|
|
|
|
| 234 |
except Exception as e:
|
| 235 |
print(f"WARN: Failed to read cached sample {p.name}: {e}")
|
| 236 |
|
| 237 |
# Save metadata cache for future runs
|
| 238 |
try:
|
| 239 |
with open(metadata_path, 'w') as f:
|
| 240 |
-
json.dump({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
print(f"INFO: Saved class metadata cache to {metadata_path}")
|
| 242 |
except Exception as e:
|
| 243 |
print(f"WARN: Failed to save metadata cache: {e}")
|
| 244 |
|
| 245 |
print(f"INFO: Class Distribution: {dict(class_counts)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
# Store file_class_map for fast lookup by train.py's create_balanced_split
|
| 248 |
self.file_class_map = {p: cid for p, cid in file_class_map.items()}
|
|
|
|
|
|
|
| 249 |
|
| 250 |
# Compute Weights
|
| 251 |
self.weights_list = []
|
|
@@ -260,8 +351,24 @@ class OracleDataset(Dataset):
|
|
| 260 |
continue
|
| 261 |
|
| 262 |
cid = file_class_map[fname]
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
self.weights_list.append(weight)
|
| 266 |
valid_files.append(p)
|
| 267 |
|
|
@@ -273,6 +380,25 @@ class OracleDataset(Dataset):
|
|
| 273 |
self.cached_files = self.cached_files[:self.num_samples]
|
| 274 |
self.weights_list = self.weights_list[:self.num_samples]
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.")
|
| 277 |
self.sampled_mints = [] # Not needed in cached mode
|
| 278 |
self.available_mints = []
|
|
@@ -1297,7 +1423,7 @@ class OracleDataset(Dataset):
|
|
| 1297 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 1298 |
)
|
| 1299 |
|
| 1300 |
-
min_context_trades =
|
| 1301 |
if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
|
| 1302 |
return None
|
| 1303 |
|
|
@@ -1626,7 +1752,7 @@ class OracleDataset(Dataset):
|
|
| 1626 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1627 |
include_wallet_data=False,
|
| 1628 |
include_graph=False,
|
| 1629 |
-
min_trades=
|
| 1630 |
full_history=True, # Bypass H/B/H limits
|
| 1631 |
prune_failed=False, # Keep failed trades for realistic simulation
|
| 1632 |
prune_transfers=False # Keep transfers for snapshot reconstruction
|
|
@@ -1641,8 +1767,14 @@ class OracleDataset(Dataset):
|
|
| 1641 |
raw_data['name'] = initial_mint_record.get('token_name', '')
|
| 1642 |
raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
|
| 1643 |
raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
|
| 1644 |
-
|
| 1645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1646 |
raw_data['protocol'] = initial_mint_record.get('protocol', 1)
|
| 1647 |
|
| 1648 |
def _timestamp_to_order_value(ts_value: Any) -> float:
|
|
@@ -2686,7 +2818,7 @@ class OracleDataset(Dataset):
|
|
| 2686 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 2687 |
include_wallet_data=False,
|
| 2688 |
include_graph=False,
|
| 2689 |
-
min_trades=
|
| 2690 |
full_history=True,
|
| 2691 |
prune_failed=False,
|
| 2692 |
prune_transfers=False
|
|
@@ -2704,8 +2836,14 @@ class OracleDataset(Dataset):
|
|
| 2704 |
raw_data['name'] = initial_mint_record.get('token_name', '')
|
| 2705 |
raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
|
| 2706 |
raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
|
| 2707 |
-
|
| 2708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2709 |
raw_data['protocol'] = initial_mint_record.get('protocol', 1)
|
| 2710 |
|
| 2711 |
def _timestamp_to_order_value(ts_value) -> float:
|
|
@@ -2734,7 +2872,7 @@ class OracleDataset(Dataset):
|
|
| 2734 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 2735 |
)
|
| 2736 |
|
| 2737 |
-
min_context_trades =
|
| 2738 |
if len(all_trades_sorted) < (min_context_trades + 1):
|
| 2739 |
print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
|
| 2740 |
return []
|
|
@@ -2821,9 +2959,9 @@ class OracleDataset(Dataset):
|
|
| 2821 |
total_supply_raw = int(raw_total_supply)
|
| 2822 |
token_decimals = int(raw_decimals)
|
| 2823 |
if total_supply_raw <= 0:
|
| 2824 |
-
|
| 2825 |
if token_decimals < 0:
|
| 2826 |
-
|
| 2827 |
token_scale = 10 ** token_decimals
|
| 2828 |
|
| 2829 |
def _strict_int(v: Any, field_name: str) -> int:
|
|
|
|
| 64 |
HOLDER_SNAPSHOT_INTERVAL_SEC = 300
|
| 65 |
HOLDER_SNAPSHOT_TOP_K = 200
|
| 66 |
DEAD_URI_RETRY_LIMIT = 2
|
| 67 |
+
DEFAULT_TOTAL_SUPPLY_RAW = 1_000_000_000_000_000
|
| 68 |
+
DEFAULT_TOKEN_DECIMALS = 6
|
| 69 |
+
|
| 70 |
+
CONTEXT_BUCKET_NEGATIVE = "bad"
|
| 71 |
+
CONTEXT_BUCKET_POSITIVE = "good"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def summarize_context_window(
|
| 75 |
+
labels: Any,
|
| 76 |
+
labels_mask: Any,
|
| 77 |
+
) -> Dict[str, Any]:
|
| 78 |
+
"""
|
| 79 |
+
Summarize a realized context window using its valid future returns.
|
| 80 |
+
|
| 81 |
+
Base rule:
|
| 82 |
+
- each horizon contributes signed terminal PnL from buying at cutoff
|
| 83 |
+
- magnitude matters, so we compress returns with signed log1p
|
| 84 |
+
- the context is `good` only if the net score is positive
|
| 85 |
+
"""
|
| 86 |
+
if labels is None or labels_mask is None:
|
| 87 |
+
raise RuntimeError("Context weighting requires both 'labels' and 'labels_mask'.")
|
| 88 |
+
|
| 89 |
+
if isinstance(labels, torch.Tensor):
|
| 90 |
+
label_vals = labels.tolist()
|
| 91 |
+
else:
|
| 92 |
+
label_vals = list(labels)
|
| 93 |
+
|
| 94 |
+
if isinstance(labels_mask, torch.Tensor):
|
| 95 |
+
mask_vals = labels_mask.tolist()
|
| 96 |
+
else:
|
| 97 |
+
mask_vals = list(labels_mask)
|
| 98 |
+
|
| 99 |
+
valid_returns = [
|
| 100 |
+
float(ret)
|
| 101 |
+
for ret, keep in zip(label_vals, mask_vals)
|
| 102 |
+
if float(keep) > 0.0
|
| 103 |
+
]
|
| 104 |
+
signed_contributions = []
|
| 105 |
+
for ret in valid_returns:
|
| 106 |
+
magnitude = np.log1p(abs(ret))
|
| 107 |
+
signed_contributions.append(magnitude if ret > 0.0 else -magnitude)
|
| 108 |
+
|
| 109 |
+
positive_count = sum(1 for ret in valid_returns if ret > 0.0)
|
| 110 |
+
negative_count = len(valid_returns) - positive_count
|
| 111 |
+
context_score = float(sum(signed_contributions) / len(signed_contributions)) if signed_contributions else 0.0
|
| 112 |
+
context_bucket = (
|
| 113 |
+
CONTEXT_BUCKET_POSITIVE
|
| 114 |
+
if context_score > 0.0
|
| 115 |
+
else CONTEXT_BUCKET_NEGATIVE
|
| 116 |
+
)
|
| 117 |
+
return {
|
| 118 |
+
"context_bucket": context_bucket,
|
| 119 |
+
"context_score": context_score,
|
| 120 |
+
"positive_horizons": positive_count,
|
| 121 |
+
"negative_horizons": negative_count,
|
| 122 |
+
"valid_horizons": len(valid_returns),
|
| 123 |
+
}
|
| 124 |
|
| 125 |
|
| 126 |
class EmbeddingPooler:
|
|
|
|
| 180 |
horizons_seconds: List[int] = [],
|
| 181 |
quantiles: List[float] = [],
|
| 182 |
max_samples: Optional[int] = None,
|
| 183 |
+
min_trades: int = 10,
|
| 184 |
|
| 185 |
token_allowlist: Optional[List[str]] = None,
|
| 186 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
|
|
| 189 |
max_seq_len: int = 8192,
|
| 190 |
p99_clamps: Optional[Dict[str, float]] = None,
|
| 191 |
movement_label_config: Optional[Dict[str, float]] = None):
|
| 192 |
+
|
| 193 |
self.max_seq_len = max_seq_len
|
| 194 |
+
self.min_trades = int(min_trades)
|
| 195 |
+
if self.min_trades < 1:
|
| 196 |
+
raise RuntimeError(f"min_trades must be >= 1, got {self.min_trades}")
|
| 197 |
|
| 198 |
# --- P99 data-driven clamp values (replace hardcoded min/max) ---
|
| 199 |
self.p99_clamps = {
|
|
|
|
| 248 |
if not self.cached_files:
|
| 249 |
raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
|
| 250 |
|
| 251 |
+
# --- OPTIMIZED: Load cached metadata if available ---
|
| 252 |
file_class_map = {}
|
| 253 |
+
file_context_bucket_map = {}
|
| 254 |
+
file_context_summary_map = {}
|
| 255 |
class_counts = defaultdict(int)
|
| 256 |
+
class_context_counts = defaultdict(lambda: defaultdict(int))
|
| 257 |
metadata_path = self.cache_dir / "class_metadata.json"
|
| 258 |
|
| 259 |
if metadata_path.exists():
|
|
|
|
| 263 |
with open(metadata_path, 'r') as f:
|
| 264 |
cached_metadata = json.load(f)
|
| 265 |
file_class_map = cached_metadata.get('file_class_map', {})
|
| 266 |
+
file_context_bucket_map = cached_metadata.get('file_context_bucket_map', {})
|
| 267 |
+
file_context_summary_map = cached_metadata.get('file_context_summary_map', {})
|
| 268 |
# Validate that cached files match metadata
|
| 269 |
cached_file_names = {p.name for p in self.cached_files}
|
| 270 |
metadata_file_names = set(file_class_map.keys())
|
| 271 |
if cached_file_names != metadata_file_names:
|
| 272 |
print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...")
|
| 273 |
file_class_map = {}
|
| 274 |
+
file_context_bucket_map = {}
|
| 275 |
+
file_context_summary_map = {}
|
| 276 |
else:
|
| 277 |
# Rebuild class_counts from loaded map
|
| 278 |
+
for fname, cid in file_class_map.items():
|
| 279 |
class_counts[cid] += 1
|
| 280 |
+
bucket = file_context_bucket_map.get(fname)
|
| 281 |
+
if bucket is not None:
|
| 282 |
+
class_context_counts[cid][bucket] += 1
|
| 283 |
print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s")
|
| 284 |
except Exception as e:
|
| 285 |
print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...")
|
| 286 |
file_class_map = {}
|
| 287 |
+
file_context_bucket_map = {}
|
| 288 |
+
file_context_summary_map = {}
|
| 289 |
|
| 290 |
# Slow path: scan all files and build metadata cache
|
| 291 |
if not file_class_map:
|
|
|
|
| 302 |
if cid is None:
|
| 303 |
print(f"WARN: File {p.name} missing class_id. Skipping.")
|
| 304 |
continue
|
| 305 |
+
context_summary = summarize_context_window(
|
| 306 |
+
cached_item.get("labels"),
|
| 307 |
+
cached_item.get("labels_mask"),
|
| 308 |
+
)
|
| 309 |
+
bucket = context_summary["context_bucket"]
|
| 310 |
file_class_map[p.name] = cid
|
| 311 |
+
file_context_bucket_map[p.name] = bucket
|
| 312 |
+
file_context_summary_map[p.name] = context_summary
|
| 313 |
class_counts[cid] += 1
|
| 314 |
+
class_context_counts[cid][bucket] += 1
|
| 315 |
except Exception as e:
|
| 316 |
print(f"WARN: Failed to read cached sample {p.name}: {e}")
|
| 317 |
|
| 318 |
# Save metadata cache for future runs
|
| 319 |
try:
|
| 320 |
with open(metadata_path, 'w') as f:
|
| 321 |
+
json.dump({
|
| 322 |
+
'file_class_map': file_class_map,
|
| 323 |
+
'file_context_bucket_map': file_context_bucket_map,
|
| 324 |
+
'file_context_summary_map': file_context_summary_map,
|
| 325 |
+
}, f)
|
| 326 |
print(f"INFO: Saved class metadata cache to {metadata_path}")
|
| 327 |
except Exception as e:
|
| 328 |
print(f"WARN: Failed to save metadata cache: {e}")
|
| 329 |
|
| 330 |
print(f"INFO: Class Distribution: {dict(class_counts)}")
|
| 331 |
+
print(
|
| 332 |
+
"INFO: Context Distribution by Class: "
|
| 333 |
+
f"{ {cid: dict(bucket_counts) for cid, bucket_counts in class_context_counts.items()} }"
|
| 334 |
+
)
|
| 335 |
|
| 336 |
# Store file_class_map for fast lookup by train.py's create_balanced_split
|
| 337 |
self.file_class_map = {p: cid for p, cid in file_class_map.items()}
|
| 338 |
+
self.file_context_bucket_map = {p: bucket for p, bucket in file_context_bucket_map.items()}
|
| 339 |
+
self.file_context_summary_map = {p: summary for p, summary in file_context_summary_map.items()}
|
| 340 |
|
| 341 |
# Compute Weights
|
| 342 |
self.weights_list = []
|
|
|
|
| 351 |
continue
|
| 352 |
|
| 353 |
cid = file_class_map[fname]
|
| 354 |
+
bucket = file_context_bucket_map.get(fname)
|
| 355 |
+
if bucket is None:
|
| 356 |
+
raise RuntimeError(
|
| 357 |
+
f"Cached sample '{fname}' is missing a context bucket. "
|
| 358 |
+
"Rebuild metadata or cache before training."
|
| 359 |
+
)
|
| 360 |
+
class_bucket_counts = class_context_counts[cid]
|
| 361 |
+
present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0]
|
| 362 |
+
if not present_buckets:
|
| 363 |
+
raise RuntimeError(
|
| 364 |
+
f"Class {cid} has no valid context buckets recorded. Cannot compute sampler weights."
|
| 365 |
+
)
|
| 366 |
+
bucket_count = class_bucket_counts[bucket]
|
| 367 |
+
if bucket_count <= 0:
|
| 368 |
+
raise RuntimeError(
|
| 369 |
+
f"Class {cid} bucket '{bucket}' has invalid count {bucket_count} for sample '{fname}'."
|
| 370 |
+
)
|
| 371 |
+
weight = 1.0 / (len(present_buckets) * bucket_count)
|
| 372 |
self.weights_list.append(weight)
|
| 373 |
valid_files.append(p)
|
| 374 |
|
|
|
|
| 380 |
self.cached_files = self.cached_files[:self.num_samples]
|
| 381 |
self.weights_list = self.weights_list[:self.num_samples]
|
| 382 |
|
| 383 |
+
# Recompute sampler weights against the active cached file subset so the
|
| 384 |
+
# class/context balancing reflects the actual dataset seen by training.
|
| 385 |
+
active_class_context_counts = defaultdict(lambda: defaultdict(int))
|
| 386 |
+
for p in self.cached_files:
|
| 387 |
+
fname = p.name
|
| 388 |
+
cid = file_class_map[fname]
|
| 389 |
+
bucket = file_context_bucket_map[fname]
|
| 390 |
+
active_class_context_counts[cid][bucket] += 1
|
| 391 |
+
|
| 392 |
+
self.weights_list = []
|
| 393 |
+
for p in self.cached_files:
|
| 394 |
+
fname = p.name
|
| 395 |
+
cid = file_class_map[fname]
|
| 396 |
+
bucket = file_context_bucket_map[fname]
|
| 397 |
+
class_bucket_counts = active_class_context_counts[cid]
|
| 398 |
+
present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0]
|
| 399 |
+
bucket_count = class_bucket_counts[bucket]
|
| 400 |
+
self.weights_list.append(1.0 / (len(present_buckets) * bucket_count))
|
| 401 |
+
|
| 402 |
print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.")
|
| 403 |
self.sampled_mints = [] # Not needed in cached mode
|
| 404 |
self.available_mints = []
|
|
|
|
| 1423 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 1424 |
)
|
| 1425 |
|
| 1426 |
+
min_context_trades = self.min_trades
|
| 1427 |
if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
|
| 1428 |
return None
|
| 1429 |
|
|
|
|
| 1752 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1753 |
include_wallet_data=False,
|
| 1754 |
include_graph=False,
|
| 1755 |
+
min_trades=self.min_trades,
|
| 1756 |
full_history=True, # Bypass H/B/H limits
|
| 1757 |
prune_failed=False, # Keep failed trades for realistic simulation
|
| 1758 |
prune_transfers=False # Keep transfers for snapshot reconstruction
|
|
|
|
| 1767 |
raw_data['name'] = initial_mint_record.get('token_name', '')
|
| 1768 |
raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
|
| 1769 |
raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
|
| 1770 |
+
raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW)
|
| 1771 |
+
raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS)
|
| 1772 |
+
raw_data['total_supply'] = (
|
| 1773 |
+
int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW
|
| 1774 |
+
)
|
| 1775 |
+
raw_data['decimals'] = (
|
| 1776 |
+
int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS
|
| 1777 |
+
)
|
| 1778 |
raw_data['protocol'] = initial_mint_record.get('protocol', 1)
|
| 1779 |
|
| 1780 |
def _timestamp_to_order_value(ts_value: Any) -> float:
|
|
|
|
| 2818 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 2819 |
include_wallet_data=False,
|
| 2820 |
include_graph=False,
|
| 2821 |
+
min_trades=self.min_trades,
|
| 2822 |
full_history=True,
|
| 2823 |
prune_failed=False,
|
| 2824 |
prune_transfers=False
|
|
|
|
| 2836 |
raw_data['name'] = initial_mint_record.get('token_name', '')
|
| 2837 |
raw_data['symbol'] = initial_mint_record.get('token_symbol', '')
|
| 2838 |
raw_data['token_uri'] = initial_mint_record.get('token_uri', '')
|
| 2839 |
+
raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW)
|
| 2840 |
+
raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS)
|
| 2841 |
+
raw_data['total_supply'] = (
|
| 2842 |
+
int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW
|
| 2843 |
+
)
|
| 2844 |
+
raw_data['decimals'] = (
|
| 2845 |
+
int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS
|
| 2846 |
+
)
|
| 2847 |
raw_data['protocol'] = initial_mint_record.get('protocol', 1)
|
| 2848 |
|
| 2849 |
def _timestamp_to_order_value(ts_value) -> float:
|
|
|
|
| 2872 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 2873 |
)
|
| 2874 |
|
| 2875 |
+
min_context_trades = self.min_trades
|
| 2876 |
if len(all_trades_sorted) < (min_context_trades + 1):
|
| 2877 |
print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
|
| 2878 |
return []
|
|
|
|
| 2959 |
total_supply_raw = int(raw_total_supply)
|
| 2960 |
token_decimals = int(raw_decimals)
|
| 2961 |
if total_supply_raw <= 0:
|
| 2962 |
+
total_supply_raw = DEFAULT_TOTAL_SUPPLY_RAW
|
| 2963 |
if token_decimals < 0:
|
| 2964 |
+
token_decimals = DEFAULT_TOKEN_DECIMALS
|
| 2965 |
token_scale = 10 ** token_decimals
|
| 2966 |
|
| 2967 |
def _strict_int(v: Any, field_name: str) -> int:
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df78e2b44dd97a148be762f91e3b00f397651f8e7e43ee21f938492291fdfa3a
|
| 3 |
+
size 83447
|
resume.md
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Resume
|
| 2 |
+
|
| 3 |
+
## Main conclusions from this chat
|
| 4 |
+
|
| 5 |
+
1. The main issue is data/sample construction, not just checkpoint choice.
|
| 6 |
+
- `class_id` is token-level.
|
| 7 |
+
- labels are context-level and depend on sampled `T_cutoff`.
|
| 8 |
+
- balanced token classes do not imply balanced future outcomes.
|
| 9 |
+
- the model can easily learn an over-negative prior if the cache is built naively.
|
| 10 |
+
|
| 11 |
+
2. Cache balancing must happen at cache generation time.
|
| 12 |
+
- train-time weighting is too late to fix disk waste or missing context diversity.
|
| 13 |
+
- the cache should control:
|
| 14 |
+
- class balance
|
| 15 |
+
- good/bad context balance within class
|
| 16 |
+
|
| 17 |
+
3. Validation needed token isolation, not just class balancing.
|
| 18 |
+
- same token appearing in train and val through different contexts makes validation misleading.
|
| 19 |
+
|
| 20 |
+
4. Movement-threshold circularity is not the main blocker anymore.
|
| 21 |
+
- movement labels are downstream of realized returns.
|
| 22 |
+
- movement thresholds should not control cache construction.
|
| 23 |
+
|
| 24 |
+
5. OHLC is important, but current usage looks like trend/regime summary, not true pattern detection.
|
| 25 |
+
- the chart branch is being used.
|
| 26 |
+
- but probe tests suggest it is mostly acting as continuation/trend context.
|
| 27 |
+
- not as breakout / support-resistance / head-and-shoulders intelligence.
|
| 28 |
+
|
| 29 |
+
6. The current code is still regression-first.
|
| 30 |
+
- main target is multi-horizon return prediction.
|
| 31 |
+
- there is also:
|
| 32 |
+
- quality head
|
| 33 |
+
- movement head
|
| 34 |
+
- calling movement auxiliary only matters if its loss contribution is actually secondary.
|
| 35 |
+
|
| 36 |
+
## What was implemented in code
|
| 37 |
+
|
| 38 |
+
### 1. Validation split
|
| 39 |
+
- `train.py`
|
| 40 |
+
- validation split was changed to group by token identity (`source_token` / `token_address`) instead of only `class_id`.
|
| 41 |
+
|
| 42 |
+
### 2. Task tracker
|
| 43 |
+
- `TASK_LIST.md`
|
| 44 |
+
- created as the running checklist for this work.
|
| 45 |
+
|
| 46 |
+
### 3. Context weighting signal
|
| 47 |
+
- `data/data_loader.py`
|
| 48 |
+
- added a context-quality summary derived from realized multi-horizon returns.
|
| 49 |
+
- current code computes:
|
| 50 |
+
- `context_score`
|
| 51 |
+
- `context_bucket` (`good` / `bad`)
|
| 52 |
+
- this is used for weighting and cache metadata.
|
| 53 |
+
|
| 54 |
+
### 4. Cached dataset weighting
|
| 55 |
+
- `data/data_loader.py`
|
| 56 |
+
- sampler weights now account for:
|
| 57 |
+
- class balance
|
| 58 |
+
- good/bad context balance inside class
|
| 59 |
+
|
| 60 |
+
### 5. Weighted cache generation
|
| 61 |
+
- `cache_dataset.py`
|
| 62 |
+
- canonical builder now writes context cache only.
|
| 63 |
+
- no `cache_mode`.
|
| 64 |
+
- no `max_samples`.
|
| 65 |
+
- cache budget is context-based:
|
| 66 |
+
- `--target_contexts_total`
|
| 67 |
+
- or `--target_contexts_per_class`
|
| 68 |
+
- quota-driven acceptance is implemented before save:
|
| 69 |
+
- level 1: class quotas
|
| 70 |
+
- level 2: good/bad quotas within class
|
| 71 |
+
|
| 72 |
+
### 6. Metadata rebuild
|
| 73 |
+
- `scripts/rebuild_metadata.py`
|
| 74 |
+
- now rebuilds:
|
| 75 |
+
- `file_class_map`
|
| 76 |
+
- `file_context_bucket_map`
|
| 77 |
+
- `file_context_summary_map`
|
| 78 |
+
|
| 79 |
+
### 7. `min_trades`
|
| 80 |
+
- re-added properly
|
| 81 |
+
- no longer a dead CLI arg
|
| 82 |
+
- now actually controls dataset/context eligibility thresholds
|
| 83 |
+
|
| 84 |
+
### 8. Evaluate script OHLC probes
|
| 85 |
+
- `scripts/evaluate_sample.py`
|
| 86 |
+
- added these OHLC probe modes:
|
| 87 |
+
- `ohlc_reverse`
|
| 88 |
+
- `ohlc_shuffle_chunks`
|
| 89 |
+
- `ohlc_mask_recent`
|
| 90 |
+
- `ohlc_trend_only`
|
| 91 |
+
- `ohlc_summary_shuffle`
|
| 92 |
+
- `ohlc_detrend`
|
| 93 |
+
- `ohlc_smooth`
|
| 94 |
+
|
| 95 |
+
### 9. Bad mint metadata fallback
|
| 96 |
+
- `data/data_loader.py`
|
| 97 |
+
- if mint metadata has invalid or zero supply/decimals, it now defaults to:
|
| 98 |
+
- `total_supply = 1000000000000000`
|
| 99 |
+
- `token_decimals = 6`
|
| 100 |
+
|
| 101 |
+
## Current cache-builder interface
|
| 102 |
+
|
| 103 |
+
`cache_dataset.py` now uses a canonical context-cache path only.
|
| 104 |
+
|
| 105 |
+
Relevant args:
|
| 106 |
+
- `--output_dir`
|
| 107 |
+
- `--start_date`
|
| 108 |
+
- `--min_trade_usd`
|
| 109 |
+
- `--min_trades`
|
| 110 |
+
- `--context_length`
|
| 111 |
+
- `--samples_per_token`
|
| 112 |
+
- `--target_contexts_per_class`
|
| 113 |
+
- `--target_contexts_total`
|
| 114 |
+
- `--good_ratio_nonzero`
|
| 115 |
+
- `--good_ratio_class0`
|
| 116 |
+
- `--num_workers`
|
| 117 |
+
- DB connection args
|
| 118 |
+
|
| 119 |
+
Rules:
|
| 120 |
+
- choose exactly one of:
|
| 121 |
+
- `--target_contexts_per_class`
|
| 122 |
+
- `--target_contexts_total`
|
| 123 |
+
- if quota-driven caching is used, current implementation expects:
|
| 124 |
+
- `--num_workers 1`
|
| 125 |
+
|
| 126 |
+
`--samples_per_token` is still present.
|
| 127 |
+
- It is not a cache budget.
|
| 128 |
+
- It is a candidate-generation knob.
|
| 129 |
+
- It may still be removable later if a better attempt-based planner replaces it.
|
| 130 |
+
|
| 131 |
+
## Important conceptual corrections from this chat
|
| 132 |
+
|
| 133 |
+
1. Binary context type matters operationally, but a naive binary rule is dangerous.
|
| 134 |
+
- examples like `+1%` then `-20%` show why simplistic rules fail.
|
| 135 |
+
- context typing should come from realized multi-horizon behavior, not one crude shortcut.
|
| 136 |
+
|
| 137 |
+
2. Patching alone does not force pattern learning.
|
| 138 |
+
- it only makes local pattern use possible.
|
| 139 |
+
- the model can still rely on trend shortcuts unless the representation and training setup make that harder.
|
| 140 |
+
|
| 141 |
+
3. Support/resistance may be inferable from current chart inputs in principle.
|
| 142 |
+
- but current encoder likely compresses too early and learns an easier shortcut instead.
|
| 143 |
+
|
| 144 |
+
4. For this question, the bottleneck is not just “what loss?”
|
| 145 |
+
- but also:
|
| 146 |
+
- chart input representation
|
| 147 |
+
- encoder compression
|
| 148 |
+
- whether the model can preserve local 1s structure
|
| 149 |
+
|
| 150 |
+
## OHLC probe findings from this chat
|
| 151 |
+
|
| 152 |
+
The key probe result repeated across runs was:
|
| 153 |
+
- `ohlc_detrend` had the largest impact
|
| 154 |
+
- `ohlc_trend_only` had the second-largest impact
|
| 155 |
+
- `ohlc_smooth`, `shuffle`, `summary_shuffle`, `reverse`, and `mask_recent` had very small impact
|
| 156 |
+
|
| 157 |
+
Interpretation:
|
| 158 |
+
- the model is using OHLC mostly as:
|
| 159 |
+
- broad trend/regime context
|
| 160 |
+
- continuation-style directional signal
|
| 161 |
+
- not as:
|
| 162 |
+
- local chart pattern detector
|
| 163 |
+
- support/resistance-aware trader logic
|
| 164 |
+
- breakout/rejection/fair-value-gap style reasoning
|
| 165 |
+
|
| 166 |
+
This strongly suggests many of the model’s bearish/bad predictions are driven by trend continuation behavior from the chart branch.
|
| 167 |
+
|
| 168 |
+
## What was explicitly rejected or corrected
|
| 169 |
+
|
| 170 |
+
- do not keep treating `cache_mode` as a real concept
|
| 171 |
+
- do not let `max_samples` and context-budget args overlap
|
| 172 |
+
- do not call something “auxiliary” if its weight can dominate optimization
|
| 173 |
+
- do not assume OHLC importance means chart-pattern understanding
|
| 174 |
+
- do not answer architecture questions by guessing from intention; inspect actual code
|
| 175 |
+
|
| 176 |
+
## What still matters next
|
| 177 |
+
|
| 178 |
+
1. verify token-overlap assertions are enforced in train/val, not just token-grouped split logic
|
| 179 |
+
2. rebuild cache with the new quota-based builder and inspect actual distributions
|
| 180 |
+
3. update cache audit tooling
|
| 181 |
+
4. decide whether `samples_per_token` should be replaced by a better attempt planner
|
| 182 |
+
5. decide how the chart branch should evolve if the goal is trader-like 1s pattern reasoning rather than trend summary
|
| 183 |
+
|
| 184 |
+
## Short current state
|
| 185 |
+
|
| 186 |
+
- cache construction is much closer to the right direction now
|
| 187 |
+
- validation logic is less broken than before
|
| 188 |
+
- OHLC is confirmed important
|
| 189 |
+
- but OHLC is currently behaving more like a trend/summary branch than a technical-pattern branch
|
| 190 |
+
- the codebase is still primarily training on return regression, with extra heads layered on top
|
scripts/evaluate_sample.py
CHANGED
|
@@ -49,6 +49,10 @@ OHLC_PROBE_MODES = [
|
|
| 49 |
"ohlc_reverse",
|
| 50 |
"ohlc_shuffle_chunks",
|
| 51 |
"ohlc_mask_recent",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
]
|
| 53 |
|
| 54 |
def unlog_transform(tensor):
|
|
@@ -221,6 +225,56 @@ def _chunk_permutation_indices(length, chunk_size):
|
|
| 221 |
return out
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def apply_ohlc_probe(batch, mode):
|
| 225 |
probed = clone_batch(batch)
|
| 226 |
if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
|
|
@@ -243,6 +297,26 @@ def apply_ohlc_probe(batch, mode):
|
|
| 243 |
elif keep == 0:
|
| 244 |
ohlc.zero_()
|
| 245 |
probed["ohlc_price_tensors"] = ohlc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
return probed
|
| 248 |
|
|
@@ -683,7 +757,7 @@ def main():
|
|
| 683 |
batch=batch,
|
| 684 |
preds=full_preds,
|
| 685 |
quality_pred=full_quality,
|
| 686 |
-
|
| 687 |
gt_labels=gt_labels,
|
| 688 |
gt_mask=gt_mask,
|
| 689 |
gt_quality=gt_quality,
|
|
@@ -701,7 +775,7 @@ def main():
|
|
| 701 |
batch=probe_batch,
|
| 702 |
preds=probe_preds,
|
| 703 |
quality_pred=probe_quality,
|
| 704 |
-
|
| 705 |
gt_labels=gt_labels,
|
| 706 |
gt_mask=gt_mask,
|
| 707 |
gt_quality=gt_quality,
|
|
@@ -717,7 +791,7 @@ def main():
|
|
| 717 |
batch=ablated_batch,
|
| 718 |
preds=ablated_preds,
|
| 719 |
quality_pred=ablated_quality,
|
| 720 |
-
|
| 721 |
gt_labels=gt_labels,
|
| 722 |
gt_mask=gt_mask,
|
| 723 |
gt_quality=gt_quality,
|
|
|
|
| 49 |
"ohlc_reverse",
|
| 50 |
"ohlc_shuffle_chunks",
|
| 51 |
"ohlc_mask_recent",
|
| 52 |
+
"ohlc_trend_only",
|
| 53 |
+
"ohlc_summary_shuffle",
|
| 54 |
+
"ohlc_detrend",
|
| 55 |
+
"ohlc_smooth",
|
| 56 |
]
|
| 57 |
|
| 58 |
def unlog_transform(tensor):
|
|
|
|
| 225 |
return out
|
| 226 |
|
| 227 |
|
| 228 |
+
def _moving_average_1d(series, kernel_size):
|
| 229 |
+
if kernel_size <= 1 or series.numel() == 0:
|
| 230 |
+
return series
|
| 231 |
+
pad = kernel_size // 2
|
| 232 |
+
kernel = torch.ones(1, 1, kernel_size, device=series.device, dtype=series.dtype) / float(kernel_size)
|
| 233 |
+
x = series.view(1, 1, -1)
|
| 234 |
+
x = torch.nn.functional.pad(x, (pad, pad), mode="replicate")
|
| 235 |
+
smoothed = torch.nn.functional.conv1d(x, kernel)
|
| 236 |
+
return smoothed.view(-1)[: series.numel()]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _linear_trend(series):
|
| 240 |
+
if series.numel() <= 1:
|
| 241 |
+
return series.clone()
|
| 242 |
+
start = series[0]
|
| 243 |
+
end = series[-1]
|
| 244 |
+
steps = torch.linspace(0.0, 1.0, series.numel(), device=series.device, dtype=series.dtype)
|
| 245 |
+
return start + (end - start) * steps
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _summary_preserving_shuffle(series, chunk_size=20):
|
| 249 |
+
length = series.numel()
|
| 250 |
+
if length <= 2:
|
| 251 |
+
return series
|
| 252 |
+
chunks = []
|
| 253 |
+
interior_start = 1
|
| 254 |
+
interior_end = length - 1
|
| 255 |
+
for i in range(interior_start, interior_end, chunk_size):
|
| 256 |
+
chunks.append(series[i:min(i + chunk_size, interior_end)].clone())
|
| 257 |
+
if len(chunks) <= 1:
|
| 258 |
+
return series
|
| 259 |
+
reordered = list(reversed(chunks))
|
| 260 |
+
out = series.clone()
|
| 261 |
+
cursor = 1
|
| 262 |
+
for chunk in reordered:
|
| 263 |
+
out[cursor:cursor + chunk.numel()] = chunk
|
| 264 |
+
cursor += chunk.numel()
|
| 265 |
+
out[0] = series[0]
|
| 266 |
+
out[-1] = series[-1]
|
| 267 |
+
return out
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _apply_per_series(ohlc, transform_fn):
|
| 271 |
+
out = ohlc.clone()
|
| 272 |
+
for batch_idx in range(out.shape[0]):
|
| 273 |
+
for channel_idx in range(out.shape[1]):
|
| 274 |
+
out[batch_idx, channel_idx] = transform_fn(out[batch_idx, channel_idx])
|
| 275 |
+
return out
|
| 276 |
+
|
| 277 |
+
|
| 278 |
def apply_ohlc_probe(batch, mode):
|
| 279 |
probed = clone_batch(batch)
|
| 280 |
if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
|
|
|
|
| 297 |
elif keep == 0:
|
| 298 |
ohlc.zero_()
|
| 299 |
probed["ohlc_price_tensors"] = ohlc
|
| 300 |
+
elif mode == "ohlc_trend_only":
|
| 301 |
+
probed["ohlc_price_tensors"] = _apply_per_series(ohlc, _linear_trend)
|
| 302 |
+
elif mode == "ohlc_summary_shuffle":
|
| 303 |
+
probed["ohlc_price_tensors"] = _apply_per_series(
|
| 304 |
+
ohlc,
|
| 305 |
+
lambda series: _summary_preserving_shuffle(series, chunk_size=20),
|
| 306 |
+
)
|
| 307 |
+
elif mode == "ohlc_detrend":
|
| 308 |
+
def detrend(series):
|
| 309 |
+
trend = _linear_trend(series)
|
| 310 |
+
detrended = series - trend + series[0]
|
| 311 |
+
detrended[0] = series[0]
|
| 312 |
+
detrended[-1] = series[0]
|
| 313 |
+
return detrended
|
| 314 |
+
probed["ohlc_price_tensors"] = _apply_per_series(ohlc, detrend)
|
| 315 |
+
elif mode == "ohlc_smooth":
|
| 316 |
+
probed["ohlc_price_tensors"] = _apply_per_series(
|
| 317 |
+
ohlc,
|
| 318 |
+
lambda series: _moving_average_1d(series, kernel_size=11),
|
| 319 |
+
)
|
| 320 |
|
| 321 |
return probed
|
| 322 |
|
|
|
|
| 757 |
batch=batch,
|
| 758 |
preds=full_preds,
|
| 759 |
quality_pred=full_quality,
|
| 760 |
+
movement_pred=full_direction,
|
| 761 |
gt_labels=gt_labels,
|
| 762 |
gt_mask=gt_mask,
|
| 763 |
gt_quality=gt_quality,
|
|
|
|
| 775 |
batch=probe_batch,
|
| 776 |
preds=probe_preds,
|
| 777 |
quality_pred=probe_quality,
|
| 778 |
+
movement_pred=probe_direction,
|
| 779 |
gt_labels=gt_labels,
|
| 780 |
gt_mask=gt_mask,
|
| 781 |
gt_quality=gt_quality,
|
|
|
|
| 791 |
batch=ablated_batch,
|
| 792 |
preds=ablated_preds,
|
| 793 |
quality_pred=ablated_quality,
|
| 794 |
+
movement_pred=ablated_direction,
|
| 795 |
gt_labels=gt_labels,
|
| 796 |
gt_mask=gt_mask,
|
| 797 |
gt_quality=gt_quality,
|
scripts/rebuild_metadata.py
CHANGED
|
@@ -5,6 +5,7 @@ import json
|
|
| 5 |
from pathlib import Path
|
| 6 |
from tqdm import tqdm
|
| 7 |
from collections import defaultdict
|
|
|
|
| 8 |
|
| 9 |
def rebuild_metadata(cache_dir="data/cache"):
|
| 10 |
cache_path = Path(cache_dir)
|
|
@@ -15,10 +16,13 @@ def rebuild_metadata(cache_dir="data/cache"):
|
|
| 15 |
print("No .pt files found!")
|
| 16 |
return
|
| 17 |
|
| 18 |
-
print(f"Found {len(files)} files. Reading class IDs...")
|
| 19 |
|
| 20 |
file_class_map = {}
|
|
|
|
|
|
|
| 21 |
class_distribution = defaultdict(int)
|
|
|
|
| 22 |
|
| 23 |
for f in tqdm(files):
|
| 24 |
try:
|
|
@@ -26,14 +30,24 @@ def rebuild_metadata(cache_dir="data/cache"):
|
|
| 26 |
# But torch.load loads everything. To be safe/fast, we just load on CPU.
|
| 27 |
data = torch.load(f, map_location="cpu", weights_only=False)
|
| 28 |
cid = data.get("class_id", 0)
|
|
|
|
| 29 |
file_class_map[f.name] = cid
|
|
|
|
|
|
|
| 30 |
class_distribution[cid] += 1
|
|
|
|
| 31 |
except Exception as e:
|
| 32 |
print(f"Error reading {f.name}: {e}")
|
| 33 |
|
| 34 |
output_data = {
|
| 35 |
'file_class_map': file_class_map,
|
|
|
|
|
|
|
| 36 |
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# These are informational, setting defaults to avoid breaking if loader checks them
|
| 38 |
'num_workers': 1,
|
| 39 |
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from tqdm import tqdm
|
| 7 |
from collections import defaultdict
|
| 8 |
+
from data.data_loader import summarize_context_window
|
| 9 |
|
| 10 |
def rebuild_metadata(cache_dir="data/cache"):
|
| 11 |
cache_path = Path(cache_dir)
|
|
|
|
| 16 |
print("No .pt files found!")
|
| 17 |
return
|
| 18 |
|
| 19 |
+
print(f"Found {len(files)} files. Reading class IDs and context summaries...")
|
| 20 |
|
| 21 |
file_class_map = {}
|
| 22 |
+
file_context_bucket_map = {}
|
| 23 |
+
file_context_summary_map = {}
|
| 24 |
class_distribution = defaultdict(int)
|
| 25 |
+
context_distribution = defaultdict(lambda: defaultdict(int))
|
| 26 |
|
| 27 |
for f in tqdm(files):
|
| 28 |
try:
|
|
|
|
| 30 |
# But torch.load loads everything. To be safe/fast, we just load on CPU.
|
| 31 |
data = torch.load(f, map_location="cpu", weights_only=False)
|
| 32 |
cid = data.get("class_id", 0)
|
| 33 |
+
context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
|
| 34 |
file_class_map[f.name] = cid
|
| 35 |
+
file_context_bucket_map[f.name] = context_summary["context_bucket"]
|
| 36 |
+
file_context_summary_map[f.name] = context_summary
|
| 37 |
class_distribution[cid] += 1
|
| 38 |
+
context_distribution[cid][context_summary["context_bucket"]] += 1
|
| 39 |
except Exception as e:
|
| 40 |
print(f"Error reading {f.name}: {e}")
|
| 41 |
|
| 42 |
output_data = {
|
| 43 |
'file_class_map': file_class_map,
|
| 44 |
+
'file_context_bucket_map': file_context_bucket_map,
|
| 45 |
+
'file_context_summary_map': file_context_summary_map,
|
| 46 |
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 47 |
+
'context_distribution': {
|
| 48 |
+
str(k): {bucket: count for bucket, count in bucket_counts.items()}
|
| 49 |
+
for k, bucket_counts in context_distribution.items()
|
| 50 |
+
},
|
| 51 |
# These are informational, setting defaults to avoid breaking if loader checks them
|
| 52 |
'num_workers': 1,
|
| 53 |
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
|
train.py
CHANGED
|
@@ -234,54 +234,77 @@ def collator_like_targets(labels, labels_mask, movement_label_config: Optional[D
|
|
| 234 |
|
| 235 |
def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
| 236 |
"""
|
| 237 |
-
Create train/val split
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
"""
|
| 241 |
import random
|
| 242 |
-
random.seed(seed)
|
| 243 |
|
| 244 |
-
|
| 245 |
class_to_indices = defaultdict(list)
|
|
|
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
class_to_indices[0].append(idx)
|
| 269 |
|
| 270 |
train_indices = []
|
| 271 |
val_indices = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
random.shuffle(indices)
|
| 276 |
-
n_val = min(len(indices), n_val_per_class) # Ensure we don't take more than we have
|
| 277 |
-
val_indices.extend(indices[:n_val])
|
| 278 |
-
train_indices.extend(indices[n_val:])
|
| 279 |
-
|
| 280 |
-
# Shuffle both sets
|
| 281 |
-
random.shuffle(train_indices)
|
| 282 |
-
random.shuffle(val_indices)
|
| 283 |
|
| 284 |
-
return train_indices, val_indices, class_to_indices
|
| 285 |
|
| 286 |
|
| 287 |
def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_fn, vocab):
|
|
@@ -688,18 +711,27 @@ def main() -> None:
|
|
| 688 |
raise RuntimeError("Dataset is empty.")
|
| 689 |
|
| 690 |
# --- NEW: Create balanced train/val split ---
|
| 691 |
-
logger.info(
|
| 692 |
-
|
|
|
|
|
|
|
| 693 |
dataset, n_val_per_class=args.val_samples_per_class, seed=seed
|
| 694 |
)
|
| 695 |
|
| 696 |
-
# Log class distribution (use set for O(1) lookup)
|
| 697 |
-
train_set = set(train_indices)
|
| 698 |
logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
|
| 699 |
for class_id, indices in sorted(class_distribution.items()):
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
# --- Compute class weights for loss weighting ---
|
| 705 |
num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
|
|
|
|
| 234 |
|
| 235 |
def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
| 236 |
"""
|
| 237 |
+
Create a token-grouped train/val split.
|
| 238 |
+
|
| 239 |
+
Validation is selected by token identity first, then balanced by class. This
|
| 240 |
+
prevents the same token from appearing in both splits through different cached
|
| 241 |
+
windows.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
(train_indices, val_indices, class_to_indices, split_stats)
|
| 245 |
"""
|
| 246 |
import random
|
|
|
|
| 247 |
|
| 248 |
+
rng = random.Random(seed)
|
| 249 |
class_to_indices = defaultdict(list)
|
| 250 |
+
class_to_token_groups = defaultdict(lambda: defaultdict(list))
|
| 251 |
|
| 252 |
+
if not hasattr(dataset, "cached_files"):
|
| 253 |
+
raise RuntimeError("Token-grouped split requires a cached dataset with cached_files.")
|
| 254 |
+
|
| 255 |
+
for idx, cached_file in enumerate(dataset.cached_files):
|
| 256 |
+
fname = cached_file.name if hasattr(cached_file, "name") else str(cached_file)
|
| 257 |
+
cached_item = None
|
| 258 |
+
try:
|
| 259 |
+
cached_item = torch.load(cached_file, map_location="cpu", weights_only=False)
|
| 260 |
+
except Exception as exc:
|
| 261 |
+
raise RuntimeError(f"Failed to read cached sample '{fname}' for split planning: {exc}") from exc
|
| 262 |
+
|
| 263 |
+
class_id = int(cached_item.get("class_id", 0))
|
| 264 |
+
source_token = cached_item.get("source_token") or cached_item.get("token_address")
|
| 265 |
+
if not source_token:
|
| 266 |
+
raise RuntimeError(
|
| 267 |
+
f"Cached sample '{fname}' is missing both 'source_token' and 'token_address'; "
|
| 268 |
+
"cannot build token-isolated validation split."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
class_to_indices[class_id].append(idx)
|
| 272 |
+
class_to_token_groups[class_id][str(source_token)].append(idx)
|
|
|
|
| 273 |
|
| 274 |
train_indices = []
|
| 275 |
val_indices = []
|
| 276 |
+
split_stats = {}
|
| 277 |
+
|
| 278 |
+
for class_id, token_groups in class_to_token_groups.items():
|
| 279 |
+
token_items = list(token_groups.items())
|
| 280 |
+
rng.shuffle(token_items)
|
| 281 |
+
|
| 282 |
+
n_val_tokens = min(len(token_items), n_val_per_class)
|
| 283 |
+
val_token_items = token_items[:n_val_tokens]
|
| 284 |
+
train_token_items = token_items[n_val_tokens:]
|
| 285 |
+
|
| 286 |
+
class_val_indices = []
|
| 287 |
+
class_train_indices = []
|
| 288 |
+
for _, indices in val_token_items:
|
| 289 |
+
class_val_indices.extend(indices)
|
| 290 |
+
for _, indices in train_token_items:
|
| 291 |
+
class_train_indices.extend(indices)
|
| 292 |
+
|
| 293 |
+
val_indices.extend(class_val_indices)
|
| 294 |
+
train_indices.extend(class_train_indices)
|
| 295 |
+
split_stats[class_id] = {
|
| 296 |
+
"total_samples": len(class_to_indices[class_id]),
|
| 297 |
+
"total_tokens": len(token_items),
|
| 298 |
+
"val_samples": len(class_val_indices),
|
| 299 |
+
"val_tokens": len(val_token_items),
|
| 300 |
+
"train_samples": len(class_train_indices),
|
| 301 |
+
"train_tokens": len(train_token_items),
|
| 302 |
+
}
|
| 303 |
|
| 304 |
+
rng.shuffle(train_indices)
|
| 305 |
+
rng.shuffle(val_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
+
return train_indices, val_indices, class_to_indices, split_stats
|
| 308 |
|
| 309 |
|
| 310 |
def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_fn, vocab):
|
|
|
|
| 711 |
raise RuntimeError("Dataset is empty.")
|
| 712 |
|
| 713 |
# --- NEW: Create balanced train/val split ---
|
| 714 |
+
logger.info(
|
| 715 |
+
f"Creating token-grouped split with {args.val_samples_per_class} validation tokens per class..."
|
| 716 |
+
)
|
| 717 |
+
train_indices, val_indices, class_distribution, split_stats = create_balanced_split(
|
| 718 |
dataset, n_val_per_class=args.val_samples_per_class, seed=seed
|
| 719 |
)
|
| 720 |
|
|
|
|
|
|
|
| 721 |
logger.info(f"Total samples: {len(dataset)}, Train: {len(train_indices)}, Val: {len(val_indices)}")
|
| 722 |
for class_id, indices in sorted(class_distribution.items()):
|
| 723 |
+
stats = split_stats.get(class_id, {})
|
| 724 |
+
logger.info(
|
| 725 |
+
" Class %s: %s samples across %s tokens (~%s train samples / %s val samples, "
|
| 726 |
+
"%s train tokens / %s val tokens)",
|
| 727 |
+
class_id,
|
| 728 |
+
len(indices),
|
| 729 |
+
stats.get("total_tokens", 0),
|
| 730 |
+
stats.get("train_samples", 0),
|
| 731 |
+
stats.get("val_samples", 0),
|
| 732 |
+
stats.get("train_tokens", 0),
|
| 733 |
+
stats.get("val_tokens", 0),
|
| 734 |
+
)
|
| 735 |
|
| 736 |
# --- Compute class weights for loss weighting ---
|
| 737 |
num_classes = max(class_distribution.keys()) + 1 if class_distribution else 7
|