zirobtc commited on
Commit
86523f8
·
1 Parent(s): e125fa3

Upload folder using huggingface_hub

Browse files
collator_dump.json ADDED
The diff for this file is too large to render. See raw diff
 
data/data_collator.py CHANGED
@@ -276,11 +276,25 @@ class MemecoinCollator:
276
  unique_wallets_data.update(item.get('wallets', {}))
277
  unique_tokens_data.update(item.get('tokens', {}))
278
 
279
- # Create mappings needed for indexing
280
- wallet_list_data = list(unique_wallets_data.values())
281
- token_list_data = list(unique_tokens_data.values())
282
- wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
283
- token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  # Collate Static Raw Features (Tokens, Wallets, Graph)
286
  token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
@@ -297,7 +311,8 @@ class MemecoinCollator:
297
 
298
  # Initialize sequence tensors
299
  event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
300
- timestamps_float = torch.zeros((B, L), dtype=torch.float32, device=self.device)
 
301
  # Store relative_ts in float32 for stability; model will scale/log/normalize
302
  relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
303
  attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
@@ -601,6 +616,9 @@ class MemecoinCollator:
601
  ]
602
  boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
603
 
 
 
 
604
  elif event_type == 'HolderSnapshot':
605
  # --- FIXED: Store raw holder data, not an index ---
606
  raw_holders = event.get('holders', [])
 
276
  unique_wallets_data.update(item.get('wallets', {}))
277
  unique_tokens_data.update(item.get('tokens', {}))
278
 
279
+ # Create mappings needed for indexing (use dict keys as source of truth)
280
+ wallet_items = list(unique_wallets_data.items())
281
+ token_items = list(unique_tokens_data.items())
282
+
283
+ wallet_list_data = []
284
+ for addr, feat in wallet_items:
285
+ profile = feat.get('profile', {})
286
+ if not profile.get('wallet_address'):
287
+ profile['wallet_address'] = addr
288
+ wallet_list_data.append(feat)
289
+
290
+ token_list_data = []
291
+ for addr, feat in token_items:
292
+ if not feat.get('address'):
293
+ feat['address'] = addr
294
+ token_list_data.append(feat)
295
+
296
+ wallet_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(wallet_items)}
297
+ token_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(token_items)}
298
 
299
  # Collate Static Raw Features (Tokens, Wallets, Graph)
300
  token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
 
311
 
312
  # Initialize sequence tensors
313
  event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
314
+ # Use float64 to preserve second-level precision for large Unix timestamps.
315
+ timestamps_float = torch.zeros((B, L), dtype=torch.float64, device=self.device)
316
  # Store relative_ts in float32 for stability; model will scale/log/normalize
317
  relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
318
  attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
 
616
  ]
617
  boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
618
 
619
+ elif event_type == 'Migrated':
620
+ migrated_protocol_ids[i, j] = event.get('protocol_id', 0)
621
+
622
  elif event_type == 'HolderSnapshot':
623
  # --- FIXED: Store raw holder data, not an index ---
624
  raw_holders = event.get('holders', [])
data/data_fetcher.py CHANGED
@@ -1014,6 +1014,45 @@ class DataFetcher:
1014
  except Exception as e:
1015
  print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
1016
  return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1017
  def fetch_raw_token_data(
1018
  self,
1019
  token_address: str,
 
1014
  except Exception as e:
1015
  print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
1016
  return 0
1017
+
1018
+ def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]:
1019
+ """
1020
+ Fetch total holder count and top holders in a single query.
1021
+ Returns (count, top_holders_list).
1022
+ """
1023
+ if not token_address:
1024
+ return 0, []
1025
+ query = """
1026
+ WITH point_in_time_holdings AS (
1027
+ SELECT
1028
+ wallet_address,
1029
+ argMax(current_balance, updated_at) AS bal
1030
+ FROM wallet_holdings
1031
+ WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
1032
+ GROUP BY wallet_address
1033
+ )
1034
+ SELECT
1035
+ (SELECT count() FROM point_in_time_holdings WHERE bal > 0) AS holder_count,
1036
+ (SELECT groupArray((wallet_address, bal))
1037
+ FROM (
1038
+ SELECT wallet_address, bal
1039
+ FROM point_in_time_holdings
1040
+ WHERE bal > 0
1041
+ ORDER BY bal DESC
1042
+ LIMIT %(limit)s
1043
+ )) AS top_holders
1044
+ """
1045
+ params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
1046
+ try:
1047
+ rows = self.db_client.execute(query, params)
1048
+ if not rows:
1049
+ return 0, []
1050
+ holder_count, top_holders = rows[0]
1051
+ top_list = [{'wallet_address': wa, 'current_balance': bal} for wa, bal in (top_holders or [])]
1052
+ return int(holder_count or 0), top_list
1053
+ except Exception as e:
1054
+ print(f"ERROR: Failed to fetch holder snapshot stats for token {token_address}: {e}")
1055
+ return 0, []
1056
  def fetch_raw_token_data(
1057
  self,
1058
  token_address: str,
data/data_loader.py CHANGED
@@ -61,6 +61,7 @@ MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
61
  # Interval for HolderSnapshot events (seconds)
62
  HOLDER_SNAPSHOT_INTERVAL_SEC = 300
63
  HOLDER_SNAPSHOT_TOP_K = 200
 
64
 
65
 
66
  class EmbeddingPooler:
@@ -114,6 +115,7 @@ class OracleDataset(Dataset):
114
  """
115
  def __init__(self,
116
  data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
 
117
  horizons_seconds: List[int] = [],
118
  quantiles: List[float] = [],
119
  max_samples: Optional[int] = None,
@@ -129,18 +131,11 @@ class OracleDataset(Dataset):
129
 
130
  # --- NEW: Create a persistent requests session for efficiency ---
131
  # Configure robust HTTP session
132
- self.http_session = requests.Session()
133
- retry_strategy = Retry(
134
- total=3,
135
- backoff_factor=1,
136
- status_forcelist=[429, 500, 502, 503, 504],
137
- allowed_methods=["HEAD", "GET", "OPTIONS"]
138
- )
139
- adapter = HTTPAdapter(max_retries=retry_strategy)
140
- self.http_session.mount("http://", adapter)
141
- self.http_session.mount("https://", adapter)
142
 
143
  self.fetcher = data_fetcher
 
144
  self.cache_dir = Path(cache_dir) if cache_dir else None
145
  # Always define these so DataLoader workers don't crash with AttributeError if
146
  # initialization falls through an unexpected branch.
@@ -271,6 +266,51 @@ class OracleDataset(Dataset):
271
  print("INFO: No OHLC stats path provided. Using default normalization.")
272
 
273
  self.min_trade_usd = min_trade_usd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  def __len__(self) -> int:
276
  return self.num_samples
@@ -287,6 +327,16 @@ class OracleDataset(Dataset):
287
  denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
288
  return [(float(v) - self.ohlc_price_mean) / denom for v in values]
289
 
 
 
 
 
 
 
 
 
 
 
290
  def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
291
  """
292
  Applies dynamic context sampling to fit events within max_seq_len.
@@ -482,8 +532,11 @@ class OracleDataset(Dataset):
482
  holders_end = len(cached_holders_list[i])
483
  elif self.fetcher:
484
  cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
485
- holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
486
- holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, cutoff_dt_ts)
 
 
 
487
  else:
488
  holder_records_ts = []
489
  holders_end = 0
@@ -784,6 +837,9 @@ class OracleDataset(Dataset):
784
  # --- FIXED: Only add to pooler if data is valid ---
785
  image = None
786
  token_uri = data.get('token_uri')
 
 
 
787
 
788
  # --- NEW: Use multiple IPFS gateways for reliability ---
789
  if token_uri and isinstance(token_uri, str) and token_uri.strip():
@@ -841,11 +897,15 @@ class OracleDataset(Dataset):
841
  else: # If all gateways fail for the image
842
  raise RuntimeError(f"All IPFS gateways failed for image: {image_url}")
843
  else: # Handle regular HTTP image URLs
 
 
844
  image_resp = self.http_session.get(image_url, timeout=10)
845
  image_resp.raise_for_status()
846
  image = Image.open(BytesIO(image_resp.content))
847
  except (requests.RequestException, ValueError, IOError) as e:
848
-
 
 
849
  print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
850
  image = None
851
 
@@ -954,6 +1014,9 @@ class OracleDataset(Dataset):
954
  """
955
  Loads raw data from cache, samples a random T_cutoff, and generates a training sample.
956
  """
 
 
 
957
  raw_data = None
958
  if self.cache_dir:
959
  if idx >= len(self.cached_files):
@@ -1024,8 +1087,8 @@ class OracleDataset(Dataset):
1024
  # ============================================================================
1025
  # 1. Use ALL trades (sorted by timestamp) for context
1026
  # 2. Find indices of SUCCESSFUL trades (needed for label computation)
1027
- # 3. Sample interval: [24, last_successful_idx - 1]
1028
- # 4. This guarantees: 24+ trades for context, 1+ successful trade for labels
1029
  # ============================================================================
1030
 
1031
  all_trades_raw = raw_data.get('trades', [])
@@ -1038,7 +1101,8 @@ class OracleDataset(Dataset):
1038
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
1039
  )
1040
 
1041
- if len(all_trades_sorted) < 26: # Need at least 24 for context + 2 for cutoff+label
 
1042
  return None
1043
 
1044
  # Find indices of SUCCESSFUL trades (valid for label computation)
@@ -1052,7 +1116,7 @@ class OracleDataset(Dataset):
1052
 
1053
  max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
1054
  # Define sampling interval
1055
- min_idx = 24 # At least 24 trades for context
1056
  max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
1057
 
1058
  if max_idx < min_idx:
@@ -1148,38 +1212,57 @@ class OracleDataset(Dataset):
1148
  _add_wallet(holder.get('wallet_address'), wallets_to_fetch)
1149
 
1150
  pooler = EmbeddingPooler()
1151
- # Prepare offline token data
1152
- offline_token_data = {token_address: raw_data} # Assuming raw_data contains token metadata at root
1153
- main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
 
 
 
 
 
 
 
1154
  if not main_token_data:
1155
  return None
1156
 
1157
- # Prepare offline wallet data
1158
- # raw_data['socials'] structure: {'profiles': {...}, 'socials': {...}} usually.
1159
- # But wait, cached raw_data['socials'] might be just the dict we need?
1160
- # Let's handle graceful empty if not found.
1161
- cached_social_bundle = raw_data.get('socials', {})
1162
- offline_profiles = cached_social_bundle.get('profiles', {})
1163
- offline_socials = cached_social_bundle.get('socials', {})
1164
- offline_holdings = {} # Holdings not cached usually due to size
1165
-
1166
- wallet_data, all_token_data = self._process_wallet_data(
1167
- list(wallets_to_fetch),
1168
- main_token_data.copy(),
1169
- pooler,
1170
- T_cutoff,
1171
- profiles_override=offline_profiles,
1172
- socials_override=offline_socials,
1173
- holdings_override=offline_holdings
1174
- )
 
 
 
 
 
 
 
1175
 
 
1176
  graph_entities = {}
1177
  graph_links = {}
1178
- graph_entities = {}
1179
- graph_links = {}
1180
- # if wallets_to_fetch:
1181
- # graph_entities, graph_links = self.fetcher.fetch_graph_links(...)
1182
- # Offline Graph: check if raw_data has graph? Assuming no for now.
 
 
 
 
1183
 
1184
  # Generate the item
1185
  return self._generate_dataset_item(
@@ -1244,7 +1327,7 @@ class OracleDataset(Dataset):
1244
  max_horizon_seconds=self.max_cache_horizon_seconds,
1245
  include_wallet_data=False,
1246
  include_graph=False,
1247
- min_trades=25, # Enforce min trades for context
1248
  full_history=True, # Bypass H/B/H limits
1249
  prune_failed=False, # Keep failed trades for realistic simulation
1250
  prune_transfers=False # Keep transfers for snapshot reconstruction
@@ -1350,10 +1433,12 @@ class OracleDataset(Dataset):
1350
  # Time is end of bucket
1351
  snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
1352
 
1353
- # These queries can be slow.
1354
- count = self.fetcher.fetch_total_holders_count_for_token(token_address, snapshot_ts)
1355
- # Fetch Top 200 as per constant
1356
- top_holders = self.fetcher.fetch_token_holders_for_snapshot(token_address, snapshot_ts, limit=HOLDER_SNAPSHOT_TOP_K)
 
 
1357
 
1358
  total_supply = raw_data.get('total_supply', 0) or 1
1359
  if raw_data.get('decimals'):
@@ -1470,6 +1555,7 @@ class OracleDataset(Dataset):
1470
 
1471
  # 3. Process Trades (Events + Chart)
1472
  trade_events = []
 
1473
  aggregation_trades = []
1474
  high_def_chart_trades = []
1475
  middle_chart_trades = []
@@ -1563,6 +1649,40 @@ class OracleDataset(Dataset):
1563
  _register_event(trade_event, trade_sort_key)
1564
  trade_events.append(trade_event)
1565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1566
  # 4. Generate Chart Events
1567
  def _finalize_chart(t_list):
1568
  t_list.sort(key=lambda x: x['sort_key'])
@@ -1625,31 +1745,194 @@ class OracleDataset(Dataset):
1625
  chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd", precomputed_ohlc=ohlc_1s_precomputed))
1626
  chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
1627
 
1628
- # 5. Process Other Records (Pool, Liquidity, etc.) using filtering
1629
- # Note: We need to port the logic that converts raw records to events
1630
- # For simplicity, assuming these records are already processed or we add the logic here.
1631
- # Given the space constraint, I'll add a simplified pass for pool creation.
1632
- # Ideally we refactor this into helper methods too.
1633
-
1634
  for pool_record in pool_creation_records:
1635
- pool_ts = int(_timestamp_to_order_value(pool_record.get('timestamp')))
1636
- # ... process pool ...
1637
- # Simple placeholder for now:
 
 
 
 
 
 
 
 
 
 
 
 
 
1638
  pool_event = {
1639
  'event_type': 'PoolCreated',
1640
  'timestamp': pool_ts,
1641
- 'relative_ts': pool_ts - t0_timestamp,
1642
  'wallet_address': pool_record.get('creator_address'),
1643
  'token_address': token_address,
1644
- # ... other fields ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1645
  }
1646
- # _register_event(pool_event, val)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1647
 
1648
  # 6. Generate Snapshots
1649
  self._generate_onchain_snapshots(
1650
  token_address, int(t0_timestamp), T_cutoff,
1651
  300, # Interval
1652
- trade_events, [], # Transfer events
1653
  aggregation_trades,
1654
  wallet_data,
1655
  total_supply_dec,
 
61
  # Interval for HolderSnapshot events (seconds)
62
  HOLDER_SNAPSHOT_INTERVAL_SEC = 300
63
  HOLDER_SNAPSHOT_TOP_K = 200
64
+ DEAD_URI_RETRY_LIMIT = 2
65
 
66
 
67
  class EmbeddingPooler:
 
115
  """
116
  def __init__(self,
117
  data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
118
+ fetcher_config: Optional[Dict[str, Any]] = None,
119
  horizons_seconds: List[int] = [],
120
  quantiles: List[float] = [],
121
  max_samples: Optional[int] = None,
 
131
 
132
  # --- NEW: Create a persistent requests session for efficiency ---
133
  # Configure robust HTTP session
134
+ self.http_session = None
135
+ self._init_http_session()
 
 
 
 
 
 
 
 
136
 
137
  self.fetcher = data_fetcher
138
+ self.fetcher_config = fetcher_config
139
  self.cache_dir = Path(cache_dir) if cache_dir else None
140
  # Always define these so DataLoader workers don't crash with AttributeError if
141
  # initialization falls through an unexpected branch.
 
266
  print("INFO: No OHLC stats path provided. Using default normalization.")
267
 
268
  self.min_trade_usd = min_trade_usd
269
+ self._uri_fail_counts: Dict[str, int] = {}
270
+
271
+ def _init_http_session(self) -> None:
272
+ # Configure robust HTTP session
273
+ self.http_session = requests.Session()
274
+ retry_strategy = Retry(
275
+ total=3,
276
+ backoff_factor=1,
277
+ status_forcelist=[429, 500, 502, 503, 504],
278
+ allowed_methods=["HEAD", "GET", "OPTIONS"]
279
+ )
280
+ adapter = HTTPAdapter(max_retries=retry_strategy)
281
+ self.http_session.mount("http://", adapter)
282
+ self.http_session.mount("https://", adapter)
283
+
284
+ def init_fetcher(self) -> None:
285
+ """
286
+ Initialize DataFetcher from stored config (for DataLoader workers).
287
+ """
288
+ if self.fetcher is not None or not self.fetcher_config:
289
+ return
290
+ from clickhouse_driver import Client as ClickHouseClient
291
+ from neo4j import GraphDatabase
292
+ cfg = self.fetcher_config
293
+ clickhouse_client = ClickHouseClient(
294
+ host=cfg.get("clickhouse_host", "localhost"),
295
+ port=int(cfg.get("clickhouse_port", 9000)),
296
+ )
297
+ neo4j_driver = GraphDatabase.driver(
298
+ cfg.get("neo4j_uri", "bolt://localhost:7687"),
299
+ auth=(cfg.get("neo4j_user", "neo4j"), cfg.get("neo4j_password", "password"))
300
+ )
301
+ self.fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
302
+
303
+ def __getstate__(self):
304
+ state = self.__dict__.copy()
305
+ # Drop non-pickleable objects
306
+ state["fetcher"] = None
307
+ state["http_session"] = None
308
+ return state
309
+
310
+ def __setstate__(self, state):
311
+ self.__dict__.update(state)
312
+ if self.http_session is None:
313
+ self._init_http_session()
314
 
315
  def __len__(self) -> int:
316
  return self.num_samples
 
327
  denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
328
  return [(float(v) - self.ohlc_price_mean) / denom for v in values]
329
 
330
+ def _is_dead_uri(self, uri: Optional[str]) -> bool:
331
+ if not uri:
332
+ return False
333
+ return self._uri_fail_counts.get(uri, 0) >= DEAD_URI_RETRY_LIMIT
334
+
335
+ def _mark_uri_failure(self, uri: Optional[str]) -> None:
336
+ if not uri:
337
+ return
338
+ self._uri_fail_counts[uri] = self._uri_fail_counts.get(uri, 0) + 1
339
+
340
  def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
341
  """
342
  Applies dynamic context sampling to fit events within max_seq_len.
 
532
  holders_end = len(cached_holders_list[i])
533
  elif self.fetcher:
534
  cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
535
+ holders_end, holder_records_ts = self.fetcher.fetch_holder_snapshot_stats_for_token(
536
+ token_address,
537
+ cutoff_dt_ts,
538
+ limit=HOLDER_SNAPSHOT_TOP_K
539
+ )
540
  else:
541
  holder_records_ts = []
542
  holders_end = 0
 
837
  # --- FIXED: Only add to pooler if data is valid ---
838
  image = None
839
  token_uri = data.get('token_uri')
840
+ if self._is_dead_uri(token_uri):
841
+ image = None
842
+ token_uri = None
843
 
844
  # --- NEW: Use multiple IPFS gateways for reliability ---
845
  if token_uri and isinstance(token_uri, str) and token_uri.strip():
 
897
  else: # If all gateways fail for the image
898
  raise RuntimeError(f"All IPFS gateways failed for image: {image_url}")
899
  else: # Handle regular HTTP image URLs
900
+ if self._is_dead_uri(image_url):
901
+ raise requests.RequestException("Skipping dead image URI after repeated failures.")
902
  image_resp = self.http_session.get(image_url, timeout=10)
903
  image_resp.raise_for_status()
904
  image = Image.open(BytesIO(image_resp.content))
905
  except (requests.RequestException, ValueError, IOError) as e:
906
+ self._mark_uri_failure(token_uri)
907
+ if isinstance(metadata.get('image') if 'metadata' in locals() and isinstance(metadata, dict) else None, str):
908
+ self._mark_uri_failure(metadata.get('image'))
909
  print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
910
  image = None
911
 
 
1014
  """
1015
  Loads raw data from cache, samples a random T_cutoff, and generates a training sample.
1016
  """
1017
+ if self.fetcher is None and self.fetcher_config:
1018
+ # Lazy init in main or worker if not initialized.
1019
+ self.init_fetcher()
1020
  raw_data = None
1021
  if self.cache_dir:
1022
  if idx >= len(self.cached_files):
 
1087
  # ============================================================================
1088
  # 1. Use ALL trades (sorted by timestamp) for context
1089
  # 2. Find indices of SUCCESSFUL trades (needed for label computation)
1090
+ # 3. Sample interval: [min_context_trades-1, last_successful_idx - 1]
1091
+ # 4. This guarantees: N trades for context, 1+ successful trade for labels
1092
  # ============================================================================
1093
 
1094
  all_trades_raw = raw_data.get('trades', [])
 
1101
  key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
1102
  )
1103
 
1104
+ min_context_trades = 10
1105
+ if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
1106
  return None
1107
 
1108
  # Find indices of SUCCESSFUL trades (valid for label computation)
 
1116
 
1117
  max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
1118
  # Define sampling interval
1119
+ min_idx = min_context_trades - 1 # At least N trades for context
1120
  max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
1121
 
1122
  if max_idx < min_idx:
 
1212
  _add_wallet(holder.get('wallet_address'), wallets_to_fetch)
1213
 
1214
  pooler = EmbeddingPooler()
1215
+ # Token data: fetch time-aware data when fetcher is available.
1216
+ if self.fetcher:
1217
+ main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=None)
1218
+ # Fallback to cached raw data if DB returned nothing
1219
+ if not main_token_data:
1220
+ offline_token_data = {token_address: raw_data} # raw_data contains token metadata at root
1221
+ main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
1222
+ else:
1223
+ offline_token_data = {token_address: raw_data} # raw_data contains token metadata at root
1224
+ main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
1225
  if not main_token_data:
1226
  return None
1227
 
1228
+ # Wallet data: fetch time-aware data when fetcher is available.
1229
+ if self.fetcher:
1230
+ wallet_data, all_token_data = self._process_wallet_data(
1231
+ list(wallets_to_fetch),
1232
+ main_token_data.copy(),
1233
+ pooler,
1234
+ T_cutoff,
1235
+ profiles_override=None,
1236
+ socials_override=None,
1237
+ holdings_override=None
1238
+ )
1239
+ else:
1240
+ cached_social_bundle = raw_data.get('socials', {})
1241
+ offline_profiles = cached_social_bundle.get('profiles', {})
1242
+ offline_socials = cached_social_bundle.get('socials', {})
1243
+ offline_holdings = {} # Holdings not cached usually due to size
1244
+ wallet_data, all_token_data = self._process_wallet_data(
1245
+ list(wallets_to_fetch),
1246
+ main_token_data.copy(),
1247
+ pooler,
1248
+ T_cutoff,
1249
+ profiles_override=offline_profiles,
1250
+ socials_override=offline_socials,
1251
+ holdings_override=offline_holdings
1252
+ )
1253
 
1254
+ # Graph links: fetch time-aware graph when fetcher is available.
1255
  graph_entities = {}
1256
  graph_links = {}
1257
+ if self.fetcher and wallets_to_fetch:
1258
+ try:
1259
+ graph_entities, graph_links = self.fetcher.fetch_graph_links(
1260
+ list(wallets_to_fetch),
1261
+ T_cutoff,
1262
+ max_degrees=1
1263
+ )
1264
+ except Exception as e:
1265
+ print(f"ERROR: Failed to fetch graph links for {token_address}: {e}")
1266
 
1267
  # Generate the item
1268
  return self._generate_dataset_item(
 
1327
  max_horizon_seconds=self.max_cache_horizon_seconds,
1328
  include_wallet_data=False,
1329
  include_graph=False,
1330
+ min_trades=10, # Enforce min trades for context
1331
  full_history=True, # Bypass H/B/H limits
1332
  prune_failed=False, # Keep failed trades for realistic simulation
1333
  prune_transfers=False # Keep transfers for snapshot reconstruction
 
1433
  # Time is end of bucket
1434
  snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
1435
 
1436
+ # These queries can be slow; use single-call combined stats.
1437
+ count, top_holders = self.fetcher.fetch_holder_snapshot_stats_for_token(
1438
+ token_address,
1439
+ snapshot_ts,
1440
+ limit=HOLDER_SNAPSHOT_TOP_K
1441
+ )
1442
 
1443
  total_supply = raw_data.get('total_supply', 0) or 1
1444
  if raw_data.get('decimals'):
 
1555
 
1556
  # 3. Process Trades (Events + Chart)
1557
  trade_events = []
1558
+ transfer_events = []
1559
  aggregation_trades = []
1560
  high_def_chart_trades = []
1561
  middle_chart_trades = []
 
1649
  _register_event(trade_event, trade_sort_key)
1650
  trade_events.append(trade_event)
1651
 
1652
+ # 3b. Process Transfers
1653
+ for transfer in transfer_records:
1654
+ transfer_ts_val = _timestamp_to_order_value(transfer.get('timestamp'))
1655
+ transfer_ts_int = int(transfer_ts_val)
1656
+ amount_dec = float(transfer.get('amount_decimal', 0.0) or 0.0)
1657
+ source_balance = float(transfer.get('source_balance', 0.0) or 0.0)
1658
+ denom = source_balance + amount_dec if source_balance > 0 else 0.0
1659
+ transfer_pct_of_holding = (amount_dec / denom) if denom > 1e-9 else 0.0
1660
+ transfer_pct_of_supply = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
1661
+ is_large_transfer = transfer_pct_of_supply >= LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD
1662
+
1663
+ transfer_event = {
1664
+ 'event_type': 'LargeTransfer' if is_large_transfer else 'Transfer',
1665
+ 'timestamp': transfer_ts_int,
1666
+ 'relative_ts': transfer_ts_val - t0_timestamp,
1667
+ 'wallet_address': transfer.get('source'),
1668
+ 'destination_wallet_address': transfer.get('destination'),
1669
+ 'token_address': token_address,
1670
+ 'token_amount': amount_dec,
1671
+ 'transfer_pct_of_total_supply': transfer_pct_of_supply,
1672
+ 'transfer_pct_of_holding': transfer_pct_of_holding,
1673
+ 'priority_fee': transfer.get('priority_fee', 0.0),
1674
+ 'success': transfer.get('success', False)
1675
+ }
1676
+ _register_event(
1677
+ transfer_event,
1678
+ _event_execution_sort_key(
1679
+ transfer.get('timestamp'),
1680
+ slot=transfer.get('slot', 0),
1681
+ signature=transfer.get('signature', '')
1682
+ )
1683
+ )
1684
+ transfer_events.append(transfer_event)
1685
+
1686
  # 4. Generate Chart Events
1687
  def _finalize_chart(t_list):
1688
  t_list.sort(key=lambda x: x['sort_key'])
 
1745
  chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd", precomputed_ohlc=ohlc_1s_precomputed))
1746
  chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
1747
 
1748
+ # 5. Process Other Records (Pool, Liquidity, Fees, Burns, Locks, Migrations)
1749
+ pool_meta_by_address = {}
 
 
 
 
1750
  for pool_record in pool_creation_records:
1751
+ pool_addr = pool_record.get('pool_address')
1752
+ if pool_addr:
1753
+ pool_meta_by_address[pool_addr] = pool_record
1754
+
1755
+ pool_ts_val = _timestamp_to_order_value(pool_record.get('timestamp'))
1756
+ pool_ts = int(pool_ts_val)
1757
+ base_decimals = pool_record.get('base_decimals')
1758
+ quote_decimals = pool_record.get('quote_decimals')
1759
+ base_decimals = int(base_decimals) if base_decimals is not None else 0
1760
+ quote_decimals = int(quote_decimals) if quote_decimals is not None else 0
1761
+
1762
+ base_amount_raw = pool_record.get('initial_base_liquidity', 0) or 0
1763
+ quote_amount_raw = pool_record.get('initial_quote_liquidity', 0) or 0
1764
+ base_amount = float(base_amount_raw) / (10 ** base_decimals) if base_decimals > 0 else float(base_amount_raw)
1765
+ quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw)
1766
+
1767
  pool_event = {
1768
  'event_type': 'PoolCreated',
1769
  'timestamp': pool_ts,
1770
+ 'relative_ts': pool_ts_val - t0_timestamp,
1771
  'wallet_address': pool_record.get('creator_address'),
1772
  'token_address': token_address,
1773
+ 'quote_token_address': pool_record.get('quote_address'),
1774
+ 'protocol_id': pool_record.get('protocol', 0),
1775
+ 'pool_address': pool_addr,
1776
+ 'base_amount': base_amount,
1777
+ 'quote_amount': quote_amount,
1778
+ 'priority_fee': pool_record.get('priority_fee', 0.0),
1779
+ 'success': pool_record.get('success', False)
1780
+ }
1781
+ _register_event(
1782
+ pool_event,
1783
+ _event_execution_sort_key(
1784
+ pool_record.get('timestamp'),
1785
+ slot=pool_record.get('slot', 0),
1786
+ signature=pool_record.get('signature', '')
1787
+ )
1788
+ )
1789
+
1790
+ for liq_record in liquidity_change_records:
1791
+ liq_ts_val = _timestamp_to_order_value(liq_record.get('timestamp'))
1792
+ liq_ts = int(liq_ts_val)
1793
+ pool_addr = liq_record.get('pool_address')
1794
+ pool_meta = pool_meta_by_address.get(pool_addr, {})
1795
+ quote_decimals = pool_meta.get('quote_decimals')
1796
+ quote_decimals = int(quote_decimals) if quote_decimals is not None else 0
1797
+
1798
+ quote_amount_raw = liq_record.get('quote_amount', 0) or 0
1799
+ quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw)
1800
+
1801
+ liq_event = {
1802
+ 'event_type': 'LiquidityChange',
1803
+ 'timestamp': liq_ts,
1804
+ 'relative_ts': liq_ts_val - t0_timestamp,
1805
+ 'wallet_address': liq_record.get('lp_provider'),
1806
+ 'token_address': token_address,
1807
+ 'quote_token_address': pool_meta.get('quote_address'),
1808
+ 'protocol_id': liq_record.get('protocol', 0),
1809
+ 'change_type_id': liq_record.get('change_type', 0),
1810
+ 'quote_amount': quote_amount,
1811
+ 'priority_fee': liq_record.get('priority_fee', 0.0),
1812
+ 'success': liq_record.get('success', False)
1813
  }
1814
+ _register_event(
1815
+ liq_event,
1816
+ _event_execution_sort_key(
1817
+ liq_record.get('timestamp'),
1818
+ slot=liq_record.get('slot', 0),
1819
+ signature=liq_record.get('signature', '')
1820
+ )
1821
+ )
1822
+
1823
+ for fee_record in fee_collection_records:
1824
+ fee_ts_val = _timestamp_to_order_value(fee_record.get('timestamp'))
1825
+ fee_ts = int(fee_ts_val)
1826
+ amount = 0.0
1827
+ if fee_record.get('token_0_mint_address') == token_address:
1828
+ amount = float(fee_record.get('token_0_amount', 0.0) or 0.0)
1829
+ elif fee_record.get('token_1_mint_address') == token_address:
1830
+ amount = float(fee_record.get('token_1_amount', 0.0) or 0.0)
1831
+
1832
+ fee_event = {
1833
+ 'event_type': 'FeeCollected',
1834
+ 'timestamp': fee_ts,
1835
+ 'relative_ts': fee_ts_val - t0_timestamp,
1836
+ 'wallet_address': fee_record.get('recipient_address'),
1837
+ 'token_address': token_address,
1838
+ 'sol_amount': amount,
1839
+ 'protocol_id': fee_record.get('protocol', 0),
1840
+ 'priority_fee': fee_record.get('priority_fee', 0.0),
1841
+ 'success': fee_record.get('success', False)
1842
+ }
1843
+ _register_event(
1844
+ fee_event,
1845
+ _event_execution_sort_key(
1846
+ fee_record.get('timestamp'),
1847
+ slot=fee_record.get('slot', 0),
1848
+ signature=fee_record.get('signature', '')
1849
+ )
1850
+ )
1851
+
1852
+ for burn_record in burn_records:
1853
+ burn_ts_val = _timestamp_to_order_value(burn_record.get('timestamp'))
1854
+ burn_ts = int(burn_ts_val)
1855
+ amount_dec = float(burn_record.get('amount_decimal', 0.0) or 0.0)
1856
+ amount_pct = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
1857
+
1858
+ burn_event = {
1859
+ 'event_type': 'TokenBurn',
1860
+ 'timestamp': burn_ts,
1861
+ 'relative_ts': burn_ts_val - t0_timestamp,
1862
+ 'wallet_address': burn_record.get('source'),
1863
+ 'token_address': token_address,
1864
+ 'amount_pct_of_total_supply': amount_pct,
1865
+ 'amount_tokens_burned': amount_dec,
1866
+ 'priority_fee': burn_record.get('priority_fee', 0.0),
1867
+ 'success': burn_record.get('success', False)
1868
+ }
1869
+ _register_event(
1870
+ burn_event,
1871
+ _event_execution_sort_key(
1872
+ burn_record.get('timestamp'),
1873
+ slot=burn_record.get('slot', 0),
1874
+ signature=burn_record.get('signature', '')
1875
+ )
1876
+ )
1877
+
1878
+ for lock_record in supply_lock_records:
1879
+ lock_ts_val = _timestamp_to_order_value(lock_record.get('timestamp'))
1880
+ lock_ts = int(lock_ts_val)
1881
+ total_locked_amount = float(lock_record.get('total_locked_amount', 0.0) or 0.0)
1882
+ amount_pct = (total_locked_amount / total_supply_dec) if total_supply_dec > 0 else 0.0
1883
+ final_unlock_ts = lock_record.get('final_unlock_timestamp', 0) or 0
1884
+ lock_duration = float(final_unlock_ts) - float(lock_ts_val)
1885
+ if lock_duration < 0:
1886
+ lock_duration = 0.0
1887
+
1888
+ lock_event = {
1889
+ 'event_type': 'SupplyLock',
1890
+ 'timestamp': lock_ts,
1891
+ 'relative_ts': lock_ts_val - t0_timestamp,
1892
+ 'wallet_address': lock_record.get('sender'),
1893
+ 'token_address': token_address,
1894
+ 'amount_pct_of_total_supply': amount_pct,
1895
+ 'lock_duration': lock_duration,
1896
+ 'protocol_id': lock_record.get('protocol', 0),
1897
+ 'priority_fee': lock_record.get('priority_fee', 0.0),
1898
+ 'success': lock_record.get('success', False)
1899
+ }
1900
+ _register_event(
1901
+ lock_event,
1902
+ _event_execution_sort_key(
1903
+ lock_record.get('timestamp'),
1904
+ slot=lock_record.get('slot', 0),
1905
+ signature=lock_record.get('signature', '')
1906
+ )
1907
+ )
1908
+
1909
+ for migration_record in migration_records:
1910
+ mig_ts_val = _timestamp_to_order_value(migration_record.get('timestamp'))
1911
+ mig_ts = int(mig_ts_val)
1912
+ mig_event = {
1913
+ 'event_type': 'Migrated',
1914
+ 'timestamp': mig_ts,
1915
+ 'relative_ts': mig_ts_val - t0_timestamp,
1916
+ 'wallet_address': None,
1917
+ 'token_address': token_address,
1918
+ 'protocol_id': migration_record.get('protocol', 0),
1919
+ 'priority_fee': migration_record.get('priority_fee', 0.0),
1920
+ 'success': migration_record.get('success', False)
1921
+ }
1922
+ _register_event(
1923
+ mig_event,
1924
+ _event_execution_sort_key(
1925
+ migration_record.get('timestamp'),
1926
+ slot=migration_record.get('slot', 0),
1927
+ signature=migration_record.get('signature', '')
1928
+ )
1929
+ )
1930
 
1931
  # 6. Generate Snapshots
1932
  self._generate_onchain_snapshots(
1933
  token_address, int(t0_timestamp), T_cutoff,
1934
  300, # Interval
1935
+ trade_events, transfer_events,
1936
  aggregation_trades,
1937
  wallet_data,
1938
  total_supply_dec,
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:46809f070aa1dfcb4f53d7390b1b6ff370e6828e198df4c0df5632ac6fa9f607
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038695f9d26e59395a2e69b52f9da029cf8796b06e4d503f0c18191288ad2a02
3
  size 1660
database.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python scripts/download_epoch_artifacts.py
2
+ python scripts/ingest_epoch.py --epoch 844
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:41885991264f1522ec8b539dd4f3f738d537102a65103a800578229feef13880
3
- size 18007
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82f53ccd1f7ea7868893c879fabdea0b5939a3780cc5086790edb79d878e8121
3
+ size 32490
models/helper_encoders.py CHANGED
@@ -43,8 +43,8 @@ class ContextualTimeEncoder(nn.Module):
43
  half_dim = d_model // 2
44
 
45
  # Calculations for sinusoidal encoding are more stable in float32
46
- div_term = torch.exp(torch.arange(0, half_dim, device=device).float() * -(math.log(10000.0) / half_dim))
47
- args = values.float().unsqueeze(-1) * div_term
48
 
49
  return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
50
 
@@ -63,15 +63,16 @@ class ContextualTimeEncoder(nn.Module):
63
 
64
  # 1. Store original shape (e.g., [B, L]) and flatten
65
  original_shape = timestamps.shape
66
- timestamps_flat = timestamps.flatten().float() # Shape [N_total]
 
67
 
68
  # 2. Sinusoidal encode (already vectorized)
69
  ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
70
 
71
  # 3. List comprehension (this is the only non-vectorized part)
72
  # This loop is now correct, as it iterates over the 1D flat tensor
73
- hours = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
74
- days = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
75
 
76
  # 4. Cyclical encode (already vectorized)
77
  hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
 
43
  half_dim = d_model // 2
44
 
45
  # Calculations for sinusoidal encoding are more stable in float32
46
+ div_term = torch.exp(torch.arange(0, half_dim, device=device, dtype=torch.float64) * -(math.log(10000.0) / half_dim))
47
+ args = values.double().unsqueeze(-1) * div_term
48
 
49
  return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
50
 
 
63
 
64
  # 1. Store original shape (e.g., [B, L]) and flatten
65
  original_shape = timestamps.shape
66
+ # Preserve precision for large Unix timestamps.
67
+ timestamps_flat = timestamps.flatten().double() # Shape [N_total]
68
 
69
  # 2. Sinusoidal encode (already vectorized)
70
  ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
71
 
72
  # 3. List comprehension (this is the only non-vectorized part)
73
  # This loop is now correct, as it iterates over the 1D flat tensor
74
+ hours = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
75
+ days = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
76
 
77
  # 4. Cyclical encode (already vectorized)
78
  hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
pre_cache.sh CHANGED
@@ -3,7 +3,5 @@
3
 
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
- --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
7
- --max_samples 50
8
-
9
  echo "Done!"
 
3
 
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
+ --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz"
 
 
7
  echo "Done!"
scripts/inspect_collator.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List
7
+
8
+ import torch
9
+
10
+ # Ensure repo root is on sys.path
11
+ REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ sys.path.insert(0, str(REPO_ROOT))
13
+
14
+ from data.data_loader import OracleDataset
15
+ from data.data_fetcher import DataFetcher
16
+ from data.data_collator import MemecoinCollator
17
+ import models.vocabulary as vocab
18
+
19
+
20
+ def _decode_events(event_type_ids: torch.Tensor) -> List[str]:
21
+ names = []
22
+ for eid in event_type_ids.tolist():
23
+ if eid == 0:
24
+ names.append("__PAD__")
25
+ else:
26
+ names.append(vocab.ID_TO_EVENT.get(eid, f"UNK_{eid}"))
27
+ return names
28
+
29
+
30
+ def _tensor_to_list(t: torch.Tensor) -> List:
31
+ return t.detach().cpu().tolist()
32
+
33
+
34
+ def main() -> None:
35
+ parser = argparse.ArgumentParser(description="Inspect MemecoinCollator outputs on cached samples.")
36
+ parser.add_argument("--cache_dir", type=str, default="data/cache")
37
+ parser.add_argument("--idx", type=int, nargs="+", default=[0], help="Sample indices to inspect")
38
+ parser.add_argument("--max_seq_len", type=int, default=16000)
39
+ parser.add_argument("--out", type=str, default="collator_dump.json")
40
+ args = parser.parse_args()
41
+
42
+ cache_dir = Path(args.cache_dir)
43
+ # Optional: enable time-aware fetches if DB env is set.
44
+ import os
45
+ from dotenv import load_dotenv
46
+ from clickhouse_driver import Client as ClickHouseClient
47
+ from neo4j import GraphDatabase
48
+
49
+ load_dotenv()
50
+ clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost")
51
+ clickhouse_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", os.getenv("CLICKHOUSE_PORT", 9000)))
52
+ neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
53
+ neo4j_user = os.getenv("NEO4J_USER", "neo4j")
54
+ neo4j_password = os.getenv("NEO4J_PASSWORD", "password")
55
+ clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port)
56
+ neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
57
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
58
+
59
+ dataset = OracleDataset(
60
+ data_fetcher=data_fetcher,
61
+ cache_dir=str(cache_dir),
62
+ horizons_seconds=[30, 60, 120, 240, 420],
63
+ quantiles=[0.1, 0.5, 0.9],
64
+ max_samples=None,
65
+ max_seq_len=args.max_seq_len,
66
+ )
67
+ if hasattr(dataset, "init_fetcher"):
68
+ dataset.init_fetcher()
69
+
70
+ collator = MemecoinCollator(
71
+ event_type_to_id=vocab.EVENT_TO_ID,
72
+ device=torch.device("cpu"),
73
+ dtype=torch.float32,
74
+ max_seq_len=args.max_seq_len,
75
+ )
76
+
77
+ batch_items = [dataset[i] for i in args.idx]
78
+ batch = collator(batch_items)
79
+
80
+ # Build JSON-friendly dump (no truncation of events; embeddings are omitted)
81
+ dump: Dict[str, Any] = {
82
+ "batch_size": len(args.idx),
83
+ "token_addresses": batch.get("token_addresses"),
84
+ "t_cutoffs": batch.get("t_cutoffs"),
85
+ "sample_indices": batch.get("sample_indices"),
86
+ "raw_events": [item.get("event_sequence", []) for item in batch_items],
87
+ }
88
+ # Raw event type counts
89
+ event_counts = []
90
+ for item in batch_items:
91
+ counts: Dict[str, int] = {}
92
+ for ev in item.get("event_sequence", []):
93
+ et = ev.get("event_type", "UNKNOWN")
94
+ counts[et] = counts.get(et, 0) + 1
95
+ event_counts.append(counts)
96
+ dump["raw_event_counts"] = event_counts
97
+
98
+ # Core sequence + features (full length)
99
+ dump["event_type_ids"] = _tensor_to_list(batch["event_type_ids"])
100
+ dump["event_type_names"] = [
101
+ _decode_events(batch["event_type_ids"][i].cpu())
102
+ for i in range(batch["event_type_ids"].shape[0])
103
+ ]
104
+ dump["timestamps_float"] = _tensor_to_list(batch["timestamps_float"])
105
+ dump["relative_ts"] = _tensor_to_list(batch["relative_ts"])
106
+ dump["attention_mask"] = _tensor_to_list(batch["attention_mask"])
107
+ dump["wallet_addr_to_batch_idx"] = batch.get("wallet_addr_to_batch_idx", {})
108
+
109
+ # Pointer tensors
110
+ for key in [
111
+ "wallet_indices",
112
+ "token_indices",
113
+ "quote_token_indices",
114
+ "trending_token_indices",
115
+ "boosted_token_indices",
116
+ "dest_wallet_indices",
117
+ "original_author_indices",
118
+ "ohlc_indices",
119
+ "holder_snapshot_indices",
120
+ "textual_event_indices",
121
+ ]:
122
+ if key in batch:
123
+ dump[key] = _tensor_to_list(batch[key])
124
+
125
+ # Numerical feature tensors
126
+ nonzero_summary = {}
127
+ for key in [
128
+ "transfer_numerical_features",
129
+ "trade_numerical_features",
130
+ "deployer_trade_numerical_features",
131
+ "smart_wallet_trade_numerical_features",
132
+ "pool_created_numerical_features",
133
+ "liquidity_change_numerical_features",
134
+ "fee_collected_numerical_features",
135
+ "token_burn_numerical_features",
136
+ "supply_lock_numerical_features",
137
+ "onchain_snapshot_numerical_features",
138
+ "trending_token_numerical_features",
139
+ "boosted_token_numerical_features",
140
+ "dexboost_paid_numerical_features",
141
+ "dexprofile_updated_flags",
142
+ "global_trending_numerical_features",
143
+ "chainsnapshot_numerical_features",
144
+ "lighthousesnapshot_numerical_features",
145
+ ]:
146
+ if key in batch:
147
+ t = batch[key]
148
+ dump[key] = _tensor_to_list(t)
149
+ nonzero_summary[key] = int(torch.count_nonzero(t).item())
150
+
151
+ # Categorical feature tensors
152
+ for key in [
153
+ "trade_dex_ids",
154
+ "trade_direction_ids",
155
+ "trade_mev_protection_ids",
156
+ "trade_is_bundle_ids",
157
+ "pool_created_protocol_ids",
158
+ "liquidity_change_type_ids",
159
+ "trending_token_source_ids",
160
+ "trending_token_timeframe_ids",
161
+ "lighthousesnapshot_protocol_ids",
162
+ "lighthousesnapshot_timeframe_ids",
163
+ "migrated_protocol_ids",
164
+ "alpha_group_ids",
165
+ "channel_ids",
166
+ "exchange_ids",
167
+ ]:
168
+ if key in batch:
169
+ t = batch[key]
170
+ dump[key] = _tensor_to_list(t)
171
+ nonzero_summary[key] = int(torch.count_nonzero(t).item())
172
+
173
+ # Labels
174
+ if batch.get("labels") is not None:
175
+ dump["labels"] = _tensor_to_list(batch["labels"])
176
+ if batch.get("labels_mask") is not None:
177
+ dump["labels_mask"] = _tensor_to_list(batch["labels_mask"])
178
+ if batch.get("quality_score") is not None:
179
+ dump["quality_score"] = _tensor_to_list(batch["quality_score"])
180
+
181
+ dump["nonzero_summary"] = nonzero_summary
182
+
183
+ # Raw wallet/token feature payloads used by encoders
184
+ wallet_inputs = batch.get("wallet_encoder_inputs", {})
185
+ token_inputs = batch.get("token_encoder_inputs", {})
186
+ dump["wallet_encoder_inputs"] = {
187
+ "profile_rows": wallet_inputs.get("profile_rows", []),
188
+ "social_rows": wallet_inputs.get("social_rows", []),
189
+ "holdings_batch": wallet_inputs.get("holdings_batch", []),
190
+ "username_embed_indices": _tensor_to_list(wallet_inputs.get("username_embed_indices")) if "username_embed_indices" in wallet_inputs else [],
191
+ }
192
+ dump["token_encoder_inputs"] = {
193
+ "addresses_for_lookup": token_inputs.get("_addresses_for_lookup", []),
194
+ "protocol_ids": _tensor_to_list(token_inputs.get("protocol_ids")) if "protocol_ids" in token_inputs else [],
195
+ "is_vanity_flags": _tensor_to_list(token_inputs.get("is_vanity_flags")) if "is_vanity_flags" in token_inputs else [],
196
+ "name_embed_indices": _tensor_to_list(token_inputs.get("name_embed_indices")) if "name_embed_indices" in token_inputs else [],
197
+ "symbol_embed_indices": _tensor_to_list(token_inputs.get("symbol_embed_indices")) if "symbol_embed_indices" in token_inputs else [],
198
+ "image_embed_indices": _tensor_to_list(token_inputs.get("image_embed_indices")) if "image_embed_indices" in token_inputs else [],
199
+ }
200
+ dump["wallet_set_encoder_inputs"] = {
201
+ "holdings_batch": wallet_inputs.get("holdings_batch", []),
202
+ "token_vibe_lookup_keys": token_inputs.get("_addresses_for_lookup", []),
203
+ }
204
+
205
+ out_path = Path(args.out)
206
+ def _json_default(o):
207
+ if isinstance(o, (str, int, float, bool)) or o is None:
208
+ return o
209
+ try:
210
+ import datetime as _dt
211
+ if isinstance(o, (_dt.datetime, _dt.date)):
212
+ return o.isoformat()
213
+ except Exception:
214
+ pass
215
+ try:
216
+ return str(o)
217
+ except Exception:
218
+ return "<unserializable>"
219
+
220
+ with out_path.open("w") as f:
221
+ json.dump(dump, f, indent=2, default=_json_default)
222
+
223
+ print(f"Wrote collator dump to {out_path.resolve()}")
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
train.py CHANGED
@@ -118,6 +118,15 @@ def quantile_pinball_loss(preds: torch.Tensor,
118
  return sum(losses) / mask.sum().clamp_min(1.0)
119
 
120
 
 
 
 
 
 
 
 
 
 
121
  def filtered_collate(collator: MemecoinCollator,
122
  batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
123
  """Filter out None items from the dataset before collating."""
@@ -304,13 +313,18 @@ def main() -> None:
304
  max_seq_len=max_seq_len
305
  )
306
 
307
- # DB Connections - REMOVED for Training (Using Cache)
308
- # clickhouse_client = ClickHouseClient(...)
309
- # neo4j_driver = GraphDatabase.driver(...)
310
- # data_fetcher = DataFetcher(...)
 
 
 
 
311
 
312
  dataset = OracleDataset(
313
- data_fetcher=None, # Training Mode (Reader Only)
 
314
  horizons_seconds=horizons,
315
  quantiles=quantiles,
316
  max_samples=args.max_samples,
@@ -339,6 +353,10 @@ def main() -> None:
339
  else:
340
  logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
341
 
 
 
 
 
342
  dl_kwargs = dict(
343
  dataset=dataset,
344
  batch_size=batch_size,
@@ -353,6 +371,9 @@ def main() -> None:
353
  # re-initializes heavy per-worker state (e.g. SigLIP MultiModalEncoder).
354
  dl_kwargs["persistent_workers"] = True
355
  dl_kwargs["prefetch_factor"] = 2
 
 
 
356
  dataloader = DataLoader(**dl_kwargs)
357
 
358
  # --- 3. Model Init ---
@@ -599,7 +620,6 @@ def main() -> None:
599
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
600
 
601
  accelerator.end_training()
602
- # neo4j_driver.close() # REMOVED
603
 
604
  if __name__ == "__main__":
605
  main()
 
118
  return sum(losses) / mask.sum().clamp_min(1.0)
119
 
120
 
121
+ def init_worker_fetcher(worker_id: int) -> None:
122
+ """Initialize per-worker DataFetcher for cached datasets."""
123
+ worker_info = torch.utils.data.get_worker_info()
124
+ if worker_info is not None:
125
+ ds = worker_info.dataset
126
+ if hasattr(ds, "init_fetcher"):
127
+ ds.init_fetcher()
128
+
129
+
130
  def filtered_collate(collator: MemecoinCollator,
131
  batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
132
  """Filter out None items from the dataset before collating."""
 
313
  max_seq_len=max_seq_len
314
  )
315
 
316
+ # DB config (for time-aware wallet/token/graph features during training)
317
+ fetcher_config = {
318
+ "clickhouse_host": os.getenv("CLICKHOUSE_HOST", "localhost"),
319
+ "clickhouse_port": int(os.getenv("CLICKHOUSE_PORT", 9000)),
320
+ "neo4j_uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
321
+ "neo4j_user": os.getenv("NEO4J_USER", "neo4j"),
322
+ "neo4j_password": os.getenv("NEO4J_PASSWORD", "password"),
323
+ }
324
 
325
  dataset = OracleDataset(
326
+ data_fetcher=None,
327
+ fetcher_config=fetcher_config, # Training Mode (Cache + time-aware fetch per worker)
328
  horizons_seconds=horizons,
329
  quantiles=quantiles,
330
  max_samples=args.max_samples,
 
353
  else:
354
  logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
355
 
356
+ # Initialize DataFetcher in main process when not using workers.
357
+ if int(args.num_workers) == 0:
358
+ dataset.init_fetcher()
359
+
360
  dl_kwargs = dict(
361
  dataset=dataset,
362
  batch_size=batch_size,
 
371
  # re-initializes heavy per-worker state (e.g. SigLIP MultiModalEncoder).
372
  dl_kwargs["persistent_workers"] = True
373
  dl_kwargs["prefetch_factor"] = 2
374
+ if int(args.num_workers) > 0:
375
+ dl_kwargs["worker_init_fn"] = init_worker_fetcher
376
+
377
  dataloader = DataLoader(**dl_kwargs)
378
 
379
  # --- 3. Model Init ---
 
620
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
621
 
622
  accelerator.end_training()
 
623
 
624
  if __name__ == "__main__":
625
  main()