zirobtc commited on
Commit
85e02a7
·
1 Parent(s): 4fdcff3

Upload folder using huggingface_hub

Browse files
.claude/settings.local.json CHANGED
@@ -9,7 +9,12 @@
9
  "Bash(python3:*)",
10
  "Bash(dir:*)",
11
  "Bash(cmd /c \"dir /s /b\")",
12
- "Bash(python -c:*)"
 
 
 
 
 
13
  ]
14
  }
15
  }
 
9
  "Bash(python3:*)",
10
  "Bash(dir:*)",
11
  "Bash(cmd /c \"dir /s /b\")",
12
+ "Bash(python -c:*)",
13
+ "Bash(pip install:*)",
14
+ "Bash(huggingface-cli login:*)",
15
+ "Bash(hf whoami:*)",
16
+ "Bash(huggingface-cli whoami:*)",
17
+ "Bash(python -m huggingface_hub.commands.huggingface_cli:*)"
18
  ]
19
  }
20
  }
data/data_loader.py CHANGED
@@ -1080,36 +1080,51 @@ class OracleDataset(Dataset):
1080
 
1081
  def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
1082
  """
1083
- Loads raw data from cache, samples a random T_cutoff, and generates a training sample.
1084
 
1085
- OPTIMIZED: Uses fully cached wallet/graph/image data for zero DB calls during training.
 
 
 
1086
  """
1087
  import time as _time
1088
  _timings = {}
1089
  _total_start = _time.perf_counter()
1090
 
1091
- # --- REMOVED: No more fetcher initialization during training ---
1092
- # We use fully offline mode with pre-cached data
1093
- _timings['fetcher_init'] = 0.0
1094
-
1095
  # --- TIMING: Cache load ---
1096
  _t0 = _time.perf_counter()
1097
- raw_data = None
1098
- if self.cache_dir:
1099
- if idx >= len(self.cached_files):
1100
- raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.")
1101
- filepath = self.cached_files[idx]
1102
- try:
1103
- raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
1104
- except Exception as e:
1105
- raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
1106
- else:
1107
- # Strict Offline Mode: No dynamic generation fallback
1108
- raise RuntimeError(f"Offline mode required. No cache directory provided or configured.")
1109
  _timings['cache_load'] = _time.perf_counter() - _t0
1110
 
1111
- if not raw_data:
1112
- raise RuntimeError(f"No raw data loaded for index {idx}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1113
 
1114
  required_keys = [
1115
  "mint_timestamp",
@@ -2347,3 +2362,408 @@ class OracleDataset(Dataset):
2347
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2348
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
2349
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1080
 
1081
  def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
1082
  """
1083
+ Loads data from cache. Behavior depends on cache mode:
1084
 
1085
+ - RAW MODE: Loads raw token data, samples T_cutoff at runtime, applies H/B/H
1086
+ - CONTEXT MODE: Loads pre-computed training context directly (fully offline)
1087
+
1088
+ The cache mode is auto-detected from the cached file's 'cache_mode' field.
1089
  """
1090
  import time as _time
1091
  _timings = {}
1092
  _total_start = _time.perf_counter()
1093
 
 
 
 
 
1094
  # --- TIMING: Cache load ---
1095
  _t0 = _time.perf_counter()
1096
+ if not self.cache_dir:
1097
+ raise RuntimeError("Offline mode required. No cache directory provided.")
1098
+
1099
+ if idx >= len(self.cached_files):
1100
+ raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.")
1101
+
1102
+ filepath = self.cached_files[idx]
1103
+ try:
1104
+ cached_data = torch.load(filepath, map_location='cpu', weights_only=False)
1105
+ except Exception as e:
1106
+ raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
 
1107
  _timings['cache_load'] = _time.perf_counter() - _t0
1108
 
1109
+ if not cached_data:
1110
+ raise RuntimeError(f"No data loaded for index {idx}")
1111
+
1112
+ # Auto-detect cache mode
1113
+ cache_mode = cached_data.get('cache_mode', 'raw')
1114
+
1115
+ if cache_mode == 'context':
1116
+ # CONTEXT MODE: Return pre-computed training context directly
1117
+ # This is fully deterministic - no runtime sampling or processing
1118
+ _timings['total'] = _time.perf_counter() - _total_start
1119
+
1120
+ if idx % 100 == 0:
1121
+ print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1122
+ f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
1123
+
1124
+ return cached_data
1125
+
1126
+ # RAW MODE: Fall through to original __getitem__ logic with runtime T_cutoff sampling
1127
+ raw_data = cached_data
1128
 
1129
  required_keys = [
1130
  "mint_timestamp",
 
2362
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2363
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
2364
  }
2365
+
2366
+ def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1) -> List[Optional[Dict[str, Any]]]:
2367
+ """
2368
+ Generates fully processed training contexts for caching.
2369
+
2370
+ This method:
2371
+ 1. Fetches raw token data (like __cacheitem__)
2372
+ 2. Samples T_cutoff(s) using the weight sampling logic
2373
+ 3. Applies H/B/H dynamic sampling based on max_seq_len
2374
+ 4. Returns complete training-ready samples that can be loaded directly
2375
+
2376
+ This moves ALL non-determinism into cache time, making training fully offline
2377
+ and avoiding caching tokens that would never be seen (98% garbage filtered out
2378
+ by weight sampling and T_cutoff eligibility).
2379
+
2380
+ Args:
2381
+ idx: Index into sampled_mints
2382
+ num_samples_per_token: Number of different T_cutoff samples to generate per token
2383
+
2384
+ Returns:
2385
+ List of training-ready samples (may be fewer than num_samples_per_token if
2386
+ some T_cutoffs are invalid)
2387
+ """
2388
+ import time as _time
2389
+
2390
+ if not self.sampled_mints:
2391
+ raise RuntimeError("Dataset has no mint records loaded.")
2392
+ if idx >= len(self.sampled_mints):
2393
+ raise IndexError(f"Index {idx} exceeds mint count {len(self.sampled_mints)}.")
2394
+
2395
+ initial_mint_record = self.sampled_mints[idx]
2396
+ t0 = initial_mint_record["timestamp"]
2397
+ if isinstance(t0, datetime.datetime) and t0.tzinfo is None:
2398
+ t0 = t0.replace(tzinfo=datetime.timezone.utc)
2399
+ creator_address = initial_mint_record['creator_address']
2400
+ token_address = initial_mint_record['mint_address']
2401
+
2402
+ print(f"\n--- Caching CONTEXT for token: {token_address} (generating {num_samples_per_token} samples) ---")
2403
+
2404
+ if not self.fetcher:
2405
+ raise RuntimeError("Dataset has no data fetcher.")
2406
+
2407
+ # --- STEP 1: Fetch raw data (same as __cacheitem__) ---
2408
+ raw_data = self.fetcher.fetch_raw_token_data(
2409
+ token_address=token_address,
2410
+ creator_address=creator_address,
2411
+ mint_timestamp=t0,
2412
+ max_horizon_seconds=self.max_cache_horizon_seconds,
2413
+ include_wallet_data=False,
2414
+ include_graph=False,
2415
+ min_trades=10,
2416
+ full_history=True,
2417
+ prune_failed=False,
2418
+ prune_transfers=False
2419
+ )
2420
+
2421
+ if raw_data is None:
2422
+ print(f" SKIP: No raw data for {token_address}")
2423
+ return []
2424
+
2425
+ def _timestamp_to_order_value(ts_value) -> float:
2426
+ if isinstance(ts_value, datetime.datetime):
2427
+ if ts_value.tzinfo is None:
2428
+ ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
2429
+ return ts_value.timestamp()
2430
+ try:
2431
+ return float(ts_value)
2432
+ except:
2433
+ return 0.0
2434
+
2435
+ # --- STEP 2: Validate trades and find eligible T_cutoff indices ---
2436
+ all_trades_raw = raw_data.get('trades', [])
2437
+ if not all_trades_raw:
2438
+ print(f" SKIP: No trades for {token_address}")
2439
+ return []
2440
+
2441
+ all_trades_sorted = sorted(
2442
+ [t for t in all_trades_raw if t.get('timestamp') is not None],
2443
+ key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
2444
+ )
2445
+
2446
+ min_context_trades = 10
2447
+ if len(all_trades_sorted) < (min_context_trades + 1):
2448
+ print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
2449
+ return []
2450
+
2451
+ # Find successful trade indices
2452
+ successful_indices = [
2453
+ i for i, t in enumerate(all_trades_sorted)
2454
+ if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
2455
+ ]
2456
+
2457
+ if len(successful_indices) < 2:
2458
+ print(f" SKIP: Not enough successful trades for {token_address}")
2459
+ return []
2460
+
2461
+ max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
2462
+ min_idx = min_context_trades - 1
2463
+ max_idx = len(all_trades_sorted) - 2
2464
+
2465
+ if max_idx < min_idx:
2466
+ print(f" SKIP: Invalid index range for {token_address}")
2467
+ return []
2468
+
2469
+ # Build lookup arrays
2470
+ last_successful_before = [-1] * len(all_trades_sorted)
2471
+ last_seen = -1
2472
+ succ_set = set(successful_indices)
2473
+ for i in range(len(all_trades_sorted)):
2474
+ if i in succ_set:
2475
+ last_seen = i
2476
+ last_successful_before[i] = last_seen
2477
+
2478
+ next_successful_after = [-1] * len(all_trades_sorted)
2479
+ next_seen = -1
2480
+ for i in range(len(all_trades_sorted) - 1, -1, -1):
2481
+ if i in succ_set:
2482
+ next_seen = i
2483
+ next_successful_after[i] = next_seen
2484
+
2485
+ # Find all eligible T_cutoff indices
2486
+ eligible_indices = []
2487
+ for i in range(min_idx, max_idx + 1):
2488
+ anchor_idx = last_successful_before[i]
2489
+ next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1
2490
+ if anchor_idx < 0 or next_idx < 0:
2491
+ continue
2492
+ cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp'))
2493
+ next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp'))
2494
+ if next_ts <= cutoff_ts + max_horizon_seconds:
2495
+ eligible_indices.append(i)
2496
+
2497
+ if not eligible_indices:
2498
+ print(f" SKIP: No eligible T_cutoff indices for {token_address}")
2499
+ return []
2500
+
2501
+ print(f" INFO: Found {len(eligible_indices)} eligible T_cutoff positions")
2502
+
2503
+ # --- STEP 3: Generate OHLC and holder snapshots (same as __cacheitem__) ---
2504
+ trades = raw_data.get('trades', [])
2505
+ trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades]
2506
+ t0_val = _timestamp_to_order_value(t0)
2507
+ last_trade_ts_val = max(trade_ts_values)
2508
+
2509
+ duration_seconds = int(last_trade_ts_val - t0_val) + 120
2510
+ ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32)
2511
+
2512
+ trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
2513
+ trades_by_sec = defaultdict(list)
2514
+ for t in trades:
2515
+ ts = _timestamp_to_order_value(t['timestamp'])
2516
+ sec_idx = int(ts - t0_val)
2517
+ if 0 <= sec_idx < duration_seconds:
2518
+ trades_by_sec[sec_idx].append(t['price_usd'])
2519
+
2520
+ last_close = float(trades[0]['price_usd'])
2521
+ for i in range(duration_seconds):
2522
+ if i in trades_by_sec:
2523
+ prices = trades_by_sec[i]
2524
+ op, cl = prices[0], prices[-1]
2525
+ last_close = cl
2526
+ else:
2527
+ op = cl = last_close
2528
+ ohlc_1s[i, 0] = float(op)
2529
+ ohlc_1s[i, 1] = float(cl)
2530
+
2531
+ raw_data['ohlc_1s'] = ohlc_1s
2532
+
2533
+ # Generate holder snapshots
2534
+ interval = 300
2535
+ num_intervals = (duration_seconds // interval) + 1
2536
+ snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32)
2537
+
2538
+ buckets = defaultdict(list)
2539
+ for t in trades:
2540
+ ts = _timestamp_to_order_value(t['timestamp'])
2541
+ bucket_idx = int(ts - t0_val) // interval
2542
+ if bucket_idx >= 0:
2543
+ buckets[bucket_idx].append(t)
2544
+
2545
+ holder_snapshots_list = []
2546
+ for i in range(num_intervals):
2547
+ bucket_trades = buckets[i]
2548
+ vol = sum(t.get('total_usd', 0.0) for t in bucket_trades)
2549
+ tx = len(bucket_trades)
2550
+ buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0)
2551
+ sells = tx - buys
2552
+
2553
+ snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
2554
+ count, top_holders = self.fetcher.fetch_holder_snapshot_stats_for_token(
2555
+ token_address, snapshot_ts, limit=HOLDER_SNAPSHOT_TOP_K
2556
+ )
2557
+
2558
+ total_supply = raw_data.get('total_supply', 0) or 1
2559
+ if raw_data.get('decimals'):
2560
+ total_supply /= (10 ** raw_data['decimals'])
2561
+
2562
+ top10_bal = sum(h.get('current_balance', 0) for h in top_holders[:10])
2563
+ top10_pct = (top10_bal / total_supply) if total_supply > 0 else 0.0
2564
+
2565
+ snapshot_stats[i, 0] = float(vol)
2566
+ snapshot_stats[i, 1] = float(tx)
2567
+ snapshot_stats[i, 2] = float(buys)
2568
+ snapshot_stats[i, 3] = float(sells)
2569
+ snapshot_stats[i, 4] = float(count)
2570
+ snapshot_stats[i, 5] = float(top10_pct)
2571
+
2572
+ holder_snapshots_list.append({
2573
+ 'timestamp': int(snapshot_ts.timestamp()),
2574
+ 'holders': top_holders
2575
+ })
2576
+
2577
+ raw_data['snapshots_5m'] = snapshot_stats
2578
+ raw_data['holder_snapshots_list'] = holder_snapshots_list
2579
+ raw_data['protocol_id'] = initial_mint_record.get('protocol')
2580
+
2581
+ # --- STEP 4: Collect ALL wallets and pre-fetch their data ---
2582
+ all_wallets = set()
2583
+ all_wallets.add(creator_address)
2584
+
2585
+ for trade in raw_data.get('trades', []):
2586
+ if trade.get('maker'):
2587
+ all_wallets.add(trade['maker'])
2588
+ for transfer in raw_data.get('transfers', []):
2589
+ if transfer.get('source'):
2590
+ all_wallets.add(transfer['source'])
2591
+ if transfer.get('destination'):
2592
+ all_wallets.add(transfer['destination'])
2593
+ for pool in raw_data.get('pool_creations', []):
2594
+ if pool.get('creator_address'):
2595
+ all_wallets.add(pool['creator_address'])
2596
+ for liq in raw_data.get('liquidity_changes', []):
2597
+ if liq.get('lp_provider'):
2598
+ all_wallets.add(liq['lp_provider'])
2599
+ for snapshot in holder_snapshots_list:
2600
+ for holder in snapshot.get('holders', []):
2601
+ if holder.get('wallet_address'):
2602
+ all_wallets.add(holder['wallet_address'])
2603
+
2604
+ all_wallets.discard(None)
2605
+ all_wallets.discard('')
2606
+ wallet_list = list(all_wallets)
2607
+
2608
+ max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc)
2609
+
2610
+ try:
2611
+ cached_profiles, cached_socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_list, max_T_cutoff)
2612
+ except Exception as e:
2613
+ print(f" WARN: Failed to fetch wallet profiles/socials: {e}")
2614
+ cached_profiles, cached_socials = {}, {}
2615
+
2616
+ try:
2617
+ cached_holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff)
2618
+ except Exception as e:
2619
+ print(f" WARN: Failed to fetch wallet holdings: {e}")
2620
+ cached_holdings = {}
2621
+
2622
+ try:
2623
+ cached_graph_entities, cached_graph_links = self.fetcher.fetch_graph_links(
2624
+ wallet_list, max_T_cutoff, max_degrees=1
2625
+ )
2626
+ except Exception as e:
2627
+ print(f" WARN: Failed to fetch graph links: {e}")
2628
+ cached_graph_entities, cached_graph_links = {}, {}
2629
+
2630
+ # Fetch token image
2631
+ cached_image_bytes = None
2632
+ try:
2633
+ bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0"
2634
+ resp = self.http_session.get(bullx_image_url, timeout=5)
2635
+ if resp.status_code == 200:
2636
+ cached_image_bytes = resp.content
2637
+ except:
2638
+ pass
2639
+
2640
+ # --- STEP 5: Sample T_cutoffs and generate complete training contexts ---
2641
+ results = []
2642
+
2643
+ # Sample indices (with replacement if needed)
2644
+ if num_samples_per_token >= len(eligible_indices):
2645
+ sampled_indices = eligible_indices.copy()
2646
+ else:
2647
+ sampled_indices = random.sample(eligible_indices, num_samples_per_token)
2648
+
2649
+ print(f" INFO: Generating {len(sampled_indices)} training contexts...")
2650
+
2651
+ for sample_num, sample_idx in enumerate(sampled_indices):
2652
+ sample_trade = all_trades_sorted[sample_idx]
2653
+ sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp'))
2654
+ T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
2655
+ cutoff_ts = sample_offset_ts
2656
+
2657
+ # Collect wallets visible at T_cutoff
2658
+ wallets_to_fetch = set()
2659
+ wallets_to_fetch.add(creator_address)
2660
+
2661
+ for trade in raw_data.get('trades', []):
2662
+ if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts:
2663
+ if trade.get('maker'):
2664
+ wallets_to_fetch.add(trade['maker'])
2665
+
2666
+ for transfer in raw_data.get('transfers', []):
2667
+ if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts:
2668
+ if transfer.get('source'):
2669
+ wallets_to_fetch.add(transfer['source'])
2670
+ if transfer.get('destination'):
2671
+ wallets_to_fetch.add(transfer['destination'])
2672
+
2673
+ for pool in raw_data.get('pool_creations', []):
2674
+ if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts:
2675
+ if pool.get('creator_address'):
2676
+ wallets_to_fetch.add(pool['creator_address'])
2677
+
2678
+ for liq in raw_data.get('liquidity_changes', []):
2679
+ if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
2680
+ if liq.get('lp_provider'):
2681
+ wallets_to_fetch.add(liq['lp_provider'])
2682
+
2683
+ # Get holder snapshot at T_cutoff
2684
+ elapsed = (T_cutoff - t0).total_seconds()
2685
+ snap_idx = int(elapsed // 300)
2686
+ if 0 <= snap_idx < len(holder_snapshots_list):
2687
+ snapshot_data = holder_snapshots_list[snap_idx]
2688
+ for holder in snapshot_data.get('holders', []):
2689
+ if holder.get('wallet_address'):
2690
+ wallets_to_fetch.add(holder['wallet_address'])
2691
+
2692
+ wallets_to_fetch.discard(None)
2693
+ wallets_to_fetch.discard('')
2694
+
2695
+ # Build offline data for this context
2696
+ pooler = EmbeddingPooler()
2697
+
2698
+ # Process token data offline
2699
+ offline_token_data = {token_address: raw_data.copy()}
2700
+ if cached_image_bytes:
2701
+ try:
2702
+ cached_image = Image.open(BytesIO(cached_image_bytes))
2703
+ offline_token_data[token_address]['_cached_image_pil'] = cached_image
2704
+ except:
2705
+ pass
2706
+
2707
+ main_token_data = self._process_token_data_offline(
2708
+ [token_address], pooler, T_cutoff, token_data=offline_token_data
2709
+ )
2710
+
2711
+ if not main_token_data:
2712
+ continue
2713
+
2714
+ # Process wallet data offline
2715
+ wallet_data, all_token_data = self._process_wallet_data(
2716
+ list(wallets_to_fetch),
2717
+ main_token_data.copy(),
2718
+ pooler,
2719
+ T_cutoff,
2720
+ profiles_override=cached_profiles,
2721
+ socials_override=cached_socials,
2722
+ holdings_override=cached_holdings
2723
+ )
2724
+
2725
+ # Generate the complete training item (with H/B/H applied via _generate_dataset_item)
2726
+ mint_event = {
2727
+ 'event_type': 'Mint',
2728
+ 'timestamp': int(t0.timestamp()),
2729
+ 'relative_ts': 0,
2730
+ 'wallet_address': creator_address,
2731
+ 'token_address': token_address,
2732
+ 'protocol_id': raw_data.get('protocol_id', 0)
2733
+ }
2734
+
2735
+ result = self._generate_dataset_item(
2736
+ token_address=token_address,
2737
+ t0=t0,
2738
+ T_cutoff=T_cutoff,
2739
+ mint_event=mint_event,
2740
+ trade_records=raw_data['trades'],
2741
+ transfer_records=raw_data['transfers'],
2742
+ pool_creation_records=raw_data['pool_creations'],
2743
+ liquidity_change_records=raw_data['liquidity_changes'],
2744
+ fee_collection_records=raw_data['fee_collections'],
2745
+ burn_records=raw_data['burns'],
2746
+ supply_lock_records=raw_data['supply_locks'],
2747
+ migration_records=raw_data['migrations'],
2748
+ wallet_data=wallet_data,
2749
+ all_token_data=all_token_data,
2750
+ graph_links=cached_graph_links,
2751
+ graph_seed_entities=wallets_to_fetch,
2752
+ all_graph_entities=cached_graph_entities,
2753
+ future_trades_for_labels=raw_data['trades'],
2754
+ pooler=pooler,
2755
+ sample_idx=idx,
2756
+ cached_holders_list=holder_snapshots_list,
2757
+ cached_ohlc_1s=ohlc_1s,
2758
+ quality_score=None # Will be injected by cache_dataset.py
2759
+ )
2760
+
2761
+ if result is not None:
2762
+ # Store the T_cutoff used for this sample (for reproducibility tracking)
2763
+ result['cached_t_cutoff_ts'] = sample_offset_ts
2764
+ result['cached_sample_num'] = sample_num
2765
+ results.append(result)
2766
+ print(f" + Context {sample_num}: T_cutoff={T_cutoff.isoformat()}, events={len(result.get('event_sequence', []))}")
2767
+
2768
+ print(f" INFO: Generated {len(results)} valid training contexts for {token_address}")
2769
+ return results
pre_cache.sh CHANGED
@@ -1,7 +1,48 @@
1
  #!/bin/bash
2
  # Pre-caches the dataset for training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  echo "Starting dataset caching..."
 
5
  python3 scripts/cache_dataset.py \
6
- --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz"
 
 
 
 
 
 
 
7
  echo "Done!"
 
 
1
  #!/bin/bash
2
  # Pre-caches the dataset for training
3
+ #
4
+ # Usage:
5
+ # ./pre_cache.sh # Raw mode (old behavior)
6
+ # ./pre_cache.sh --cache_mode context # Context mode (new fully offline)
7
+ #
8
+ # Context mode arguments:
9
+ # --context_length N Max sequence length, triggers H/B/H when exceeded (default: 8192)
10
+ # --min_trades N Minimum trades for T_cutoff sampling (default: 10)
11
+ # --samples_per_token N Number of T_cutoff samples per token (default: 1)
12
+
13
+ set -e
14
+
15
+ # Default values
16
+ CACHE_MODE="${CACHE_MODE:-raw}"
17
+ CONTEXT_LENGTH="${CONTEXT_LENGTH:-8192}"
18
+ MIN_TRADES="${MIN_TRADES:-10}"
19
+ SAMPLES_PER_TOKEN="${SAMPLES_PER_TOKEN:-1}"
20
+ OHLC_STATS_PATH="${OHLC_STATS_PATH:-/workspace/apollo/data/ohlc_stats.npz}"
21
+ OUTPUT_DIR="${OUTPUT_DIR:-data/cache}"
22
+
23
+ echo "========================================"
24
+ echo "Apollo Dataset Pre-Caching"
25
+ echo "========================================"
26
+ echo "Cache Mode: $CACHE_MODE"
27
+ if [ "$CACHE_MODE" = "context" ]; then
28
+ echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
29
+ echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
30
+ echo "Samples per Token: $SAMPLES_PER_TOKEN"
31
+ fi
32
+ echo "Output Directory: $OUTPUT_DIR"
33
+ echo "OHLC Stats Path: $OHLC_STATS_PATH"
34
+ echo "========================================"
35
 
36
  echo "Starting dataset caching..."
37
+
38
  python3 scripts/cache_dataset.py \
39
+ --ohlc_stats_path "$OHLC_STATS_PATH" \
40
+ --output_dir "$OUTPUT_DIR" \
41
+ --cache_mode "$CACHE_MODE" \
42
+ --context_length "$CONTEXT_LENGTH" \
43
+ --min_trades "$MIN_TRADES" \
44
+ --samples_per_token "$SAMPLES_PER_TOKEN" \
45
+ "$@"
46
+
47
  echo "Done!"
48
+ echo "Cache saved to: $OUTPUT_DIR"
scripts/cache_dataset.py CHANGED
@@ -194,14 +194,24 @@ def main():
194
  parser.add_argument("--start_date", type=str, default=None, help="Start date (YYYY-MM-DD) for fetching new mints")
195
  parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
196
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
197
-
 
 
 
 
 
 
 
 
 
 
198
  # DB Args
199
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
200
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
201
  parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
202
  parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
203
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
204
-
205
  args = parser.parse_args()
206
 
207
  output_dir = Path(args.output_dir)
@@ -240,7 +250,8 @@ def main():
240
  ohlc_stats_path=args.ohlc_stats_path,
241
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
242
  quantiles=[0.5],
243
- min_trade_usd=args.min_trade_usd
 
244
  )
245
 
246
  if len(dataset) == 0:
@@ -262,84 +273,186 @@ def main():
262
  if len(dataset) == 0:
263
  print("WARNING: No tokens remain after filtering by return_class_map.")
264
  return
265
-
266
- # --- 3. Iterate and cache each item ---
267
- print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
268
-
 
269
  skipped_count = 0
270
  cached_count = 0
271
-
272
- for i in tqdm(range(len(dataset)), desc="Caching samples"):
273
- mint_addr = dataset.sampled_mints[i]['mint_address']
274
-
275
- # (No need to check if in return_class_map anymore, we filtered)
276
- class_id = return_class_map[mint_addr]
277
 
278
- try:
279
- item = dataset.__cacheitem__(i)
280
- if item is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  skipped_count += 1
282
  continue
283
-
284
- # Require quality score only for samples that will be cached
285
- if mint_addr not in quality_scores_map:
286
- reason = quality_missing_reason(mint_addr)
287
- raise RuntimeError(
288
- f"Missing quality score for mint {mint_addr}. Reason: {reason}. "
289
- "Refusing to cache without quality_score."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
- q_score = quality_scores_map[mint_addr]
292
-
293
- # INJECT QUALITY SCORE INTO TENSOR DICT
294
- item["quality_score"] = q_score
295
- item["class_id"] = class_id
296
-
297
- filename = f"sample_{i}.pt"
298
- output_path = output_dir / filename
299
- torch.save(item, output_path)
300
-
301
- cached_count += 1
302
-
303
- # Log progress details (reflect all cached event lists)
304
- n_trades = len(item.get("trades", []))
305
- n_transfers = len(item.get("transfers", []))
306
- n_pool_creations = len(item.get("pool_creations", []))
307
- n_liquidity_changes = len(item.get("liquidity_changes", []))
308
- n_fee_collections = len(item.get("fee_collections", []))
309
- n_burns = len(item.get("burns", []))
310
- n_supply_locks = len(item.get("supply_locks", []))
311
- n_migrations = len(item.get("migrations", []))
312
- n_mints = 1 if item.get("mint_timestamp") else 0
313
- n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
314
- n_snapshots_5m = len(item.get("snapshots_5m", []))
315
- n_holders = len(item.get("holder_snapshots_list", []))
316
-
317
- tqdm.write(
318
- f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
319
- f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
320
- f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
321
- f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
322
- f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
323
- )
324
-
325
  except Exception as e:
326
- error_msg = str(e)
327
- # If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
328
- if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
329
- print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
330
- sys.exit(1)
331
-
332
- print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
333
- # print trackback
334
- import traceback
335
- traceback.print_exc()
336
- skipped_count += 1
337
- continue
338
-
 
 
 
 
339
  print(f"\n--- Caching Complete ---")
340
- print(f"Successfully cached: {cached_count} items.")
341
- print(f"Filtered (Invalid/High Return): {filtered_count} items.")
342
- print(f"Skipped (Errors/Empty): {skipped_count} items.")
 
 
343
  print(f"Cache location: {output_dir.resolve()}")
344
 
345
  finally:
 
194
  parser.add_argument("--start_date", type=str, default=None, help="Start date (YYYY-MM-DD) for fetching new mints")
195
  parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
196
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
197
+
198
+ # NEW: Context caching mode args
199
+ parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"],
200
+ help="Cache mode: 'raw' caches raw token data (old behavior), 'context' caches fully processed training contexts (new behavior)")
201
+ parser.add_argument("--context_length", type=int, default=8192,
202
+ help="Max sequence length for context caching mode. Triggers H/B/H dynamic sampling when events exceed this limit.")
203
+ parser.add_argument("--min_trades", type=int, default=10,
204
+ help="Minimum number of trades required for T_cutoff sampling. Tokens with fewer trades are skipped.")
205
+ parser.add_argument("--samples_per_token", type=int, default=1,
206
+ help="Number of different T_cutoff samples to generate per token in context mode.")
207
+
208
  # DB Args
209
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
210
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
211
  parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
212
  parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
213
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
214
+
215
  args = parser.parse_args()
216
 
217
  output_dir = Path(args.output_dir)
 
250
  ohlc_stats_path=args.ohlc_stats_path,
251
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
252
  quantiles=[0.5],
253
+ min_trade_usd=args.min_trade_usd,
254
+ max_seq_len=args.context_length # Pass context_length for H/B/H threshold
255
  )
256
 
257
  if len(dataset) == 0:
 
273
  if len(dataset) == 0:
274
  print("WARNING: No tokens remain after filtering by return_class_map.")
275
  return
276
+
277
+ # --- 3. Iterate and cache based on mode ---
278
+ print(f"INFO: Cache mode: {args.cache_mode}")
279
+ print(f"INFO: Starting to generate and cache from {len(dataset)} tokens...")
280
+
281
  skipped_count = 0
282
  cached_count = 0
283
+ global_sample_idx = 0 # Global counter for unique sample filenames
 
 
 
 
 
284
 
285
+ # Track class distribution for balanced sampling metadata
286
+ class_distribution = {}
287
+
288
+ if args.cache_mode == "context":
289
+ # =========================================================================
290
+ # CONTEXT MODE: Cache fully processed training contexts
291
+ # - Samples T_cutoff during caching (non-deterministic moved to cache time)
292
+ # - Applies H/B/H dynamic sampling based on context_length
293
+ # - Avoids caching tokens that won't be seen (garbage filtered out)
294
+ # - Training becomes fully deterministic (just loads cached contexts)
295
+ # =========================================================================
296
+ print(f"INFO: Context mode settings:")
297
+ print(f" - context_length (H/B/H threshold): {args.context_length}")
298
+ print(f" - min_trades (T_cutoff threshold): {args.min_trades}")
299
+ print(f" - samples_per_token: {args.samples_per_token}")
300
+
301
+ for i in tqdm(range(len(dataset)), desc="Caching contexts"):
302
+ mint_addr = dataset.sampled_mints[i]['mint_address']
303
+ class_id = return_class_map[mint_addr]
304
+
305
+ try:
306
+ # Generate multiple training contexts per token
307
+ contexts = dataset.__cacheitem_context__(i, num_samples_per_token=args.samples_per_token)
308
+
309
+ if not contexts:
310
+ skipped_count += 1
311
+ continue
312
+
313
+ # Require quality score
314
+ if mint_addr not in quality_scores_map:
315
+ reason = quality_missing_reason(mint_addr)
316
+ raise RuntimeError(
317
+ f"Missing quality score for mint {mint_addr}. Reason: {reason}."
318
+ )
319
+ q_score = quality_scores_map[mint_addr]
320
+
321
+ # Save each context as a separate sample
322
+ for ctx in contexts:
323
+ ctx["quality_score"] = q_score
324
+ ctx["class_id"] = class_id
325
+ ctx["source_token"] = mint_addr # Track origin for debugging
326
+ ctx["cache_mode"] = "context"
327
+
328
+ filename = f"sample_{global_sample_idx}.pt"
329
+ output_path = output_dir / filename
330
+ torch.save(ctx, output_path)
331
+
332
+ # Track class distribution
333
+ class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
334
+
335
+ global_sample_idx += 1
336
+ cached_count += 1
337
+
338
+ n_events = len(contexts[0].get("event_sequence", [])) if contexts else 0
339
+ tqdm.write(
340
+ f" + Cached {len(contexts)} contexts: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | Events: {n_events}"
341
+ )
342
+
343
+ except Exception as e:
344
+ error_msg = str(e)
345
+ if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
346
+ print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
347
+ sys.exit(1)
348
+
349
+ print(f"\nERROR: Failed to cache contexts for {mint_addr}. Error: {e}", file=sys.stderr)
350
+ import traceback
351
+ traceback.print_exc()
352
  skipped_count += 1
353
  continue
354
+
355
+ else:
356
+ # =========================================================================
357
+ # RAW MODE: Cache raw token data (original behavior)
358
+ # - T_cutoff sampling happens at runtime
359
+ # - H/B/H applied at runtime
360
+ # - Non-deterministic training
361
+ # =========================================================================
362
+ for i in tqdm(range(len(dataset)), desc="Caching raw samples"):
363
+ mint_addr = dataset.sampled_mints[i]['mint_address']
364
+ class_id = return_class_map[mint_addr]
365
+
366
+ try:
367
+ item = dataset.__cacheitem__(i)
368
+ if item is None:
369
+ skipped_count += 1
370
+ continue
371
+
372
+ if mint_addr not in quality_scores_map:
373
+ reason = quality_missing_reason(mint_addr)
374
+ raise RuntimeError(
375
+ f"Missing quality score for mint {mint_addr}. Reason: {reason}."
376
+ )
377
+ q_score = quality_scores_map[mint_addr]
378
+
379
+ item["quality_score"] = q_score
380
+ item["class_id"] = class_id
381
+ item["cache_mode"] = "raw"
382
+
383
+ filename = f"sample_{i}.pt"
384
+ output_path = output_dir / filename
385
+ torch.save(item, output_path)
386
+
387
+ # Track class distribution
388
+ class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
389
+
390
+ cached_count += 1
391
+
392
+ n_trades = len(item.get("trades", []))
393
+ n_transfers = len(item.get("transfers", []))
394
+ n_pool_creations = len(item.get("pool_creations", []))
395
+ n_liquidity_changes = len(item.get("liquidity_changes", []))
396
+ n_fee_collections = len(item.get("fee_collections", []))
397
+ n_burns = len(item.get("burns", []))
398
+ n_supply_locks = len(item.get("supply_locks", []))
399
+ n_migrations = len(item.get("migrations", []))
400
+ n_mints = 1 if item.get("mint_timestamp") else 0
401
+ n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
402
+ n_snapshots_5m = len(item.get("snapshots_5m", []))
403
+ n_holders = len(item.get("holder_snapshots_list", []))
404
+
405
+ tqdm.write(
406
+ f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
407
+ f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
408
+ f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
409
+ f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
410
+ f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
411
  )
412
+
413
+ except Exception as e:
414
+ error_msg = str(e)
415
+ if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
416
+ print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
417
+ sys.exit(1)
418
+
419
+ print(f"\nERROR: Failed to cache sample {i} for {mint_addr}. Error: {e}", file=sys.stderr)
420
+ import traceback
421
+ traceback.print_exc()
422
+ skipped_count += 1
423
+ continue
424
+
425
+ # --- Save class metadata for balanced sampling ---
426
+ # Build file_class_map for the metadata cache
427
+ file_class_map = {}
428
+ for sample_file in sorted(output_dir.glob("sample_*.pt")):
429
+ try:
430
+ sample_data = torch.load(sample_file, map_location="cpu", weights_only=False)
431
+ file_class_map[sample_file.name] = sample_data.get("class_id", 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  except Exception as e:
433
+ print(f"WARN: Could not read class_id from {sample_file.name}: {e}")
434
+
435
+ metadata_path = output_dir / "class_metadata.json"
436
+ try:
437
+ with open(metadata_path, 'w') as f:
438
+ json.dump({
439
+ 'file_class_map': file_class_map,
440
+ 'class_distribution': class_distribution,
441
+ 'cache_mode': args.cache_mode,
442
+ 'context_length': args.context_length if args.cache_mode == "context" else None,
443
+ 'min_trades': args.min_trades if args.cache_mode == "context" else None,
444
+ 'samples_per_token': args.samples_per_token if args.cache_mode == "context" else None,
445
+ }, f, indent=2)
446
+ print(f"INFO: Saved class metadata to {metadata_path}")
447
+ except Exception as e:
448
+ print(f"WARN: Failed to save class metadata: {e}")
449
+
450
  print(f"\n--- Caching Complete ---")
451
+ print(f"Cache mode: {args.cache_mode}")
452
+ print(f"Successfully cached: {cached_count} samples.")
453
+ print(f"Filtered (Invalid/High Return): {filtered_count} tokens.")
454
+ print(f"Skipped (Errors/Empty): {skipped_count} tokens.")
455
+ print(f"Class distribution: {class_distribution}")
456
  print(f"Cache location: {output_dir.resolve()}")
457
 
458
  finally: