Upload folder using huggingface_hub
Browse files- collator_dump.json +0 -0
- data/data_collator.py +24 -6
- data/data_fetcher.py +39 -0
- data/data_loader.py +344 -61
- data/ohlc_stats.npz +1 -1
- database.sh +2 -0
- log.log +2 -2
- models/helper_encoders.py +6 -5
- pre_cache.sh +1 -3
- scripts/inspect_collator.py +227 -0
- train.py +26 -6
collator_dump.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/data_collator.py
CHANGED
|
@@ -276,11 +276,25 @@ class MemecoinCollator:
|
|
| 276 |
unique_wallets_data.update(item.get('wallets', {}))
|
| 277 |
unique_tokens_data.update(item.get('tokens', {}))
|
| 278 |
|
| 279 |
-
# Create mappings needed for indexing
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 286 |
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
|
@@ -297,7 +311,8 @@ class MemecoinCollator:
|
|
| 297 |
|
| 298 |
# Initialize sequence tensors
|
| 299 |
event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
|
| 300 |
-
|
|
|
|
| 301 |
# Store relative_ts in float32 for stability; model will scale/log/normalize
|
| 302 |
relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
|
| 303 |
attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
|
|
@@ -601,6 +616,9 @@ class MemecoinCollator:
|
|
| 601 |
]
|
| 602 |
boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 603 |
|
|
|
|
|
|
|
|
|
|
| 604 |
elif event_type == 'HolderSnapshot':
|
| 605 |
# --- FIXED: Store raw holder data, not an index ---
|
| 606 |
raw_holders = event.get('holders', [])
|
|
|
|
| 276 |
unique_wallets_data.update(item.get('wallets', {}))
|
| 277 |
unique_tokens_data.update(item.get('tokens', {}))
|
| 278 |
|
| 279 |
+
# Create mappings needed for indexing (use dict keys as source of truth)
|
| 280 |
+
wallet_items = list(unique_wallets_data.items())
|
| 281 |
+
token_items = list(unique_tokens_data.items())
|
| 282 |
+
|
| 283 |
+
wallet_list_data = []
|
| 284 |
+
for addr, feat in wallet_items:
|
| 285 |
+
profile = feat.get('profile', {})
|
| 286 |
+
if not profile.get('wallet_address'):
|
| 287 |
+
profile['wallet_address'] = addr
|
| 288 |
+
wallet_list_data.append(feat)
|
| 289 |
+
|
| 290 |
+
token_list_data = []
|
| 291 |
+
for addr, feat in token_items:
|
| 292 |
+
if not feat.get('address'):
|
| 293 |
+
feat['address'] = addr
|
| 294 |
+
token_list_data.append(feat)
|
| 295 |
+
|
| 296 |
+
wallet_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(wallet_items)}
|
| 297 |
+
token_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(token_items)}
|
| 298 |
|
| 299 |
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 300 |
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
|
|
|
| 311 |
|
| 312 |
# Initialize sequence tensors
|
| 313 |
event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
|
| 314 |
+
# Use float64 to preserve second-level precision for large Unix timestamps.
|
| 315 |
+
timestamps_float = torch.zeros((B, L), dtype=torch.float64, device=self.device)
|
| 316 |
# Store relative_ts in float32 for stability; model will scale/log/normalize
|
| 317 |
relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
|
| 318 |
attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
|
|
|
|
| 616 |
]
|
| 617 |
boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 618 |
|
| 619 |
+
elif event_type == 'Migrated':
|
| 620 |
+
migrated_protocol_ids[i, j] = event.get('protocol_id', 0)
|
| 621 |
+
|
| 622 |
elif event_type == 'HolderSnapshot':
|
| 623 |
# --- FIXED: Store raw holder data, not an index ---
|
| 624 |
raw_holders = event.get('holders', [])
|
data/data_fetcher.py
CHANGED
|
@@ -1014,6 +1014,45 @@ class DataFetcher:
|
|
| 1014 |
except Exception as e:
|
| 1015 |
print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
|
| 1016 |
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
def fetch_raw_token_data(
|
| 1018 |
self,
|
| 1019 |
token_address: str,
|
|
|
|
| 1014 |
except Exception as e:
|
| 1015 |
print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
|
| 1016 |
return 0
|
| 1017 |
+
|
| 1018 |
+
def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]:
|
| 1019 |
+
"""
|
| 1020 |
+
Fetch total holder count and top holders in a single query.
|
| 1021 |
+
Returns (count, top_holders_list).
|
| 1022 |
+
"""
|
| 1023 |
+
if not token_address:
|
| 1024 |
+
return 0, []
|
| 1025 |
+
query = """
|
| 1026 |
+
WITH point_in_time_holdings AS (
|
| 1027 |
+
SELECT
|
| 1028 |
+
wallet_address,
|
| 1029 |
+
argMax(current_balance, updated_at) AS bal
|
| 1030 |
+
FROM wallet_holdings
|
| 1031 |
+
WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
|
| 1032 |
+
GROUP BY wallet_address
|
| 1033 |
+
)
|
| 1034 |
+
SELECT
|
| 1035 |
+
(SELECT count() FROM point_in_time_holdings WHERE bal > 0) AS holder_count,
|
| 1036 |
+
(SELECT groupArray((wallet_address, bal))
|
| 1037 |
+
FROM (
|
| 1038 |
+
SELECT wallet_address, bal
|
| 1039 |
+
FROM point_in_time_holdings
|
| 1040 |
+
WHERE bal > 0
|
| 1041 |
+
ORDER BY bal DESC
|
| 1042 |
+
LIMIT %(limit)s
|
| 1043 |
+
)) AS top_holders
|
| 1044 |
+
"""
|
| 1045 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
|
| 1046 |
+
try:
|
| 1047 |
+
rows = self.db_client.execute(query, params)
|
| 1048 |
+
if not rows:
|
| 1049 |
+
return 0, []
|
| 1050 |
+
holder_count, top_holders = rows[0]
|
| 1051 |
+
top_list = [{'wallet_address': wa, 'current_balance': bal} for wa, bal in (top_holders or [])]
|
| 1052 |
+
return int(holder_count or 0), top_list
|
| 1053 |
+
except Exception as e:
|
| 1054 |
+
print(f"ERROR: Failed to fetch holder snapshot stats for token {token_address}: {e}")
|
| 1055 |
+
return 0, []
|
| 1056 |
def fetch_raw_token_data(
|
| 1057 |
self,
|
| 1058 |
token_address: str,
|
data/data_loader.py
CHANGED
|
@@ -61,6 +61,7 @@ MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
|
|
| 61 |
# Interval for HolderSnapshot events (seconds)
|
| 62 |
HOLDER_SNAPSHOT_INTERVAL_SEC = 300
|
| 63 |
HOLDER_SNAPSHOT_TOP_K = 200
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class EmbeddingPooler:
|
|
@@ -114,6 +115,7 @@ class OracleDataset(Dataset):
|
|
| 114 |
"""
|
| 115 |
def __init__(self,
|
| 116 |
data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
|
|
|
|
| 117 |
horizons_seconds: List[int] = [],
|
| 118 |
quantiles: List[float] = [],
|
| 119 |
max_samples: Optional[int] = None,
|
|
@@ -129,18 +131,11 @@ class OracleDataset(Dataset):
|
|
| 129 |
|
| 130 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 131 |
# Configure robust HTTP session
|
| 132 |
-
self.http_session =
|
| 133 |
-
|
| 134 |
-
total=3,
|
| 135 |
-
backoff_factor=1,
|
| 136 |
-
status_forcelist=[429, 500, 502, 503, 504],
|
| 137 |
-
allowed_methods=["HEAD", "GET", "OPTIONS"]
|
| 138 |
-
)
|
| 139 |
-
adapter = HTTPAdapter(max_retries=retry_strategy)
|
| 140 |
-
self.http_session.mount("http://", adapter)
|
| 141 |
-
self.http_session.mount("https://", adapter)
|
| 142 |
|
| 143 |
self.fetcher = data_fetcher
|
|
|
|
| 144 |
self.cache_dir = Path(cache_dir) if cache_dir else None
|
| 145 |
# Always define these so DataLoader workers don't crash with AttributeError if
|
| 146 |
# initialization falls through an unexpected branch.
|
|
@@ -271,6 +266,51 @@ class OracleDataset(Dataset):
|
|
| 271 |
print("INFO: No OHLC stats path provided. Using default normalization.")
|
| 272 |
|
| 273 |
self.min_trade_usd = min_trade_usd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
def __len__(self) -> int:
|
| 276 |
return self.num_samples
|
|
@@ -287,6 +327,16 @@ class OracleDataset(Dataset):
|
|
| 287 |
denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
|
| 288 |
return [(float(v) - self.ohlc_price_mean) / denom for v in values]
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 291 |
"""
|
| 292 |
Applies dynamic context sampling to fit events within max_seq_len.
|
|
@@ -482,8 +532,11 @@ class OracleDataset(Dataset):
|
|
| 482 |
holders_end = len(cached_holders_list[i])
|
| 483 |
elif self.fetcher:
|
| 484 |
cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
|
| 485 |
-
holder_records_ts = self.fetcher.
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
| 487 |
else:
|
| 488 |
holder_records_ts = []
|
| 489 |
holders_end = 0
|
|
@@ -784,6 +837,9 @@ class OracleDataset(Dataset):
|
|
| 784 |
# --- FIXED: Only add to pooler if data is valid ---
|
| 785 |
image = None
|
| 786 |
token_uri = data.get('token_uri')
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
# --- NEW: Use multiple IPFS gateways for reliability ---
|
| 789 |
if token_uri and isinstance(token_uri, str) and token_uri.strip():
|
|
@@ -841,11 +897,15 @@ class OracleDataset(Dataset):
|
|
| 841 |
else: # If all gateways fail for the image
|
| 842 |
raise RuntimeError(f"All IPFS gateways failed for image: {image_url}")
|
| 843 |
else: # Handle regular HTTP image URLs
|
|
|
|
|
|
|
| 844 |
image_resp = self.http_session.get(image_url, timeout=10)
|
| 845 |
image_resp.raise_for_status()
|
| 846 |
image = Image.open(BytesIO(image_resp.content))
|
| 847 |
except (requests.RequestException, ValueError, IOError) as e:
|
| 848 |
-
|
|
|
|
|
|
|
| 849 |
print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
|
| 850 |
image = None
|
| 851 |
|
|
@@ -954,6 +1014,9 @@ class OracleDataset(Dataset):
|
|
| 954 |
"""
|
| 955 |
Loads raw data from cache, samples a random T_cutoff, and generates a training sample.
|
| 956 |
"""
|
|
|
|
|
|
|
|
|
|
| 957 |
raw_data = None
|
| 958 |
if self.cache_dir:
|
| 959 |
if idx >= len(self.cached_files):
|
|
@@ -1024,8 +1087,8 @@ class OracleDataset(Dataset):
|
|
| 1024 |
# ============================================================================
|
| 1025 |
# 1. Use ALL trades (sorted by timestamp) for context
|
| 1026 |
# 2. Find indices of SUCCESSFUL trades (needed for label computation)
|
| 1027 |
-
# 3. Sample interval: [
|
| 1028 |
-
# 4. This guarantees:
|
| 1029 |
# ============================================================================
|
| 1030 |
|
| 1031 |
all_trades_raw = raw_data.get('trades', [])
|
|
@@ -1038,7 +1101,8 @@ class OracleDataset(Dataset):
|
|
| 1038 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 1039 |
)
|
| 1040 |
|
| 1041 |
-
|
|
|
|
| 1042 |
return None
|
| 1043 |
|
| 1044 |
# Find indices of SUCCESSFUL trades (valid for label computation)
|
|
@@ -1052,7 +1116,7 @@ class OracleDataset(Dataset):
|
|
| 1052 |
|
| 1053 |
max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
|
| 1054 |
# Define sampling interval
|
| 1055 |
-
min_idx =
|
| 1056 |
max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
|
| 1057 |
|
| 1058 |
if max_idx < min_idx:
|
|
@@ -1148,38 +1212,57 @@ class OracleDataset(Dataset):
|
|
| 1148 |
_add_wallet(holder.get('wallet_address'), wallets_to_fetch)
|
| 1149 |
|
| 1150 |
pooler = EmbeddingPooler()
|
| 1151 |
-
#
|
| 1152 |
-
|
| 1153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1154 |
if not main_token_data:
|
| 1155 |
return None
|
| 1156 |
|
| 1157 |
-
#
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1175 |
|
|
|
|
| 1176 |
graph_entities = {}
|
| 1177 |
graph_links = {}
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
|
| 1184 |
# Generate the item
|
| 1185 |
return self._generate_dataset_item(
|
|
@@ -1244,7 +1327,7 @@ class OracleDataset(Dataset):
|
|
| 1244 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1245 |
include_wallet_data=False,
|
| 1246 |
include_graph=False,
|
| 1247 |
-
min_trades=
|
| 1248 |
full_history=True, # Bypass H/B/H limits
|
| 1249 |
prune_failed=False, # Keep failed trades for realistic simulation
|
| 1250 |
prune_transfers=False # Keep transfers for snapshot reconstruction
|
|
@@ -1350,10 +1433,12 @@ class OracleDataset(Dataset):
|
|
| 1350 |
# Time is end of bucket
|
| 1351 |
snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
|
| 1352 |
|
| 1353 |
-
# These queries can be slow
|
| 1354 |
-
count = self.fetcher.
|
| 1355 |
-
|
| 1356 |
-
|
|
|
|
|
|
|
| 1357 |
|
| 1358 |
total_supply = raw_data.get('total_supply', 0) or 1
|
| 1359 |
if raw_data.get('decimals'):
|
|
@@ -1470,6 +1555,7 @@ class OracleDataset(Dataset):
|
|
| 1470 |
|
| 1471 |
# 3. Process Trades (Events + Chart)
|
| 1472 |
trade_events = []
|
|
|
|
| 1473 |
aggregation_trades = []
|
| 1474 |
high_def_chart_trades = []
|
| 1475 |
middle_chart_trades = []
|
|
@@ -1563,6 +1649,40 @@ class OracleDataset(Dataset):
|
|
| 1563 |
_register_event(trade_event, trade_sort_key)
|
| 1564 |
trade_events.append(trade_event)
|
| 1565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
# 4. Generate Chart Events
|
| 1567 |
def _finalize_chart(t_list):
|
| 1568 |
t_list.sort(key=lambda x: x['sort_key'])
|
|
@@ -1625,31 +1745,194 @@ class OracleDataset(Dataset):
|
|
| 1625 |
chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd", precomputed_ohlc=ohlc_1s_precomputed))
|
| 1626 |
chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
|
| 1627 |
|
| 1628 |
-
# 5. Process Other Records (Pool, Liquidity,
|
| 1629 |
-
|
| 1630 |
-
# For simplicity, assuming these records are already processed or we add the logic here.
|
| 1631 |
-
# Given the space constraint, I'll add a simplified pass for pool creation.
|
| 1632 |
-
# Ideally we refactor this into helper methods too.
|
| 1633 |
-
|
| 1634 |
for pool_record in pool_creation_records:
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1638 |
pool_event = {
|
| 1639 |
'event_type': 'PoolCreated',
|
| 1640 |
'timestamp': pool_ts,
|
| 1641 |
-
'relative_ts':
|
| 1642 |
'wallet_address': pool_record.get('creator_address'),
|
| 1643 |
'token_address': token_address,
|
| 1644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1645 |
}
|
| 1646 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1647 |
|
| 1648 |
# 6. Generate Snapshots
|
| 1649 |
self._generate_onchain_snapshots(
|
| 1650 |
token_address, int(t0_timestamp), T_cutoff,
|
| 1651 |
300, # Interval
|
| 1652 |
-
trade_events,
|
| 1653 |
aggregation_trades,
|
| 1654 |
wallet_data,
|
| 1655 |
total_supply_dec,
|
|
|
|
| 61 |
# Interval for HolderSnapshot events (seconds)
|
| 62 |
HOLDER_SNAPSHOT_INTERVAL_SEC = 300
|
| 63 |
HOLDER_SNAPSHOT_TOP_K = 200
|
| 64 |
+
DEAD_URI_RETRY_LIMIT = 2
|
| 65 |
|
| 66 |
|
| 67 |
class EmbeddingPooler:
|
|
|
|
| 115 |
"""
|
| 116 |
def __init__(self,
|
| 117 |
data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
|
| 118 |
+
fetcher_config: Optional[Dict[str, Any]] = None,
|
| 119 |
horizons_seconds: List[int] = [],
|
| 120 |
quantiles: List[float] = [],
|
| 121 |
max_samples: Optional[int] = None,
|
|
|
|
| 131 |
|
| 132 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 133 |
# Configure robust HTTP session
|
| 134 |
+
self.http_session = None
|
| 135 |
+
self._init_http_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
self.fetcher = data_fetcher
|
| 138 |
+
self.fetcher_config = fetcher_config
|
| 139 |
self.cache_dir = Path(cache_dir) if cache_dir else None
|
| 140 |
# Always define these so DataLoader workers don't crash with AttributeError if
|
| 141 |
# initialization falls through an unexpected branch.
|
|
|
|
| 266 |
print("INFO: No OHLC stats path provided. Using default normalization.")
|
| 267 |
|
| 268 |
self.min_trade_usd = min_trade_usd
|
| 269 |
+
self._uri_fail_counts: Dict[str, int] = {}
|
| 270 |
+
|
| 271 |
+
def _init_http_session(self) -> None:
|
| 272 |
+
# Configure robust HTTP session
|
| 273 |
+
self.http_session = requests.Session()
|
| 274 |
+
retry_strategy = Retry(
|
| 275 |
+
total=3,
|
| 276 |
+
backoff_factor=1,
|
| 277 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
| 278 |
+
allowed_methods=["HEAD", "GET", "OPTIONS"]
|
| 279 |
+
)
|
| 280 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
| 281 |
+
self.http_session.mount("http://", adapter)
|
| 282 |
+
self.http_session.mount("https://", adapter)
|
| 283 |
+
|
| 284 |
+
def init_fetcher(self) -> None:
|
| 285 |
+
"""
|
| 286 |
+
Initialize DataFetcher from stored config (for DataLoader workers).
|
| 287 |
+
"""
|
| 288 |
+
if self.fetcher is not None or not self.fetcher_config:
|
| 289 |
+
return
|
| 290 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 291 |
+
from neo4j import GraphDatabase
|
| 292 |
+
cfg = self.fetcher_config
|
| 293 |
+
clickhouse_client = ClickHouseClient(
|
| 294 |
+
host=cfg.get("clickhouse_host", "localhost"),
|
| 295 |
+
port=int(cfg.get("clickhouse_port", 9000)),
|
| 296 |
+
)
|
| 297 |
+
neo4j_driver = GraphDatabase.driver(
|
| 298 |
+
cfg.get("neo4j_uri", "bolt://localhost:7687"),
|
| 299 |
+
auth=(cfg.get("neo4j_user", "neo4j"), cfg.get("neo4j_password", "password"))
|
| 300 |
+
)
|
| 301 |
+
self.fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 302 |
+
|
| 303 |
+
def __getstate__(self):
|
| 304 |
+
state = self.__dict__.copy()
|
| 305 |
+
# Drop non-pickleable objects
|
| 306 |
+
state["fetcher"] = None
|
| 307 |
+
state["http_session"] = None
|
| 308 |
+
return state
|
| 309 |
+
|
| 310 |
+
def __setstate__(self, state):
|
| 311 |
+
self.__dict__.update(state)
|
| 312 |
+
if self.http_session is None:
|
| 313 |
+
self._init_http_session()
|
| 314 |
|
| 315 |
def __len__(self) -> int:
|
| 316 |
return self.num_samples
|
|
|
|
| 327 |
denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
|
| 328 |
return [(float(v) - self.ohlc_price_mean) / denom for v in values]
|
| 329 |
|
| 330 |
+
def _is_dead_uri(self, uri: Optional[str]) -> bool:
|
| 331 |
+
if not uri:
|
| 332 |
+
return False
|
| 333 |
+
return self._uri_fail_counts.get(uri, 0) >= DEAD_URI_RETRY_LIMIT
|
| 334 |
+
|
| 335 |
+
def _mark_uri_failure(self, uri: Optional[str]) -> None:
|
| 336 |
+
if not uri:
|
| 337 |
+
return
|
| 338 |
+
self._uri_fail_counts[uri] = self._uri_fail_counts.get(uri, 0) + 1
|
| 339 |
+
|
| 340 |
def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 341 |
"""
|
| 342 |
Applies dynamic context sampling to fit events within max_seq_len.
|
|
|
|
| 532 |
holders_end = len(cached_holders_list[i])
|
| 533 |
elif self.fetcher:
|
| 534 |
cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
|
| 535 |
+
holders_end, holder_records_ts = self.fetcher.fetch_holder_snapshot_stats_for_token(
|
| 536 |
+
token_address,
|
| 537 |
+
cutoff_dt_ts,
|
| 538 |
+
limit=HOLDER_SNAPSHOT_TOP_K
|
| 539 |
+
)
|
| 540 |
else:
|
| 541 |
holder_records_ts = []
|
| 542 |
holders_end = 0
|
|
|
|
| 837 |
# --- FIXED: Only add to pooler if data is valid ---
|
| 838 |
image = None
|
| 839 |
token_uri = data.get('token_uri')
|
| 840 |
+
if self._is_dead_uri(token_uri):
|
| 841 |
+
image = None
|
| 842 |
+
token_uri = None
|
| 843 |
|
| 844 |
# --- NEW: Use multiple IPFS gateways for reliability ---
|
| 845 |
if token_uri and isinstance(token_uri, str) and token_uri.strip():
|
|
|
|
| 897 |
else: # If all gateways fail for the image
|
| 898 |
raise RuntimeError(f"All IPFS gateways failed for image: {image_url}")
|
| 899 |
else: # Handle regular HTTP image URLs
|
| 900 |
+
if self._is_dead_uri(image_url):
|
| 901 |
+
raise requests.RequestException("Skipping dead image URI after repeated failures.")
|
| 902 |
image_resp = self.http_session.get(image_url, timeout=10)
|
| 903 |
image_resp.raise_for_status()
|
| 904 |
image = Image.open(BytesIO(image_resp.content))
|
| 905 |
except (requests.RequestException, ValueError, IOError) as e:
|
| 906 |
+
self._mark_uri_failure(token_uri)
|
| 907 |
+
if isinstance(metadata.get('image') if 'metadata' in locals() and isinstance(metadata, dict) else None, str):
|
| 908 |
+
self._mark_uri_failure(metadata.get('image'))
|
| 909 |
print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
|
| 910 |
image = None
|
| 911 |
|
|
|
|
| 1014 |
"""
|
| 1015 |
Loads raw data from cache, samples a random T_cutoff, and generates a training sample.
|
| 1016 |
"""
|
| 1017 |
+
if self.fetcher is None and self.fetcher_config:
|
| 1018 |
+
# Lazy init in main or worker if not initialized.
|
| 1019 |
+
self.init_fetcher()
|
| 1020 |
raw_data = None
|
| 1021 |
if self.cache_dir:
|
| 1022 |
if idx >= len(self.cached_files):
|
|
|
|
| 1087 |
# ============================================================================
|
| 1088 |
# 1. Use ALL trades (sorted by timestamp) for context
|
| 1089 |
# 2. Find indices of SUCCESSFUL trades (needed for label computation)
|
| 1090 |
+
# 3. Sample interval: [min_context_trades-1, last_successful_idx - 1]
|
| 1091 |
+
# 4. This guarantees: N trades for context, 1+ successful trade for labels
|
| 1092 |
# ============================================================================
|
| 1093 |
|
| 1094 |
all_trades_raw = raw_data.get('trades', [])
|
|
|
|
| 1101 |
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 1102 |
)
|
| 1103 |
|
| 1104 |
+
min_context_trades = 10
|
| 1105 |
+
if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
|
| 1106 |
return None
|
| 1107 |
|
| 1108 |
# Find indices of SUCCESSFUL trades (valid for label computation)
|
|
|
|
| 1116 |
|
| 1117 |
max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
|
| 1118 |
# Define sampling interval
|
| 1119 |
+
min_idx = min_context_trades - 1 # At least N trades for context
|
| 1120 |
max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
|
| 1121 |
|
| 1122 |
if max_idx < min_idx:
|
|
|
|
| 1212 |
_add_wallet(holder.get('wallet_address'), wallets_to_fetch)
|
| 1213 |
|
| 1214 |
pooler = EmbeddingPooler()
|
| 1215 |
+
# Token data: fetch time-aware data when fetcher is available.
|
| 1216 |
+
if self.fetcher:
|
| 1217 |
+
main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=None)
|
| 1218 |
+
# Fallback to cached raw data if DB returned nothing
|
| 1219 |
+
if not main_token_data:
|
| 1220 |
+
offline_token_data = {token_address: raw_data} # raw_data contains token metadata at root
|
| 1221 |
+
main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
|
| 1222 |
+
else:
|
| 1223 |
+
offline_token_data = {token_address: raw_data} # raw_data contains token metadata at root
|
| 1224 |
+
main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
|
| 1225 |
if not main_token_data:
|
| 1226 |
return None
|
| 1227 |
|
| 1228 |
+
# Wallet data: fetch time-aware data when fetcher is available.
|
| 1229 |
+
if self.fetcher:
|
| 1230 |
+
wallet_data, all_token_data = self._process_wallet_data(
|
| 1231 |
+
list(wallets_to_fetch),
|
| 1232 |
+
main_token_data.copy(),
|
| 1233 |
+
pooler,
|
| 1234 |
+
T_cutoff,
|
| 1235 |
+
profiles_override=None,
|
| 1236 |
+
socials_override=None,
|
| 1237 |
+
holdings_override=None
|
| 1238 |
+
)
|
| 1239 |
+
else:
|
| 1240 |
+
cached_social_bundle = raw_data.get('socials', {})
|
| 1241 |
+
offline_profiles = cached_social_bundle.get('profiles', {})
|
| 1242 |
+
offline_socials = cached_social_bundle.get('socials', {})
|
| 1243 |
+
offline_holdings = {} # Holdings not cached usually due to size
|
| 1244 |
+
wallet_data, all_token_data = self._process_wallet_data(
|
| 1245 |
+
list(wallets_to_fetch),
|
| 1246 |
+
main_token_data.copy(),
|
| 1247 |
+
pooler,
|
| 1248 |
+
T_cutoff,
|
| 1249 |
+
profiles_override=offline_profiles,
|
| 1250 |
+
socials_override=offline_socials,
|
| 1251 |
+
holdings_override=offline_holdings
|
| 1252 |
+
)
|
| 1253 |
|
| 1254 |
+
# Graph links: fetch time-aware graph when fetcher is available.
|
| 1255 |
graph_entities = {}
|
| 1256 |
graph_links = {}
|
| 1257 |
+
if self.fetcher and wallets_to_fetch:
|
| 1258 |
+
try:
|
| 1259 |
+
graph_entities, graph_links = self.fetcher.fetch_graph_links(
|
| 1260 |
+
list(wallets_to_fetch),
|
| 1261 |
+
T_cutoff,
|
| 1262 |
+
max_degrees=1
|
| 1263 |
+
)
|
| 1264 |
+
except Exception as e:
|
| 1265 |
+
print(f"ERROR: Failed to fetch graph links for {token_address}: {e}")
|
| 1266 |
|
| 1267 |
# Generate the item
|
| 1268 |
return self._generate_dataset_item(
|
|
|
|
| 1327 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1328 |
include_wallet_data=False,
|
| 1329 |
include_graph=False,
|
| 1330 |
+
min_trades=10, # Enforce min trades for context
|
| 1331 |
full_history=True, # Bypass H/B/H limits
|
| 1332 |
prune_failed=False, # Keep failed trades for realistic simulation
|
| 1333 |
prune_transfers=False # Keep transfers for snapshot reconstruction
|
|
|
|
| 1433 |
# Time is end of bucket
|
| 1434 |
snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
|
| 1435 |
|
| 1436 |
+
# These queries can be slow; use single-call combined stats.
|
| 1437 |
+
count, top_holders = self.fetcher.fetch_holder_snapshot_stats_for_token(
|
| 1438 |
+
token_address,
|
| 1439 |
+
snapshot_ts,
|
| 1440 |
+
limit=HOLDER_SNAPSHOT_TOP_K
|
| 1441 |
+
)
|
| 1442 |
|
| 1443 |
total_supply = raw_data.get('total_supply', 0) or 1
|
| 1444 |
if raw_data.get('decimals'):
|
|
|
|
| 1555 |
|
| 1556 |
# 3. Process Trades (Events + Chart)
|
| 1557 |
trade_events = []
|
| 1558 |
+
transfer_events = []
|
| 1559 |
aggregation_trades = []
|
| 1560 |
high_def_chart_trades = []
|
| 1561 |
middle_chart_trades = []
|
|
|
|
| 1649 |
_register_event(trade_event, trade_sort_key)
|
| 1650 |
trade_events.append(trade_event)
|
| 1651 |
|
| 1652 |
+
# 3b. Process Transfers
|
| 1653 |
+
for transfer in transfer_records:
|
| 1654 |
+
transfer_ts_val = _timestamp_to_order_value(transfer.get('timestamp'))
|
| 1655 |
+
transfer_ts_int = int(transfer_ts_val)
|
| 1656 |
+
amount_dec = float(transfer.get('amount_decimal', 0.0) or 0.0)
|
| 1657 |
+
source_balance = float(transfer.get('source_balance', 0.0) or 0.0)
|
| 1658 |
+
denom = source_balance + amount_dec if source_balance > 0 else 0.0
|
| 1659 |
+
transfer_pct_of_holding = (amount_dec / denom) if denom > 1e-9 else 0.0
|
| 1660 |
+
transfer_pct_of_supply = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
|
| 1661 |
+
is_large_transfer = transfer_pct_of_supply >= LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD
|
| 1662 |
+
|
| 1663 |
+
transfer_event = {
|
| 1664 |
+
'event_type': 'LargeTransfer' if is_large_transfer else 'Transfer',
|
| 1665 |
+
'timestamp': transfer_ts_int,
|
| 1666 |
+
'relative_ts': transfer_ts_val - t0_timestamp,
|
| 1667 |
+
'wallet_address': transfer.get('source'),
|
| 1668 |
+
'destination_wallet_address': transfer.get('destination'),
|
| 1669 |
+
'token_address': token_address,
|
| 1670 |
+
'token_amount': amount_dec,
|
| 1671 |
+
'transfer_pct_of_total_supply': transfer_pct_of_supply,
|
| 1672 |
+
'transfer_pct_of_holding': transfer_pct_of_holding,
|
| 1673 |
+
'priority_fee': transfer.get('priority_fee', 0.0),
|
| 1674 |
+
'success': transfer.get('success', False)
|
| 1675 |
+
}
|
| 1676 |
+
_register_event(
|
| 1677 |
+
transfer_event,
|
| 1678 |
+
_event_execution_sort_key(
|
| 1679 |
+
transfer.get('timestamp'),
|
| 1680 |
+
slot=transfer.get('slot', 0),
|
| 1681 |
+
signature=transfer.get('signature', '')
|
| 1682 |
+
)
|
| 1683 |
+
)
|
| 1684 |
+
transfer_events.append(transfer_event)
|
| 1685 |
+
|
| 1686 |
# 4. Generate Chart Events
|
| 1687 |
def _finalize_chart(t_list):
|
| 1688 |
t_list.sort(key=lambda x: x['sort_key'])
|
|
|
|
| 1745 |
chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd", precomputed_ohlc=ohlc_1s_precomputed))
|
| 1746 |
chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
|
| 1747 |
|
| 1748 |
+
# 5. Process Other Records (Pool, Liquidity, Fees, Burns, Locks, Migrations)
|
| 1749 |
+
pool_meta_by_address = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1750 |
for pool_record in pool_creation_records:
|
| 1751 |
+
pool_addr = pool_record.get('pool_address')
|
| 1752 |
+
if pool_addr:
|
| 1753 |
+
pool_meta_by_address[pool_addr] = pool_record
|
| 1754 |
+
|
| 1755 |
+
pool_ts_val = _timestamp_to_order_value(pool_record.get('timestamp'))
|
| 1756 |
+
pool_ts = int(pool_ts_val)
|
| 1757 |
+
base_decimals = pool_record.get('base_decimals')
|
| 1758 |
+
quote_decimals = pool_record.get('quote_decimals')
|
| 1759 |
+
base_decimals = int(base_decimals) if base_decimals is not None else 0
|
| 1760 |
+
quote_decimals = int(quote_decimals) if quote_decimals is not None else 0
|
| 1761 |
+
|
| 1762 |
+
base_amount_raw = pool_record.get('initial_base_liquidity', 0) or 0
|
| 1763 |
+
quote_amount_raw = pool_record.get('initial_quote_liquidity', 0) or 0
|
| 1764 |
+
base_amount = float(base_amount_raw) / (10 ** base_decimals) if base_decimals > 0 else float(base_amount_raw)
|
| 1765 |
+
quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw)
|
| 1766 |
+
|
| 1767 |
pool_event = {
|
| 1768 |
'event_type': 'PoolCreated',
|
| 1769 |
'timestamp': pool_ts,
|
| 1770 |
+
'relative_ts': pool_ts_val - t0_timestamp,
|
| 1771 |
'wallet_address': pool_record.get('creator_address'),
|
| 1772 |
'token_address': token_address,
|
| 1773 |
+
'quote_token_address': pool_record.get('quote_address'),
|
| 1774 |
+
'protocol_id': pool_record.get('protocol', 0),
|
| 1775 |
+
'pool_address': pool_addr,
|
| 1776 |
+
'base_amount': base_amount,
|
| 1777 |
+
'quote_amount': quote_amount,
|
| 1778 |
+
'priority_fee': pool_record.get('priority_fee', 0.0),
|
| 1779 |
+
'success': pool_record.get('success', False)
|
| 1780 |
+
}
|
| 1781 |
+
_register_event(
|
| 1782 |
+
pool_event,
|
| 1783 |
+
_event_execution_sort_key(
|
| 1784 |
+
pool_record.get('timestamp'),
|
| 1785 |
+
slot=pool_record.get('slot', 0),
|
| 1786 |
+
signature=pool_record.get('signature', '')
|
| 1787 |
+
)
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
for liq_record in liquidity_change_records:
|
| 1791 |
+
liq_ts_val = _timestamp_to_order_value(liq_record.get('timestamp'))
|
| 1792 |
+
liq_ts = int(liq_ts_val)
|
| 1793 |
+
pool_addr = liq_record.get('pool_address')
|
| 1794 |
+
pool_meta = pool_meta_by_address.get(pool_addr, {})
|
| 1795 |
+
quote_decimals = pool_meta.get('quote_decimals')
|
| 1796 |
+
quote_decimals = int(quote_decimals) if quote_decimals is not None else 0
|
| 1797 |
+
|
| 1798 |
+
quote_amount_raw = liq_record.get('quote_amount', 0) or 0
|
| 1799 |
+
quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw)
|
| 1800 |
+
|
| 1801 |
+
liq_event = {
|
| 1802 |
+
'event_type': 'LiquidityChange',
|
| 1803 |
+
'timestamp': liq_ts,
|
| 1804 |
+
'relative_ts': liq_ts_val - t0_timestamp,
|
| 1805 |
+
'wallet_address': liq_record.get('lp_provider'),
|
| 1806 |
+
'token_address': token_address,
|
| 1807 |
+
'quote_token_address': pool_meta.get('quote_address'),
|
| 1808 |
+
'protocol_id': liq_record.get('protocol', 0),
|
| 1809 |
+
'change_type_id': liq_record.get('change_type', 0),
|
| 1810 |
+
'quote_amount': quote_amount,
|
| 1811 |
+
'priority_fee': liq_record.get('priority_fee', 0.0),
|
| 1812 |
+
'success': liq_record.get('success', False)
|
| 1813 |
}
|
| 1814 |
+
_register_event(
|
| 1815 |
+
liq_event,
|
| 1816 |
+
_event_execution_sort_key(
|
| 1817 |
+
liq_record.get('timestamp'),
|
| 1818 |
+
slot=liq_record.get('slot', 0),
|
| 1819 |
+
signature=liq_record.get('signature', '')
|
| 1820 |
+
)
|
| 1821 |
+
)
|
| 1822 |
+
|
| 1823 |
+
for fee_record in fee_collection_records:
|
| 1824 |
+
fee_ts_val = _timestamp_to_order_value(fee_record.get('timestamp'))
|
| 1825 |
+
fee_ts = int(fee_ts_val)
|
| 1826 |
+
amount = 0.0
|
| 1827 |
+
if fee_record.get('token_0_mint_address') == token_address:
|
| 1828 |
+
amount = float(fee_record.get('token_0_amount', 0.0) or 0.0)
|
| 1829 |
+
elif fee_record.get('token_1_mint_address') == token_address:
|
| 1830 |
+
amount = float(fee_record.get('token_1_amount', 0.0) or 0.0)
|
| 1831 |
+
|
| 1832 |
+
fee_event = {
|
| 1833 |
+
'event_type': 'FeeCollected',
|
| 1834 |
+
'timestamp': fee_ts,
|
| 1835 |
+
'relative_ts': fee_ts_val - t0_timestamp,
|
| 1836 |
+
'wallet_address': fee_record.get('recipient_address'),
|
| 1837 |
+
'token_address': token_address,
|
| 1838 |
+
'sol_amount': amount,
|
| 1839 |
+
'protocol_id': fee_record.get('protocol', 0),
|
| 1840 |
+
'priority_fee': fee_record.get('priority_fee', 0.0),
|
| 1841 |
+
'success': fee_record.get('success', False)
|
| 1842 |
+
}
|
| 1843 |
+
_register_event(
|
| 1844 |
+
fee_event,
|
| 1845 |
+
_event_execution_sort_key(
|
| 1846 |
+
fee_record.get('timestamp'),
|
| 1847 |
+
slot=fee_record.get('slot', 0),
|
| 1848 |
+
signature=fee_record.get('signature', '')
|
| 1849 |
+
)
|
| 1850 |
+
)
|
| 1851 |
+
|
| 1852 |
+
for burn_record in burn_records:
|
| 1853 |
+
burn_ts_val = _timestamp_to_order_value(burn_record.get('timestamp'))
|
| 1854 |
+
burn_ts = int(burn_ts_val)
|
| 1855 |
+
amount_dec = float(burn_record.get('amount_decimal', 0.0) or 0.0)
|
| 1856 |
+
amount_pct = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
|
| 1857 |
+
|
| 1858 |
+
burn_event = {
|
| 1859 |
+
'event_type': 'TokenBurn',
|
| 1860 |
+
'timestamp': burn_ts,
|
| 1861 |
+
'relative_ts': burn_ts_val - t0_timestamp,
|
| 1862 |
+
'wallet_address': burn_record.get('source'),
|
| 1863 |
+
'token_address': token_address,
|
| 1864 |
+
'amount_pct_of_total_supply': amount_pct,
|
| 1865 |
+
'amount_tokens_burned': amount_dec,
|
| 1866 |
+
'priority_fee': burn_record.get('priority_fee', 0.0),
|
| 1867 |
+
'success': burn_record.get('success', False)
|
| 1868 |
+
}
|
| 1869 |
+
_register_event(
|
| 1870 |
+
burn_event,
|
| 1871 |
+
_event_execution_sort_key(
|
| 1872 |
+
burn_record.get('timestamp'),
|
| 1873 |
+
slot=burn_record.get('slot', 0),
|
| 1874 |
+
signature=burn_record.get('signature', '')
|
| 1875 |
+
)
|
| 1876 |
+
)
|
| 1877 |
+
|
| 1878 |
+
for lock_record in supply_lock_records:
|
| 1879 |
+
lock_ts_val = _timestamp_to_order_value(lock_record.get('timestamp'))
|
| 1880 |
+
lock_ts = int(lock_ts_val)
|
| 1881 |
+
total_locked_amount = float(lock_record.get('total_locked_amount', 0.0) or 0.0)
|
| 1882 |
+
amount_pct = (total_locked_amount / total_supply_dec) if total_supply_dec > 0 else 0.0
|
| 1883 |
+
final_unlock_ts = lock_record.get('final_unlock_timestamp', 0) or 0
|
| 1884 |
+
lock_duration = float(final_unlock_ts) - float(lock_ts_val)
|
| 1885 |
+
if lock_duration < 0:
|
| 1886 |
+
lock_duration = 0.0
|
| 1887 |
+
|
| 1888 |
+
lock_event = {
|
| 1889 |
+
'event_type': 'SupplyLock',
|
| 1890 |
+
'timestamp': lock_ts,
|
| 1891 |
+
'relative_ts': lock_ts_val - t0_timestamp,
|
| 1892 |
+
'wallet_address': lock_record.get('sender'),
|
| 1893 |
+
'token_address': token_address,
|
| 1894 |
+
'amount_pct_of_total_supply': amount_pct,
|
| 1895 |
+
'lock_duration': lock_duration,
|
| 1896 |
+
'protocol_id': lock_record.get('protocol', 0),
|
| 1897 |
+
'priority_fee': lock_record.get('priority_fee', 0.0),
|
| 1898 |
+
'success': lock_record.get('success', False)
|
| 1899 |
+
}
|
| 1900 |
+
_register_event(
|
| 1901 |
+
lock_event,
|
| 1902 |
+
_event_execution_sort_key(
|
| 1903 |
+
lock_record.get('timestamp'),
|
| 1904 |
+
slot=lock_record.get('slot', 0),
|
| 1905 |
+
signature=lock_record.get('signature', '')
|
| 1906 |
+
)
|
| 1907 |
+
)
|
| 1908 |
+
|
| 1909 |
+
for migration_record in migration_records:
|
| 1910 |
+
mig_ts_val = _timestamp_to_order_value(migration_record.get('timestamp'))
|
| 1911 |
+
mig_ts = int(mig_ts_val)
|
| 1912 |
+
mig_event = {
|
| 1913 |
+
'event_type': 'Migrated',
|
| 1914 |
+
'timestamp': mig_ts,
|
| 1915 |
+
'relative_ts': mig_ts_val - t0_timestamp,
|
| 1916 |
+
'wallet_address': None,
|
| 1917 |
+
'token_address': token_address,
|
| 1918 |
+
'protocol_id': migration_record.get('protocol', 0),
|
| 1919 |
+
'priority_fee': migration_record.get('priority_fee', 0.0),
|
| 1920 |
+
'success': migration_record.get('success', False)
|
| 1921 |
+
}
|
| 1922 |
+
_register_event(
|
| 1923 |
+
mig_event,
|
| 1924 |
+
_event_execution_sort_key(
|
| 1925 |
+
migration_record.get('timestamp'),
|
| 1926 |
+
slot=migration_record.get('slot', 0),
|
| 1927 |
+
signature=migration_record.get('signature', '')
|
| 1928 |
+
)
|
| 1929 |
+
)
|
| 1930 |
|
| 1931 |
# 6. Generate Snapshots
|
| 1932 |
self._generate_onchain_snapshots(
|
| 1933 |
token_address, int(t0_timestamp), T_cutoff,
|
| 1934 |
300, # Interval
|
| 1935 |
+
trade_events, transfer_events,
|
| 1936 |
aggregation_trades,
|
| 1937 |
wallet_data,
|
| 1938 |
total_supply_dec,
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:038695f9d26e59395a2e69b52f9da029cf8796b06e4d503f0c18191288ad2a02
|
| 3 |
size 1660
|
database.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python scripts/download_epoch_artifacts.py
|
| 2 |
+
python scripts/ingest_epoch.py --epoch 844
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82f53ccd1f7ea7868893c879fabdea0b5939a3780cc5086790edb79d878e8121
|
| 3 |
+
size 32490
|
models/helper_encoders.py
CHANGED
|
@@ -43,8 +43,8 @@ class ContextualTimeEncoder(nn.Module):
|
|
| 43 |
half_dim = d_model // 2
|
| 44 |
|
| 45 |
# Calculations for sinusoidal encoding are more stable in float32
|
| 46 |
-
div_term = torch.exp(torch.arange(0, half_dim, device=device
|
| 47 |
-
args = values.
|
| 48 |
|
| 49 |
return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
| 50 |
|
|
@@ -63,15 +63,16 @@ class ContextualTimeEncoder(nn.Module):
|
|
| 63 |
|
| 64 |
# 1. Store original shape (e.g., [B, L]) and flatten
|
| 65 |
original_shape = timestamps.shape
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
# 2. Sinusoidal encode (already vectorized)
|
| 69 |
ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
|
| 70 |
|
| 71 |
# 3. List comprehension (this is the only non-vectorized part)
|
| 72 |
# This loop is now correct, as it iterates over the 1D flat tensor
|
| 73 |
-
hours = torch.tensor([datetime.datetime.fromtimestamp(
|
| 74 |
-
days = torch.tensor([datetime.datetime.fromtimestamp(
|
| 75 |
|
| 76 |
# 4. Cyclical encode (already vectorized)
|
| 77 |
hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
|
|
|
|
| 43 |
half_dim = d_model // 2
|
| 44 |
|
| 45 |
# Calculations for sinusoidal encoding are more stable in float32
|
| 46 |
+
div_term = torch.exp(torch.arange(0, half_dim, device=device, dtype=torch.float64) * -(math.log(10000.0) / half_dim))
|
| 47 |
+
args = values.double().unsqueeze(-1) * div_term
|
| 48 |
|
| 49 |
return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
| 50 |
|
|
|
|
| 63 |
|
| 64 |
# 1. Store original shape (e.g., [B, L]) and flatten
|
| 65 |
original_shape = timestamps.shape
|
| 66 |
+
# Preserve precision for large Unix timestamps.
|
| 67 |
+
timestamps_flat = timestamps.flatten().double() # Shape [N_total]
|
| 68 |
|
| 69 |
# 2. Sinusoidal encode (already vectorized)
|
| 70 |
ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
|
| 71 |
|
| 72 |
# 3. List comprehension (this is the only non-vectorized part)
|
| 73 |
# This loop is now correct, as it iterates over the 1D flat tensor
|
| 74 |
+
hours = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
|
| 75 |
+
days = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
|
| 76 |
|
| 77 |
# 4. Cyclical encode (already vectorized)
|
| 78 |
hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
|
pre_cache.sh
CHANGED
|
@@ -3,7 +3,5 @@
|
|
| 3 |
|
| 4 |
echo "Starting dataset caching..."
|
| 5 |
python3 scripts/cache_dataset.py \
|
| 6 |
-
--ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz"
|
| 7 |
-
--max_samples 50
|
| 8 |
-
|
| 9 |
echo "Done!"
|
|
|
|
| 3 |
|
| 4 |
echo "Starting dataset caching..."
|
| 5 |
python3 scripts/cache_dataset.py \
|
| 6 |
+
--ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz"
|
|
|
|
|
|
|
| 7 |
echo "Done!"
|
scripts/inspect_collator.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
# Ensure repo root is on sys.path
|
| 11 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 13 |
+
|
| 14 |
+
from data.data_loader import OracleDataset
|
| 15 |
+
from data.data_fetcher import DataFetcher
|
| 16 |
+
from data.data_collator import MemecoinCollator
|
| 17 |
+
import models.vocabulary as vocab
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _decode_events(event_type_ids: torch.Tensor) -> List[str]:
|
| 21 |
+
names = []
|
| 22 |
+
for eid in event_type_ids.tolist():
|
| 23 |
+
if eid == 0:
|
| 24 |
+
names.append("__PAD__")
|
| 25 |
+
else:
|
| 26 |
+
names.append(vocab.ID_TO_EVENT.get(eid, f"UNK_{eid}"))
|
| 27 |
+
return names
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _tensor_to_list(t: torch.Tensor) -> List:
|
| 31 |
+
return t.detach().cpu().tolist()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main() -> None:
|
| 35 |
+
parser = argparse.ArgumentParser(description="Inspect MemecoinCollator outputs on cached samples.")
|
| 36 |
+
parser.add_argument("--cache_dir", type=str, default="data/cache")
|
| 37 |
+
parser.add_argument("--idx", type=int, nargs="+", default=[0], help="Sample indices to inspect")
|
| 38 |
+
parser.add_argument("--max_seq_len", type=int, default=16000)
|
| 39 |
+
parser.add_argument("--out", type=str, default="collator_dump.json")
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
cache_dir = Path(args.cache_dir)
|
| 43 |
+
# Optional: enable time-aware fetches if DB env is set.
|
| 44 |
+
import os
|
| 45 |
+
from dotenv import load_dotenv
|
| 46 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 47 |
+
from neo4j import GraphDatabase
|
| 48 |
+
|
| 49 |
+
load_dotenv()
|
| 50 |
+
clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 51 |
+
clickhouse_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", os.getenv("CLICKHOUSE_PORT", 9000)))
|
| 52 |
+
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
| 53 |
+
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
|
| 54 |
+
neo4j_password = os.getenv("NEO4J_PASSWORD", "password")
|
| 55 |
+
clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port)
|
| 56 |
+
neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
|
| 57 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 58 |
+
|
| 59 |
+
dataset = OracleDataset(
|
| 60 |
+
data_fetcher=data_fetcher,
|
| 61 |
+
cache_dir=str(cache_dir),
|
| 62 |
+
horizons_seconds=[30, 60, 120, 240, 420],
|
| 63 |
+
quantiles=[0.1, 0.5, 0.9],
|
| 64 |
+
max_samples=None,
|
| 65 |
+
max_seq_len=args.max_seq_len,
|
| 66 |
+
)
|
| 67 |
+
if hasattr(dataset, "init_fetcher"):
|
| 68 |
+
dataset.init_fetcher()
|
| 69 |
+
|
| 70 |
+
collator = MemecoinCollator(
|
| 71 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 72 |
+
device=torch.device("cpu"),
|
| 73 |
+
dtype=torch.float32,
|
| 74 |
+
max_seq_len=args.max_seq_len,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
batch_items = [dataset[i] for i in args.idx]
|
| 78 |
+
batch = collator(batch_items)
|
| 79 |
+
|
| 80 |
+
# Build JSON-friendly dump (no truncation of events; embeddings are omitted)
|
| 81 |
+
dump: Dict[str, Any] = {
|
| 82 |
+
"batch_size": len(args.idx),
|
| 83 |
+
"token_addresses": batch.get("token_addresses"),
|
| 84 |
+
"t_cutoffs": batch.get("t_cutoffs"),
|
| 85 |
+
"sample_indices": batch.get("sample_indices"),
|
| 86 |
+
"raw_events": [item.get("event_sequence", []) for item in batch_items],
|
| 87 |
+
}
|
| 88 |
+
# Raw event type counts
|
| 89 |
+
event_counts = []
|
| 90 |
+
for item in batch_items:
|
| 91 |
+
counts: Dict[str, int] = {}
|
| 92 |
+
for ev in item.get("event_sequence", []):
|
| 93 |
+
et = ev.get("event_type", "UNKNOWN")
|
| 94 |
+
counts[et] = counts.get(et, 0) + 1
|
| 95 |
+
event_counts.append(counts)
|
| 96 |
+
dump["raw_event_counts"] = event_counts
|
| 97 |
+
|
| 98 |
+
# Core sequence + features (full length)
|
| 99 |
+
dump["event_type_ids"] = _tensor_to_list(batch["event_type_ids"])
|
| 100 |
+
dump["event_type_names"] = [
|
| 101 |
+
_decode_events(batch["event_type_ids"][i].cpu())
|
| 102 |
+
for i in range(batch["event_type_ids"].shape[0])
|
| 103 |
+
]
|
| 104 |
+
dump["timestamps_float"] = _tensor_to_list(batch["timestamps_float"])
|
| 105 |
+
dump["relative_ts"] = _tensor_to_list(batch["relative_ts"])
|
| 106 |
+
dump["attention_mask"] = _tensor_to_list(batch["attention_mask"])
|
| 107 |
+
dump["wallet_addr_to_batch_idx"] = batch.get("wallet_addr_to_batch_idx", {})
|
| 108 |
+
|
| 109 |
+
# Pointer tensors
|
| 110 |
+
for key in [
|
| 111 |
+
"wallet_indices",
|
| 112 |
+
"token_indices",
|
| 113 |
+
"quote_token_indices",
|
| 114 |
+
"trending_token_indices",
|
| 115 |
+
"boosted_token_indices",
|
| 116 |
+
"dest_wallet_indices",
|
| 117 |
+
"original_author_indices",
|
| 118 |
+
"ohlc_indices",
|
| 119 |
+
"holder_snapshot_indices",
|
| 120 |
+
"textual_event_indices",
|
| 121 |
+
]:
|
| 122 |
+
if key in batch:
|
| 123 |
+
dump[key] = _tensor_to_list(batch[key])
|
| 124 |
+
|
| 125 |
+
# Numerical feature tensors
|
| 126 |
+
nonzero_summary = {}
|
| 127 |
+
for key in [
|
| 128 |
+
"transfer_numerical_features",
|
| 129 |
+
"trade_numerical_features",
|
| 130 |
+
"deployer_trade_numerical_features",
|
| 131 |
+
"smart_wallet_trade_numerical_features",
|
| 132 |
+
"pool_created_numerical_features",
|
| 133 |
+
"liquidity_change_numerical_features",
|
| 134 |
+
"fee_collected_numerical_features",
|
| 135 |
+
"token_burn_numerical_features",
|
| 136 |
+
"supply_lock_numerical_features",
|
| 137 |
+
"onchain_snapshot_numerical_features",
|
| 138 |
+
"trending_token_numerical_features",
|
| 139 |
+
"boosted_token_numerical_features",
|
| 140 |
+
"dexboost_paid_numerical_features",
|
| 141 |
+
"dexprofile_updated_flags",
|
| 142 |
+
"global_trending_numerical_features",
|
| 143 |
+
"chainsnapshot_numerical_features",
|
| 144 |
+
"lighthousesnapshot_numerical_features",
|
| 145 |
+
]:
|
| 146 |
+
if key in batch:
|
| 147 |
+
t = batch[key]
|
| 148 |
+
dump[key] = _tensor_to_list(t)
|
| 149 |
+
nonzero_summary[key] = int(torch.count_nonzero(t).item())
|
| 150 |
+
|
| 151 |
+
# Categorical feature tensors
|
| 152 |
+
for key in [
|
| 153 |
+
"trade_dex_ids",
|
| 154 |
+
"trade_direction_ids",
|
| 155 |
+
"trade_mev_protection_ids",
|
| 156 |
+
"trade_is_bundle_ids",
|
| 157 |
+
"pool_created_protocol_ids",
|
| 158 |
+
"liquidity_change_type_ids",
|
| 159 |
+
"trending_token_source_ids",
|
| 160 |
+
"trending_token_timeframe_ids",
|
| 161 |
+
"lighthousesnapshot_protocol_ids",
|
| 162 |
+
"lighthousesnapshot_timeframe_ids",
|
| 163 |
+
"migrated_protocol_ids",
|
| 164 |
+
"alpha_group_ids",
|
| 165 |
+
"channel_ids",
|
| 166 |
+
"exchange_ids",
|
| 167 |
+
]:
|
| 168 |
+
if key in batch:
|
| 169 |
+
t = batch[key]
|
| 170 |
+
dump[key] = _tensor_to_list(t)
|
| 171 |
+
nonzero_summary[key] = int(torch.count_nonzero(t).item())
|
| 172 |
+
|
| 173 |
+
# Labels
|
| 174 |
+
if batch.get("labels") is not None:
|
| 175 |
+
dump["labels"] = _tensor_to_list(batch["labels"])
|
| 176 |
+
if batch.get("labels_mask") is not None:
|
| 177 |
+
dump["labels_mask"] = _tensor_to_list(batch["labels_mask"])
|
| 178 |
+
if batch.get("quality_score") is not None:
|
| 179 |
+
dump["quality_score"] = _tensor_to_list(batch["quality_score"])
|
| 180 |
+
|
| 181 |
+
dump["nonzero_summary"] = nonzero_summary
|
| 182 |
+
|
| 183 |
+
# Raw wallet/token feature payloads used by encoders
|
| 184 |
+
wallet_inputs = batch.get("wallet_encoder_inputs", {})
|
| 185 |
+
token_inputs = batch.get("token_encoder_inputs", {})
|
| 186 |
+
dump["wallet_encoder_inputs"] = {
|
| 187 |
+
"profile_rows": wallet_inputs.get("profile_rows", []),
|
| 188 |
+
"social_rows": wallet_inputs.get("social_rows", []),
|
| 189 |
+
"holdings_batch": wallet_inputs.get("holdings_batch", []),
|
| 190 |
+
"username_embed_indices": _tensor_to_list(wallet_inputs.get("username_embed_indices")) if "username_embed_indices" in wallet_inputs else [],
|
| 191 |
+
}
|
| 192 |
+
dump["token_encoder_inputs"] = {
|
| 193 |
+
"addresses_for_lookup": token_inputs.get("_addresses_for_lookup", []),
|
| 194 |
+
"protocol_ids": _tensor_to_list(token_inputs.get("protocol_ids")) if "protocol_ids" in token_inputs else [],
|
| 195 |
+
"is_vanity_flags": _tensor_to_list(token_inputs.get("is_vanity_flags")) if "is_vanity_flags" in token_inputs else [],
|
| 196 |
+
"name_embed_indices": _tensor_to_list(token_inputs.get("name_embed_indices")) if "name_embed_indices" in token_inputs else [],
|
| 197 |
+
"symbol_embed_indices": _tensor_to_list(token_inputs.get("symbol_embed_indices")) if "symbol_embed_indices" in token_inputs else [],
|
| 198 |
+
"image_embed_indices": _tensor_to_list(token_inputs.get("image_embed_indices")) if "image_embed_indices" in token_inputs else [],
|
| 199 |
+
}
|
| 200 |
+
dump["wallet_set_encoder_inputs"] = {
|
| 201 |
+
"holdings_batch": wallet_inputs.get("holdings_batch", []),
|
| 202 |
+
"token_vibe_lookup_keys": token_inputs.get("_addresses_for_lookup", []),
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
out_path = Path(args.out)
|
| 206 |
+
def _json_default(o):
|
| 207 |
+
if isinstance(o, (str, int, float, bool)) or o is None:
|
| 208 |
+
return o
|
| 209 |
+
try:
|
| 210 |
+
import datetime as _dt
|
| 211 |
+
if isinstance(o, (_dt.datetime, _dt.date)):
|
| 212 |
+
return o.isoformat()
|
| 213 |
+
except Exception:
|
| 214 |
+
pass
|
| 215 |
+
try:
|
| 216 |
+
return str(o)
|
| 217 |
+
except Exception:
|
| 218 |
+
return "<unserializable>"
|
| 219 |
+
|
| 220 |
+
with out_path.open("w") as f:
|
| 221 |
+
json.dump(dump, f, indent=2, default=_json_default)
|
| 222 |
+
|
| 223 |
+
print(f"Wrote collator dump to {out_path.resolve()}")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
train.py
CHANGED
|
@@ -118,6 +118,15 @@ def quantile_pinball_loss(preds: torch.Tensor,
|
|
| 118 |
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def filtered_collate(collator: MemecoinCollator,
|
| 122 |
batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
|
| 123 |
"""Filter out None items from the dataset before collating."""
|
|
@@ -304,13 +313,18 @@ def main() -> None:
|
|
| 304 |
max_seq_len=max_seq_len
|
| 305 |
)
|
| 306 |
|
| 307 |
-
# DB
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
dataset = OracleDataset(
|
| 313 |
-
data_fetcher=None,
|
|
|
|
| 314 |
horizons_seconds=horizons,
|
| 315 |
quantiles=quantiles,
|
| 316 |
max_samples=args.max_samples,
|
|
@@ -339,6 +353,10 @@ def main() -> None:
|
|
| 339 |
else:
|
| 340 |
logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
dl_kwargs = dict(
|
| 343 |
dataset=dataset,
|
| 344 |
batch_size=batch_size,
|
|
@@ -353,6 +371,9 @@ def main() -> None:
|
|
| 353 |
# re-initializes heavy per-worker state (e.g. SigLIP MultiModalEncoder).
|
| 354 |
dl_kwargs["persistent_workers"] = True
|
| 355 |
dl_kwargs["prefetch_factor"] = 2
|
|
|
|
|
|
|
|
|
|
| 356 |
dataloader = DataLoader(**dl_kwargs)
|
| 357 |
|
| 358 |
# --- 3. Model Init ---
|
|
@@ -599,7 +620,6 @@ def main() -> None:
|
|
| 599 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
| 600 |
|
| 601 |
accelerator.end_training()
|
| 602 |
-
# neo4j_driver.close() # REMOVED
|
| 603 |
|
| 604 |
if __name__ == "__main__":
|
| 605 |
main()
|
|
|
|
| 118 |
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 119 |
|
| 120 |
|
| 121 |
+
def init_worker_fetcher(worker_id: int) -> None:
|
| 122 |
+
"""Initialize per-worker DataFetcher for cached datasets."""
|
| 123 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 124 |
+
if worker_info is not None:
|
| 125 |
+
ds = worker_info.dataset
|
| 126 |
+
if hasattr(ds, "init_fetcher"):
|
| 127 |
+
ds.init_fetcher()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
def filtered_collate(collator: MemecoinCollator,
|
| 131 |
batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
|
| 132 |
"""Filter out None items from the dataset before collating."""
|
|
|
|
| 313 |
max_seq_len=max_seq_len
|
| 314 |
)
|
| 315 |
|
| 316 |
+
# DB config (for time-aware wallet/token/graph features during training)
|
| 317 |
+
fetcher_config = {
|
| 318 |
+
"clickhouse_host": os.getenv("CLICKHOUSE_HOST", "localhost"),
|
| 319 |
+
"clickhouse_port": int(os.getenv("CLICKHOUSE_PORT", 9000)),
|
| 320 |
+
"neo4j_uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
|
| 321 |
+
"neo4j_user": os.getenv("NEO4J_USER", "neo4j"),
|
| 322 |
+
"neo4j_password": os.getenv("NEO4J_PASSWORD", "password"),
|
| 323 |
+
}
|
| 324 |
|
| 325 |
dataset = OracleDataset(
|
| 326 |
+
data_fetcher=None,
|
| 327 |
+
fetcher_config=fetcher_config, # Training Mode (Cache + time-aware fetch per worker)
|
| 328 |
horizons_seconds=horizons,
|
| 329 |
quantiles=quantiles,
|
| 330 |
max_samples=args.max_samples,
|
|
|
|
| 353 |
else:
|
| 354 |
logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
|
| 355 |
|
| 356 |
+
# Initialize DataFetcher in main process when not using workers.
|
| 357 |
+
if int(args.num_workers) == 0:
|
| 358 |
+
dataset.init_fetcher()
|
| 359 |
+
|
| 360 |
dl_kwargs = dict(
|
| 361 |
dataset=dataset,
|
| 362 |
batch_size=batch_size,
|
|
|
|
| 371 |
# re-initializes heavy per-worker state (e.g. SigLIP MultiModalEncoder).
|
| 372 |
dl_kwargs["persistent_workers"] = True
|
| 373 |
dl_kwargs["prefetch_factor"] = 2
|
| 374 |
+
if int(args.num_workers) > 0:
|
| 375 |
+
dl_kwargs["worker_init_fn"] = init_worker_fetcher
|
| 376 |
+
|
| 377 |
dataloader = DataLoader(**dl_kwargs)
|
| 378 |
|
| 379 |
# --- 3. Model Init ---
|
|
|
|
| 620 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
| 621 |
|
| 622 |
accelerator.end_training()
|
|
|
|
| 623 |
|
| 624 |
if __name__ == "__main__":
|
| 625 |
main()
|