Upload folder using huggingface_hub
Browse files- .claude/settings.local.json +6 -1
- data/data_loader.py +440 -20
- pre_cache.sh +42 -1
- scripts/cache_dataset.py +186 -73
.claude/settings.local.json
CHANGED
|
@@ -9,7 +9,12 @@
|
|
| 9 |
"Bash(python3:*)",
|
| 10 |
"Bash(dir:*)",
|
| 11 |
"Bash(cmd /c \"dir /s /b\")",
|
| 12 |
-
"Bash(python -c:*)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
]
|
| 14 |
}
|
| 15 |
}
|
|
|
|
| 9 |
"Bash(python3:*)",
|
| 10 |
"Bash(dir:*)",
|
| 11 |
"Bash(cmd /c \"dir /s /b\")",
|
| 12 |
+
"Bash(python -c:*)",
|
| 13 |
+
"Bash(pip install:*)",
|
| 14 |
+
"Bash(huggingface-cli login:*)",
|
| 15 |
+
"Bash(hf whoami:*)",
|
| 16 |
+
"Bash(huggingface-cli whoami:*)",
|
| 17 |
+
"Bash(python -m huggingface_hub.commands.huggingface_cli:*)"
|
| 18 |
]
|
| 19 |
}
|
| 20 |
}
|
data/data_loader.py
CHANGED
|
@@ -1080,36 +1080,51 @@ class OracleDataset(Dataset):
|
|
| 1080 |
|
| 1081 |
def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 1082 |
"""
|
| 1083 |
-
Loads
|
| 1084 |
|
| 1085 |
-
|
|
|
|
|
|
|
|
|
|
| 1086 |
"""
|
| 1087 |
import time as _time
|
| 1088 |
_timings = {}
|
| 1089 |
_total_start = _time.perf_counter()
|
| 1090 |
|
| 1091 |
-
# --- REMOVED: No more fetcher initialization during training ---
|
| 1092 |
-
# We use fully offline mode with pre-cached data
|
| 1093 |
-
_timings['fetcher_init'] = 0.0
|
| 1094 |
-
|
| 1095 |
# --- TIMING: Cache load ---
|
| 1096 |
_t0 = _time.perf_counter()
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
raise RuntimeError(f"Offline mode required. No cache directory provided or configured.")
|
| 1109 |
_timings['cache_load'] = _time.perf_counter() - _t0
|
| 1110 |
|
| 1111 |
-
if not
|
| 1112 |
-
raise RuntimeError(f"No
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
|
| 1114 |
required_keys = [
|
| 1115 |
"mint_timestamp",
|
|
@@ -2347,3 +2362,408 @@ class OracleDataset(Dataset):
|
|
| 2347 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2348 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
| 2349 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
|
| 1081 |
def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 1082 |
"""
|
| 1083 |
+
Loads data from cache. Behavior depends on cache mode:
|
| 1084 |
|
| 1085 |
+
- RAW MODE: Loads raw token data, samples T_cutoff at runtime, applies H/B/H
|
| 1086 |
+
- CONTEXT MODE: Loads pre-computed training context directly (fully offline)
|
| 1087 |
+
|
| 1088 |
+
The cache mode is auto-detected from the cached file's 'cache_mode' field.
|
| 1089 |
"""
|
| 1090 |
import time as _time
|
| 1091 |
_timings = {}
|
| 1092 |
_total_start = _time.perf_counter()
|
| 1093 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1094 |
# --- TIMING: Cache load ---
|
| 1095 |
_t0 = _time.perf_counter()
|
| 1096 |
+
if not self.cache_dir:
|
| 1097 |
+
raise RuntimeError("Offline mode required. No cache directory provided.")
|
| 1098 |
+
|
| 1099 |
+
if idx >= len(self.cached_files):
|
| 1100 |
+
raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.")
|
| 1101 |
+
|
| 1102 |
+
filepath = self.cached_files[idx]
|
| 1103 |
+
try:
|
| 1104 |
+
cached_data = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 1105 |
+
except Exception as e:
|
| 1106 |
+
raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
|
|
|
|
| 1107 |
_timings['cache_load'] = _time.perf_counter() - _t0
|
| 1108 |
|
| 1109 |
+
if not cached_data:
|
| 1110 |
+
raise RuntimeError(f"No data loaded for index {idx}")
|
| 1111 |
+
|
| 1112 |
+
# Auto-detect cache mode
|
| 1113 |
+
cache_mode = cached_data.get('cache_mode', 'raw')
|
| 1114 |
+
|
| 1115 |
+
if cache_mode == 'context':
|
| 1116 |
+
# CONTEXT MODE: Return pre-computed training context directly
|
| 1117 |
+
# This is fully deterministic - no runtime sampling or processing
|
| 1118 |
+
_timings['total'] = _time.perf_counter() - _total_start
|
| 1119 |
+
|
| 1120 |
+
if idx % 100 == 0:
|
| 1121 |
+
print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
|
| 1122 |
+
f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
|
| 1123 |
+
|
| 1124 |
+
return cached_data
|
| 1125 |
+
|
| 1126 |
+
# RAW MODE: Fall through to original __getitem__ logic with runtime T_cutoff sampling
|
| 1127 |
+
raw_data = cached_data
|
| 1128 |
|
| 1129 |
required_keys = [
|
| 1130 |
"mint_timestamp",
|
|
|
|
| 2362 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2363 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
| 2364 |
}
|
| 2365 |
+
|
| 2366 |
+
def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1) -> List[Optional[Dict[str, Any]]]:
|
| 2367 |
+
"""
|
| 2368 |
+
Generates fully processed training contexts for caching.
|
| 2369 |
+
|
| 2370 |
+
This method:
|
| 2371 |
+
1. Fetches raw token data (like __cacheitem__)
|
| 2372 |
+
2. Samples T_cutoff(s) using the weight sampling logic
|
| 2373 |
+
3. Applies H/B/H dynamic sampling based on max_seq_len
|
| 2374 |
+
4. Returns complete training-ready samples that can be loaded directly
|
| 2375 |
+
|
| 2376 |
+
This moves ALL non-determinism into cache time, making training fully offline
|
| 2377 |
+
and avoiding caching tokens that would never be seen (98% garbage filtered out
|
| 2378 |
+
by weight sampling and T_cutoff eligibility).
|
| 2379 |
+
|
| 2380 |
+
Args:
|
| 2381 |
+
idx: Index into sampled_mints
|
| 2382 |
+
num_samples_per_token: Number of different T_cutoff samples to generate per token
|
| 2383 |
+
|
| 2384 |
+
Returns:
|
| 2385 |
+
List of training-ready samples (may be fewer than num_samples_per_token if
|
| 2386 |
+
some T_cutoffs are invalid)
|
| 2387 |
+
"""
|
| 2388 |
+
import time as _time
|
| 2389 |
+
|
| 2390 |
+
if not self.sampled_mints:
|
| 2391 |
+
raise RuntimeError("Dataset has no mint records loaded.")
|
| 2392 |
+
if idx >= len(self.sampled_mints):
|
| 2393 |
+
raise IndexError(f"Index {idx} exceeds mint count {len(self.sampled_mints)}.")
|
| 2394 |
+
|
| 2395 |
+
initial_mint_record = self.sampled_mints[idx]
|
| 2396 |
+
t0 = initial_mint_record["timestamp"]
|
| 2397 |
+
if isinstance(t0, datetime.datetime) and t0.tzinfo is None:
|
| 2398 |
+
t0 = t0.replace(tzinfo=datetime.timezone.utc)
|
| 2399 |
+
creator_address = initial_mint_record['creator_address']
|
| 2400 |
+
token_address = initial_mint_record['mint_address']
|
| 2401 |
+
|
| 2402 |
+
print(f"\n--- Caching CONTEXT for token: {token_address} (generating {num_samples_per_token} samples) ---")
|
| 2403 |
+
|
| 2404 |
+
if not self.fetcher:
|
| 2405 |
+
raise RuntimeError("Dataset has no data fetcher.")
|
| 2406 |
+
|
| 2407 |
+
# --- STEP 1: Fetch raw data (same as __cacheitem__) ---
|
| 2408 |
+
raw_data = self.fetcher.fetch_raw_token_data(
|
| 2409 |
+
token_address=token_address,
|
| 2410 |
+
creator_address=creator_address,
|
| 2411 |
+
mint_timestamp=t0,
|
| 2412 |
+
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 2413 |
+
include_wallet_data=False,
|
| 2414 |
+
include_graph=False,
|
| 2415 |
+
min_trades=10,
|
| 2416 |
+
full_history=True,
|
| 2417 |
+
prune_failed=False,
|
| 2418 |
+
prune_transfers=False
|
| 2419 |
+
)
|
| 2420 |
+
|
| 2421 |
+
if raw_data is None:
|
| 2422 |
+
print(f" SKIP: No raw data for {token_address}")
|
| 2423 |
+
return []
|
| 2424 |
+
|
| 2425 |
+
def _timestamp_to_order_value(ts_value) -> float:
|
| 2426 |
+
if isinstance(ts_value, datetime.datetime):
|
| 2427 |
+
if ts_value.tzinfo is None:
|
| 2428 |
+
ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
|
| 2429 |
+
return ts_value.timestamp()
|
| 2430 |
+
try:
|
| 2431 |
+
return float(ts_value)
|
| 2432 |
+
except:
|
| 2433 |
+
return 0.0
|
| 2434 |
+
|
| 2435 |
+
# --- STEP 2: Validate trades and find eligible T_cutoff indices ---
|
| 2436 |
+
all_trades_raw = raw_data.get('trades', [])
|
| 2437 |
+
if not all_trades_raw:
|
| 2438 |
+
print(f" SKIP: No trades for {token_address}")
|
| 2439 |
+
return []
|
| 2440 |
+
|
| 2441 |
+
all_trades_sorted = sorted(
|
| 2442 |
+
[t for t in all_trades_raw if t.get('timestamp') is not None],
|
| 2443 |
+
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 2444 |
+
)
|
| 2445 |
+
|
| 2446 |
+
min_context_trades = 10
|
| 2447 |
+
if len(all_trades_sorted) < (min_context_trades + 1):
|
| 2448 |
+
print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}")
|
| 2449 |
+
return []
|
| 2450 |
+
|
| 2451 |
+
# Find successful trade indices
|
| 2452 |
+
successful_indices = [
|
| 2453 |
+
i for i, t in enumerate(all_trades_sorted)
|
| 2454 |
+
if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
|
| 2455 |
+
]
|
| 2456 |
+
|
| 2457 |
+
if len(successful_indices) < 2:
|
| 2458 |
+
print(f" SKIP: Not enough successful trades for {token_address}")
|
| 2459 |
+
return []
|
| 2460 |
+
|
| 2461 |
+
max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
|
| 2462 |
+
min_idx = min_context_trades - 1
|
| 2463 |
+
max_idx = len(all_trades_sorted) - 2
|
| 2464 |
+
|
| 2465 |
+
if max_idx < min_idx:
|
| 2466 |
+
print(f" SKIP: Invalid index range for {token_address}")
|
| 2467 |
+
return []
|
| 2468 |
+
|
| 2469 |
+
# Build lookup arrays
|
| 2470 |
+
last_successful_before = [-1] * len(all_trades_sorted)
|
| 2471 |
+
last_seen = -1
|
| 2472 |
+
succ_set = set(successful_indices)
|
| 2473 |
+
for i in range(len(all_trades_sorted)):
|
| 2474 |
+
if i in succ_set:
|
| 2475 |
+
last_seen = i
|
| 2476 |
+
last_successful_before[i] = last_seen
|
| 2477 |
+
|
| 2478 |
+
next_successful_after = [-1] * len(all_trades_sorted)
|
| 2479 |
+
next_seen = -1
|
| 2480 |
+
for i in range(len(all_trades_sorted) - 1, -1, -1):
|
| 2481 |
+
if i in succ_set:
|
| 2482 |
+
next_seen = i
|
| 2483 |
+
next_successful_after[i] = next_seen
|
| 2484 |
+
|
| 2485 |
+
# Find all eligible T_cutoff indices
|
| 2486 |
+
eligible_indices = []
|
| 2487 |
+
for i in range(min_idx, max_idx + 1):
|
| 2488 |
+
anchor_idx = last_successful_before[i]
|
| 2489 |
+
next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1
|
| 2490 |
+
if anchor_idx < 0 or next_idx < 0:
|
| 2491 |
+
continue
|
| 2492 |
+
cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp'))
|
| 2493 |
+
next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp'))
|
| 2494 |
+
if next_ts <= cutoff_ts + max_horizon_seconds:
|
| 2495 |
+
eligible_indices.append(i)
|
| 2496 |
+
|
| 2497 |
+
if not eligible_indices:
|
| 2498 |
+
print(f" SKIP: No eligible T_cutoff indices for {token_address}")
|
| 2499 |
+
return []
|
| 2500 |
+
|
| 2501 |
+
print(f" INFO: Found {len(eligible_indices)} eligible T_cutoff positions")
|
| 2502 |
+
|
| 2503 |
+
# --- STEP 3: Generate OHLC and holder snapshots (same as __cacheitem__) ---
|
| 2504 |
+
trades = raw_data.get('trades', [])
|
| 2505 |
+
trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades]
|
| 2506 |
+
t0_val = _timestamp_to_order_value(t0)
|
| 2507 |
+
last_trade_ts_val = max(trade_ts_values)
|
| 2508 |
+
|
| 2509 |
+
duration_seconds = int(last_trade_ts_val - t0_val) + 120
|
| 2510 |
+
ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32)
|
| 2511 |
+
|
| 2512 |
+
trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
|
| 2513 |
+
trades_by_sec = defaultdict(list)
|
| 2514 |
+
for t in trades:
|
| 2515 |
+
ts = _timestamp_to_order_value(t['timestamp'])
|
| 2516 |
+
sec_idx = int(ts - t0_val)
|
| 2517 |
+
if 0 <= sec_idx < duration_seconds:
|
| 2518 |
+
trades_by_sec[sec_idx].append(t['price_usd'])
|
| 2519 |
+
|
| 2520 |
+
last_close = float(trades[0]['price_usd'])
|
| 2521 |
+
for i in range(duration_seconds):
|
| 2522 |
+
if i in trades_by_sec:
|
| 2523 |
+
prices = trades_by_sec[i]
|
| 2524 |
+
op, cl = prices[0], prices[-1]
|
| 2525 |
+
last_close = cl
|
| 2526 |
+
else:
|
| 2527 |
+
op = cl = last_close
|
| 2528 |
+
ohlc_1s[i, 0] = float(op)
|
| 2529 |
+
ohlc_1s[i, 1] = float(cl)
|
| 2530 |
+
|
| 2531 |
+
raw_data['ohlc_1s'] = ohlc_1s
|
| 2532 |
+
|
| 2533 |
+
# Generate holder snapshots
|
| 2534 |
+
interval = 300
|
| 2535 |
+
num_intervals = (duration_seconds // interval) + 1
|
| 2536 |
+
snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32)
|
| 2537 |
+
|
| 2538 |
+
buckets = defaultdict(list)
|
| 2539 |
+
for t in trades:
|
| 2540 |
+
ts = _timestamp_to_order_value(t['timestamp'])
|
| 2541 |
+
bucket_idx = int(ts - t0_val) // interval
|
| 2542 |
+
if bucket_idx >= 0:
|
| 2543 |
+
buckets[bucket_idx].append(t)
|
| 2544 |
+
|
| 2545 |
+
holder_snapshots_list = []
|
| 2546 |
+
for i in range(num_intervals):
|
| 2547 |
+
bucket_trades = buckets[i]
|
| 2548 |
+
vol = sum(t.get('total_usd', 0.0) for t in bucket_trades)
|
| 2549 |
+
tx = len(bucket_trades)
|
| 2550 |
+
buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0)
|
| 2551 |
+
sells = tx - buys
|
| 2552 |
+
|
| 2553 |
+
snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
|
| 2554 |
+
count, top_holders = self.fetcher.fetch_holder_snapshot_stats_for_token(
|
| 2555 |
+
token_address, snapshot_ts, limit=HOLDER_SNAPSHOT_TOP_K
|
| 2556 |
+
)
|
| 2557 |
+
|
| 2558 |
+
total_supply = raw_data.get('total_supply', 0) or 1
|
| 2559 |
+
if raw_data.get('decimals'):
|
| 2560 |
+
total_supply /= (10 ** raw_data['decimals'])
|
| 2561 |
+
|
| 2562 |
+
top10_bal = sum(h.get('current_balance', 0) for h in top_holders[:10])
|
| 2563 |
+
top10_pct = (top10_bal / total_supply) if total_supply > 0 else 0.0
|
| 2564 |
+
|
| 2565 |
+
snapshot_stats[i, 0] = float(vol)
|
| 2566 |
+
snapshot_stats[i, 1] = float(tx)
|
| 2567 |
+
snapshot_stats[i, 2] = float(buys)
|
| 2568 |
+
snapshot_stats[i, 3] = float(sells)
|
| 2569 |
+
snapshot_stats[i, 4] = float(count)
|
| 2570 |
+
snapshot_stats[i, 5] = float(top10_pct)
|
| 2571 |
+
|
| 2572 |
+
holder_snapshots_list.append({
|
| 2573 |
+
'timestamp': int(snapshot_ts.timestamp()),
|
| 2574 |
+
'holders': top_holders
|
| 2575 |
+
})
|
| 2576 |
+
|
| 2577 |
+
raw_data['snapshots_5m'] = snapshot_stats
|
| 2578 |
+
raw_data['holder_snapshots_list'] = holder_snapshots_list
|
| 2579 |
+
raw_data['protocol_id'] = initial_mint_record.get('protocol')
|
| 2580 |
+
|
| 2581 |
+
# --- STEP 4: Collect ALL wallets and pre-fetch their data ---
|
| 2582 |
+
all_wallets = set()
|
| 2583 |
+
all_wallets.add(creator_address)
|
| 2584 |
+
|
| 2585 |
+
for trade in raw_data.get('trades', []):
|
| 2586 |
+
if trade.get('maker'):
|
| 2587 |
+
all_wallets.add(trade['maker'])
|
| 2588 |
+
for transfer in raw_data.get('transfers', []):
|
| 2589 |
+
if transfer.get('source'):
|
| 2590 |
+
all_wallets.add(transfer['source'])
|
| 2591 |
+
if transfer.get('destination'):
|
| 2592 |
+
all_wallets.add(transfer['destination'])
|
| 2593 |
+
for pool in raw_data.get('pool_creations', []):
|
| 2594 |
+
if pool.get('creator_address'):
|
| 2595 |
+
all_wallets.add(pool['creator_address'])
|
| 2596 |
+
for liq in raw_data.get('liquidity_changes', []):
|
| 2597 |
+
if liq.get('lp_provider'):
|
| 2598 |
+
all_wallets.add(liq['lp_provider'])
|
| 2599 |
+
for snapshot in holder_snapshots_list:
|
| 2600 |
+
for holder in snapshot.get('holders', []):
|
| 2601 |
+
if holder.get('wallet_address'):
|
| 2602 |
+
all_wallets.add(holder['wallet_address'])
|
| 2603 |
+
|
| 2604 |
+
all_wallets.discard(None)
|
| 2605 |
+
all_wallets.discard('')
|
| 2606 |
+
wallet_list = list(all_wallets)
|
| 2607 |
+
|
| 2608 |
+
max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc)
|
| 2609 |
+
|
| 2610 |
+
try:
|
| 2611 |
+
cached_profiles, cached_socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_list, max_T_cutoff)
|
| 2612 |
+
except Exception as e:
|
| 2613 |
+
print(f" WARN: Failed to fetch wallet profiles/socials: {e}")
|
| 2614 |
+
cached_profiles, cached_socials = {}, {}
|
| 2615 |
+
|
| 2616 |
+
try:
|
| 2617 |
+
cached_holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff)
|
| 2618 |
+
except Exception as e:
|
| 2619 |
+
print(f" WARN: Failed to fetch wallet holdings: {e}")
|
| 2620 |
+
cached_holdings = {}
|
| 2621 |
+
|
| 2622 |
+
try:
|
| 2623 |
+
cached_graph_entities, cached_graph_links = self.fetcher.fetch_graph_links(
|
| 2624 |
+
wallet_list, max_T_cutoff, max_degrees=1
|
| 2625 |
+
)
|
| 2626 |
+
except Exception as e:
|
| 2627 |
+
print(f" WARN: Failed to fetch graph links: {e}")
|
| 2628 |
+
cached_graph_entities, cached_graph_links = {}, {}
|
| 2629 |
+
|
| 2630 |
+
# Fetch token image
|
| 2631 |
+
cached_image_bytes = None
|
| 2632 |
+
try:
|
| 2633 |
+
bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0"
|
| 2634 |
+
resp = self.http_session.get(bullx_image_url, timeout=5)
|
| 2635 |
+
if resp.status_code == 200:
|
| 2636 |
+
cached_image_bytes = resp.content
|
| 2637 |
+
except:
|
| 2638 |
+
pass
|
| 2639 |
+
|
| 2640 |
+
# --- STEP 5: Sample T_cutoffs and generate complete training contexts ---
|
| 2641 |
+
results = []
|
| 2642 |
+
|
| 2643 |
+
# Sample indices (with replacement if needed)
|
| 2644 |
+
if num_samples_per_token >= len(eligible_indices):
|
| 2645 |
+
sampled_indices = eligible_indices.copy()
|
| 2646 |
+
else:
|
| 2647 |
+
sampled_indices = random.sample(eligible_indices, num_samples_per_token)
|
| 2648 |
+
|
| 2649 |
+
print(f" INFO: Generating {len(sampled_indices)} training contexts...")
|
| 2650 |
+
|
| 2651 |
+
for sample_num, sample_idx in enumerate(sampled_indices):
|
| 2652 |
+
sample_trade = all_trades_sorted[sample_idx]
|
| 2653 |
+
sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp'))
|
| 2654 |
+
T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
|
| 2655 |
+
cutoff_ts = sample_offset_ts
|
| 2656 |
+
|
| 2657 |
+
# Collect wallets visible at T_cutoff
|
| 2658 |
+
wallets_to_fetch = set()
|
| 2659 |
+
wallets_to_fetch.add(creator_address)
|
| 2660 |
+
|
| 2661 |
+
for trade in raw_data.get('trades', []):
|
| 2662 |
+
if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts:
|
| 2663 |
+
if trade.get('maker'):
|
| 2664 |
+
wallets_to_fetch.add(trade['maker'])
|
| 2665 |
+
|
| 2666 |
+
for transfer in raw_data.get('transfers', []):
|
| 2667 |
+
if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts:
|
| 2668 |
+
if transfer.get('source'):
|
| 2669 |
+
wallets_to_fetch.add(transfer['source'])
|
| 2670 |
+
if transfer.get('destination'):
|
| 2671 |
+
wallets_to_fetch.add(transfer['destination'])
|
| 2672 |
+
|
| 2673 |
+
for pool in raw_data.get('pool_creations', []):
|
| 2674 |
+
if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts:
|
| 2675 |
+
if pool.get('creator_address'):
|
| 2676 |
+
wallets_to_fetch.add(pool['creator_address'])
|
| 2677 |
+
|
| 2678 |
+
for liq in raw_data.get('liquidity_changes', []):
|
| 2679 |
+
if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
|
| 2680 |
+
if liq.get('lp_provider'):
|
| 2681 |
+
wallets_to_fetch.add(liq['lp_provider'])
|
| 2682 |
+
|
| 2683 |
+
# Get holder snapshot at T_cutoff
|
| 2684 |
+
elapsed = (T_cutoff - t0).total_seconds()
|
| 2685 |
+
snap_idx = int(elapsed // 300)
|
| 2686 |
+
if 0 <= snap_idx < len(holder_snapshots_list):
|
| 2687 |
+
snapshot_data = holder_snapshots_list[snap_idx]
|
| 2688 |
+
for holder in snapshot_data.get('holders', []):
|
| 2689 |
+
if holder.get('wallet_address'):
|
| 2690 |
+
wallets_to_fetch.add(holder['wallet_address'])
|
| 2691 |
+
|
| 2692 |
+
wallets_to_fetch.discard(None)
|
| 2693 |
+
wallets_to_fetch.discard('')
|
| 2694 |
+
|
| 2695 |
+
# Build offline data for this context
|
| 2696 |
+
pooler = EmbeddingPooler()
|
| 2697 |
+
|
| 2698 |
+
# Process token data offline
|
| 2699 |
+
offline_token_data = {token_address: raw_data.copy()}
|
| 2700 |
+
if cached_image_bytes:
|
| 2701 |
+
try:
|
| 2702 |
+
cached_image = Image.open(BytesIO(cached_image_bytes))
|
| 2703 |
+
offline_token_data[token_address]['_cached_image_pil'] = cached_image
|
| 2704 |
+
except:
|
| 2705 |
+
pass
|
| 2706 |
+
|
| 2707 |
+
main_token_data = self._process_token_data_offline(
|
| 2708 |
+
[token_address], pooler, T_cutoff, token_data=offline_token_data
|
| 2709 |
+
)
|
| 2710 |
+
|
| 2711 |
+
if not main_token_data:
|
| 2712 |
+
continue
|
| 2713 |
+
|
| 2714 |
+
# Process wallet data offline
|
| 2715 |
+
wallet_data, all_token_data = self._process_wallet_data(
|
| 2716 |
+
list(wallets_to_fetch),
|
| 2717 |
+
main_token_data.copy(),
|
| 2718 |
+
pooler,
|
| 2719 |
+
T_cutoff,
|
| 2720 |
+
profiles_override=cached_profiles,
|
| 2721 |
+
socials_override=cached_socials,
|
| 2722 |
+
holdings_override=cached_holdings
|
| 2723 |
+
)
|
| 2724 |
+
|
| 2725 |
+
# Generate the complete training item (with H/B/H applied via _generate_dataset_item)
|
| 2726 |
+
mint_event = {
|
| 2727 |
+
'event_type': 'Mint',
|
| 2728 |
+
'timestamp': int(t0.timestamp()),
|
| 2729 |
+
'relative_ts': 0,
|
| 2730 |
+
'wallet_address': creator_address,
|
| 2731 |
+
'token_address': token_address,
|
| 2732 |
+
'protocol_id': raw_data.get('protocol_id', 0)
|
| 2733 |
+
}
|
| 2734 |
+
|
| 2735 |
+
result = self._generate_dataset_item(
|
| 2736 |
+
token_address=token_address,
|
| 2737 |
+
t0=t0,
|
| 2738 |
+
T_cutoff=T_cutoff,
|
| 2739 |
+
mint_event=mint_event,
|
| 2740 |
+
trade_records=raw_data['trades'],
|
| 2741 |
+
transfer_records=raw_data['transfers'],
|
| 2742 |
+
pool_creation_records=raw_data['pool_creations'],
|
| 2743 |
+
liquidity_change_records=raw_data['liquidity_changes'],
|
| 2744 |
+
fee_collection_records=raw_data['fee_collections'],
|
| 2745 |
+
burn_records=raw_data['burns'],
|
| 2746 |
+
supply_lock_records=raw_data['supply_locks'],
|
| 2747 |
+
migration_records=raw_data['migrations'],
|
| 2748 |
+
wallet_data=wallet_data,
|
| 2749 |
+
all_token_data=all_token_data,
|
| 2750 |
+
graph_links=cached_graph_links,
|
| 2751 |
+
graph_seed_entities=wallets_to_fetch,
|
| 2752 |
+
all_graph_entities=cached_graph_entities,
|
| 2753 |
+
future_trades_for_labels=raw_data['trades'],
|
| 2754 |
+
pooler=pooler,
|
| 2755 |
+
sample_idx=idx,
|
| 2756 |
+
cached_holders_list=holder_snapshots_list,
|
| 2757 |
+
cached_ohlc_1s=ohlc_1s,
|
| 2758 |
+
quality_score=None # Will be injected by cache_dataset.py
|
| 2759 |
+
)
|
| 2760 |
+
|
| 2761 |
+
if result is not None:
|
| 2762 |
+
# Store the T_cutoff used for this sample (for reproducibility tracking)
|
| 2763 |
+
result['cached_t_cutoff_ts'] = sample_offset_ts
|
| 2764 |
+
result['cached_sample_num'] = sample_num
|
| 2765 |
+
results.append(result)
|
| 2766 |
+
print(f" + Context {sample_num}: T_cutoff={T_cutoff.isoformat()}, events={len(result.get('event_sequence', []))}")
|
| 2767 |
+
|
| 2768 |
+
print(f" INFO: Generated {len(results)} valid training contexts for {token_address}")
|
| 2769 |
+
return results
|
pre_cache.sh
CHANGED
|
@@ -1,7 +1,48 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
# Pre-caches the dataset for training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
echo "Starting dataset caching..."
|
|
|
|
| 5 |
python3 scripts/cache_dataset.py \
|
| 6 |
-
--ohlc_stats_path "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
echo "Done!"
|
|
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
# Pre-caches the dataset for training
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# ./pre_cache.sh # Raw mode (old behavior)
|
| 6 |
+
# ./pre_cache.sh --cache_mode context # Context mode (new fully offline)
|
| 7 |
+
#
|
| 8 |
+
# Context mode arguments:
|
| 9 |
+
# --context_length N Max sequence length, triggers H/B/H when exceeded (default: 8192)
|
| 10 |
+
# --min_trades N Minimum trades for T_cutoff sampling (default: 10)
|
| 11 |
+
# --samples_per_token N Number of T_cutoff samples per token (default: 1)
|
| 12 |
+
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
# Default values
|
| 16 |
+
CACHE_MODE="${CACHE_MODE:-raw}"
|
| 17 |
+
CONTEXT_LENGTH="${CONTEXT_LENGTH:-8192}"
|
| 18 |
+
MIN_TRADES="${MIN_TRADES:-10}"
|
| 19 |
+
SAMPLES_PER_TOKEN="${SAMPLES_PER_TOKEN:-1}"
|
| 20 |
+
OHLC_STATS_PATH="${OHLC_STATS_PATH:-/workspace/apollo/data/ohlc_stats.npz}"
|
| 21 |
+
OUTPUT_DIR="${OUTPUT_DIR:-data/cache}"
|
| 22 |
+
|
| 23 |
+
echo "========================================"
|
| 24 |
+
echo "Apollo Dataset Pre-Caching"
|
| 25 |
+
echo "========================================"
|
| 26 |
+
echo "Cache Mode: $CACHE_MODE"
|
| 27 |
+
if [ "$CACHE_MODE" = "context" ]; then
|
| 28 |
+
echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
|
| 29 |
+
echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
|
| 30 |
+
echo "Samples per Token: $SAMPLES_PER_TOKEN"
|
| 31 |
+
fi
|
| 32 |
+
echo "Output Directory: $OUTPUT_DIR"
|
| 33 |
+
echo "OHLC Stats Path: $OHLC_STATS_PATH"
|
| 34 |
+
echo "========================================"
|
| 35 |
|
| 36 |
echo "Starting dataset caching..."
|
| 37 |
+
|
| 38 |
python3 scripts/cache_dataset.py \
|
| 39 |
+
--ohlc_stats_path "$OHLC_STATS_PATH" \
|
| 40 |
+
--output_dir "$OUTPUT_DIR" \
|
| 41 |
+
--cache_mode "$CACHE_MODE" \
|
| 42 |
+
--context_length "$CONTEXT_LENGTH" \
|
| 43 |
+
--min_trades "$MIN_TRADES" \
|
| 44 |
+
--samples_per_token "$SAMPLES_PER_TOKEN" \
|
| 45 |
+
"$@"
|
| 46 |
+
|
| 47 |
echo "Done!"
|
| 48 |
+
echo "Cache saved to: $OUTPUT_DIR"
|
scripts/cache_dataset.py
CHANGED
|
@@ -194,14 +194,24 @@ def main():
|
|
| 194 |
parser.add_argument("--start_date", type=str, default=None, help="Start date (YYYY-MM-DD) for fetching new mints")
|
| 195 |
parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
|
| 196 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
# DB Args
|
| 199 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 200 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
| 201 |
parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
|
| 202 |
parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
|
| 203 |
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
|
| 204 |
-
|
| 205 |
args = parser.parse_args()
|
| 206 |
|
| 207 |
output_dir = Path(args.output_dir)
|
|
@@ -240,7 +250,8 @@ def main():
|
|
| 240 |
ohlc_stats_path=args.ohlc_stats_path,
|
| 241 |
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
|
| 242 |
quantiles=[0.5],
|
| 243 |
-
min_trade_usd=args.min_trade_usd
|
|
|
|
| 244 |
)
|
| 245 |
|
| 246 |
if len(dataset) == 0:
|
|
@@ -262,84 +273,186 @@ def main():
|
|
| 262 |
if len(dataset) == 0:
|
| 263 |
print("WARNING: No tokens remain after filtering by return_class_map.")
|
| 264 |
return
|
| 265 |
-
|
| 266 |
-
# --- 3. Iterate and cache
|
| 267 |
-
print(f"INFO:
|
| 268 |
-
|
|
|
|
| 269 |
skipped_count = 0
|
| 270 |
cached_count = 0
|
| 271 |
-
|
| 272 |
-
for i in tqdm(range(len(dataset)), desc="Caching samples"):
|
| 273 |
-
mint_addr = dataset.sampled_mints[i]['mint_address']
|
| 274 |
-
|
| 275 |
-
# (No need to check if in return_class_map anymore, we filtered)
|
| 276 |
-
class_id = return_class_map[mint_addr]
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
skipped_count += 1
|
| 282 |
continue
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
n_migrations = len(item.get("migrations", []))
|
| 312 |
-
n_mints = 1 if item.get("mint_timestamp") else 0
|
| 313 |
-
n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
|
| 314 |
-
n_snapshots_5m = len(item.get("snapshots_5m", []))
|
| 315 |
-
n_holders = len(item.get("holder_snapshots_list", []))
|
| 316 |
-
|
| 317 |
-
tqdm.write(
|
| 318 |
-
f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
|
| 319 |
-
f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
|
| 320 |
-
f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
|
| 321 |
-
f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
|
| 322 |
-
f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
except Exception as e:
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
print(f"\n--- Caching Complete ---")
|
| 340 |
-
print(f"
|
| 341 |
-
print(f"
|
| 342 |
-
print(f"
|
|
|
|
|
|
|
| 343 |
print(f"Cache location: {output_dir.resolve()}")
|
| 344 |
|
| 345 |
finally:
|
|
|
|
| 194 |
parser.add_argument("--start_date", type=str, default=None, help="Start date (YYYY-MM-DD) for fetching new mints")
|
| 195 |
parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
|
| 196 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 197 |
+
|
| 198 |
+
# NEW: Context caching mode args
|
| 199 |
+
parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"],
|
| 200 |
+
help="Cache mode: 'raw' caches raw token data (old behavior), 'context' caches fully processed training contexts (new behavior)")
|
| 201 |
+
parser.add_argument("--context_length", type=int, default=8192,
|
| 202 |
+
help="Max sequence length for context caching mode. Triggers H/B/H dynamic sampling when events exceed this limit.")
|
| 203 |
+
parser.add_argument("--min_trades", type=int, default=10,
|
| 204 |
+
help="Minimum number of trades required for T_cutoff sampling. Tokens with fewer trades are skipped.")
|
| 205 |
+
parser.add_argument("--samples_per_token", type=int, default=1,
|
| 206 |
+
help="Number of different T_cutoff samples to generate per token in context mode.")
|
| 207 |
+
|
| 208 |
# DB Args
|
| 209 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 210 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
| 211 |
parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
|
| 212 |
parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
|
| 213 |
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
|
| 214 |
+
|
| 215 |
args = parser.parse_args()
|
| 216 |
|
| 217 |
output_dir = Path(args.output_dir)
|
|
|
|
| 250 |
ohlc_stats_path=args.ohlc_stats_path,
|
| 251 |
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
|
| 252 |
quantiles=[0.5],
|
| 253 |
+
min_trade_usd=args.min_trade_usd,
|
| 254 |
+
max_seq_len=args.context_length # Pass context_length for H/B/H threshold
|
| 255 |
)
|
| 256 |
|
| 257 |
if len(dataset) == 0:
|
|
|
|
| 273 |
if len(dataset) == 0:
|
| 274 |
print("WARNING: No tokens remain after filtering by return_class_map.")
|
| 275 |
return
|
| 276 |
+
|
| 277 |
+
# --- 3. Iterate and cache based on mode ---
|
| 278 |
+
print(f"INFO: Cache mode: {args.cache_mode}")
|
| 279 |
+
print(f"INFO: Starting to generate and cache from {len(dataset)} tokens...")
|
| 280 |
+
|
| 281 |
skipped_count = 0
|
| 282 |
cached_count = 0
|
| 283 |
+
global_sample_idx = 0 # Global counter for unique sample filenames
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
# Track class distribution for balanced sampling metadata
|
| 286 |
+
class_distribution = {}
|
| 287 |
+
|
| 288 |
+
if args.cache_mode == "context":
|
| 289 |
+
# =========================================================================
|
| 290 |
+
# CONTEXT MODE: Cache fully processed training contexts
|
| 291 |
+
# - Samples T_cutoff during caching (non-deterministic moved to cache time)
|
| 292 |
+
# - Applies H/B/H dynamic sampling based on context_length
|
| 293 |
+
# - Avoids caching tokens that won't be seen (garbage filtered out)
|
| 294 |
+
# - Training becomes fully deterministic (just loads cached contexts)
|
| 295 |
+
# =========================================================================
|
| 296 |
+
print(f"INFO: Context mode settings:")
|
| 297 |
+
print(f" - context_length (H/B/H threshold): {args.context_length}")
|
| 298 |
+
print(f" - min_trades (T_cutoff threshold): {args.min_trades}")
|
| 299 |
+
print(f" - samples_per_token: {args.samples_per_token}")
|
| 300 |
+
|
| 301 |
+
for i in tqdm(range(len(dataset)), desc="Caching contexts"):
|
| 302 |
+
mint_addr = dataset.sampled_mints[i]['mint_address']
|
| 303 |
+
class_id = return_class_map[mint_addr]
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
# Generate multiple training contexts per token
|
| 307 |
+
contexts = dataset.__cacheitem_context__(i, num_samples_per_token=args.samples_per_token)
|
| 308 |
+
|
| 309 |
+
if not contexts:
|
| 310 |
+
skipped_count += 1
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
# Require quality score
|
| 314 |
+
if mint_addr not in quality_scores_map:
|
| 315 |
+
reason = quality_missing_reason(mint_addr)
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
f"Missing quality score for mint {mint_addr}. Reason: {reason}."
|
| 318 |
+
)
|
| 319 |
+
q_score = quality_scores_map[mint_addr]
|
| 320 |
+
|
| 321 |
+
# Save each context as a separate sample
|
| 322 |
+
for ctx in contexts:
|
| 323 |
+
ctx["quality_score"] = q_score
|
| 324 |
+
ctx["class_id"] = class_id
|
| 325 |
+
ctx["source_token"] = mint_addr # Track origin for debugging
|
| 326 |
+
ctx["cache_mode"] = "context"
|
| 327 |
+
|
| 328 |
+
filename = f"sample_{global_sample_idx}.pt"
|
| 329 |
+
output_path = output_dir / filename
|
| 330 |
+
torch.save(ctx, output_path)
|
| 331 |
+
|
| 332 |
+
# Track class distribution
|
| 333 |
+
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
|
| 334 |
+
|
| 335 |
+
global_sample_idx += 1
|
| 336 |
+
cached_count += 1
|
| 337 |
+
|
| 338 |
+
n_events = len(contexts[0].get("event_sequence", [])) if contexts else 0
|
| 339 |
+
tqdm.write(
|
| 340 |
+
f" + Cached {len(contexts)} contexts: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | Events: {n_events}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
except Exception as e:
|
| 344 |
+
error_msg = str(e)
|
| 345 |
+
if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
|
| 346 |
+
print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
|
| 347 |
+
sys.exit(1)
|
| 348 |
+
|
| 349 |
+
print(f"\nERROR: Failed to cache contexts for {mint_addr}. Error: {e}", file=sys.stderr)
|
| 350 |
+
import traceback
|
| 351 |
+
traceback.print_exc()
|
| 352 |
skipped_count += 1
|
| 353 |
continue
|
| 354 |
+
|
| 355 |
+
else:
|
| 356 |
+
# =========================================================================
|
| 357 |
+
# RAW MODE: Cache raw token data (original behavior)
|
| 358 |
+
# - T_cutoff sampling happens at runtime
|
| 359 |
+
# - H/B/H applied at runtime
|
| 360 |
+
# - Non-deterministic training
|
| 361 |
+
# =========================================================================
|
| 362 |
+
for i in tqdm(range(len(dataset)), desc="Caching raw samples"):
|
| 363 |
+
mint_addr = dataset.sampled_mints[i]['mint_address']
|
| 364 |
+
class_id = return_class_map[mint_addr]
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
item = dataset.__cacheitem__(i)
|
| 368 |
+
if item is None:
|
| 369 |
+
skipped_count += 1
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
if mint_addr not in quality_scores_map:
|
| 373 |
+
reason = quality_missing_reason(mint_addr)
|
| 374 |
+
raise RuntimeError(
|
| 375 |
+
f"Missing quality score for mint {mint_addr}. Reason: {reason}."
|
| 376 |
+
)
|
| 377 |
+
q_score = quality_scores_map[mint_addr]
|
| 378 |
+
|
| 379 |
+
item["quality_score"] = q_score
|
| 380 |
+
item["class_id"] = class_id
|
| 381 |
+
item["cache_mode"] = "raw"
|
| 382 |
+
|
| 383 |
+
filename = f"sample_{i}.pt"
|
| 384 |
+
output_path = output_dir / filename
|
| 385 |
+
torch.save(item, output_path)
|
| 386 |
+
|
| 387 |
+
# Track class distribution
|
| 388 |
+
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
|
| 389 |
+
|
| 390 |
+
cached_count += 1
|
| 391 |
+
|
| 392 |
+
n_trades = len(item.get("trades", []))
|
| 393 |
+
n_transfers = len(item.get("transfers", []))
|
| 394 |
+
n_pool_creations = len(item.get("pool_creations", []))
|
| 395 |
+
n_liquidity_changes = len(item.get("liquidity_changes", []))
|
| 396 |
+
n_fee_collections = len(item.get("fee_collections", []))
|
| 397 |
+
n_burns = len(item.get("burns", []))
|
| 398 |
+
n_supply_locks = len(item.get("supply_locks", []))
|
| 399 |
+
n_migrations = len(item.get("migrations", []))
|
| 400 |
+
n_mints = 1 if item.get("mint_timestamp") else 0
|
| 401 |
+
n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
|
| 402 |
+
n_snapshots_5m = len(item.get("snapshots_5m", []))
|
| 403 |
+
n_holders = len(item.get("holder_snapshots_list", []))
|
| 404 |
+
|
| 405 |
+
tqdm.write(
|
| 406 |
+
f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
|
| 407 |
+
f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
|
| 408 |
+
f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
|
| 409 |
+
f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
|
| 410 |
+
f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
|
| 411 |
)
|
| 412 |
+
|
| 413 |
+
except Exception as e:
|
| 414 |
+
error_msg = str(e)
|
| 415 |
+
if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
|
| 416 |
+
print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
|
| 417 |
+
sys.exit(1)
|
| 418 |
+
|
| 419 |
+
print(f"\nERROR: Failed to cache sample {i} for {mint_addr}. Error: {e}", file=sys.stderr)
|
| 420 |
+
import traceback
|
| 421 |
+
traceback.print_exc()
|
| 422 |
+
skipped_count += 1
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
# --- Save class metadata for balanced sampling ---
|
| 426 |
+
# Build file_class_map for the metadata cache
|
| 427 |
+
file_class_map = {}
|
| 428 |
+
for sample_file in sorted(output_dir.glob("sample_*.pt")):
|
| 429 |
+
try:
|
| 430 |
+
sample_data = torch.load(sample_file, map_location="cpu", weights_only=False)
|
| 431 |
+
file_class_map[sample_file.name] = sample_data.get("class_id", 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
except Exception as e:
|
| 433 |
+
print(f"WARN: Could not read class_id from {sample_file.name}: {e}")
|
| 434 |
+
|
| 435 |
+
metadata_path = output_dir / "class_metadata.json"
|
| 436 |
+
try:
|
| 437 |
+
with open(metadata_path, 'w') as f:
|
| 438 |
+
json.dump({
|
| 439 |
+
'file_class_map': file_class_map,
|
| 440 |
+
'class_distribution': class_distribution,
|
| 441 |
+
'cache_mode': args.cache_mode,
|
| 442 |
+
'context_length': args.context_length if args.cache_mode == "context" else None,
|
| 443 |
+
'min_trades': args.min_trades if args.cache_mode == "context" else None,
|
| 444 |
+
'samples_per_token': args.samples_per_token if args.cache_mode == "context" else None,
|
| 445 |
+
}, f, indent=2)
|
| 446 |
+
print(f"INFO: Saved class metadata to {metadata_path}")
|
| 447 |
+
except Exception as e:
|
| 448 |
+
print(f"WARN: Failed to save class metadata: {e}")
|
| 449 |
+
|
| 450 |
print(f"\n--- Caching Complete ---")
|
| 451 |
+
print(f"Cache mode: {args.cache_mode}")
|
| 452 |
+
print(f"Successfully cached: {cached_count} samples.")
|
| 453 |
+
print(f"Filtered (Invalid/High Return): {filtered_count} tokens.")
|
| 454 |
+
print(f"Skipped (Errors/Empty): {skipped_count} tokens.")
|
| 455 |
+
print(f"Class distribution: {class_distribution}")
|
| 456 |
print(f"Cache location: {output_dir.resolve()}")
|
| 457 |
|
| 458 |
finally:
|