zirobtc commited on
Commit
d195287
·
verified ·
1 Parent(s): c471f42

Upload folder using huggingface_hub

Browse files
audit_cache.py CHANGED
@@ -1,151 +1,249 @@
1
- import os
2
- import torch
3
- import math
4
  import argparse
 
 
5
  from pathlib import Path
6
- from collections import defaultdict
7
- import glob
8
  from tqdm import tqdm
9
 
10
- def audit_cache(cache_dir, num_samples=10000):
11
- files = glob.glob(os.path.join(cache_dir, "sample_*.pt"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if not files:
13
- print(f"No .pt files found in {cache_dir}")
14
  return
15
 
16
- files = files[:num_samples]
17
-
18
- issues = defaultdict(int)
19
- total_files = len(files)
20
-
 
 
 
 
 
 
21
  stats = {
22
- 'max_label_return': -float('inf'),
23
- 'min_label_return': float('inf'),
24
- 'nan_labels': 0,
25
- 'nan_masks': 0,
26
- 'missing_quality_score': 0,
27
- 'negative_quality_score': 0,
28
- 'empty_event_sequence': 0,
29
- 'missing_wallets': 0,
30
- 'nan_in_wallet_profile': 0,
31
- 'nan_in_events': 0,
32
- 'inf_in_events': 0,
33
- 'invalid_pool_idx': 0,
34
- 'max_seq_len_exceeded': 0,
35
- 'negative_prices': 0,
 
 
 
 
36
  }
37
 
38
- for fpath in tqdm(files, desc="Auditing"):
39
  try:
40
- try:
41
- data = torch.load(fpath, map_location="cpu", weights_only=False)
42
- except Exception:
43
- issues['load_error'] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  continue
45
-
46
- # 1. Quality score
47
- q_score = data.get("quality_score")
48
- if q_score is None:
49
- stats['missing_quality_score'] += 1
50
- elif math.isnan(q_score):
51
- issues['nan_quality_score'] += 1
52
- elif q_score < 0:
53
- stats['negative_quality_score'] += 1
54
-
55
- # 2. Labels & Masks
56
- labels = data.get("future_return_labels")
57
- masks = data.get("future_return_masks")
58
- if labels is not None:
59
- for v in labels.tolist():
60
- if math.isnan(v):
61
- stats['nan_labels'] += 1
62
- break
63
- stats['max_label_return'] = max(stats['max_label_return'], float(v))
64
- stats['min_label_return'] = min(stats['min_label_return'], float(v))
65
- if masks is not None:
66
- for v in masks.tolist():
67
- if math.isnan(v):
68
- stats['nan_masks'] += 1
69
- break
70
-
71
- # 3. Events
72
- events = data.get("event_sequence", [])
73
- if not events:
74
- stats['empty_event_sequence'] += 1
75
- elif len(events) > 8192:
76
- stats['max_seq_len_exceeded'] += 1
77
-
78
- has_nan_event = False
79
- has_inf_event = False
80
- has_neg_price = False
81
-
82
- for event in events:
83
- for k, v in event.items():
84
- if isinstance(v, float):
85
- if math.isnan(v):
86
- has_nan_event = True
87
- elif math.isinf(v):
88
- has_inf_event = True
89
-
90
- # Check chart segments specifically
91
- if event.get("event_type") == "Chart_Segment":
92
- opens = event.get("opens", [])
93
- closes = event.get("closes", [])
94
- # The user report said spread is zero due to z-score
95
- # Let's check if prices are negative
96
- if opens and min(opens) < 0:
97
- has_neg_price = True
98
-
99
- if has_nan_event:
100
- stats['nan_in_events'] += 1
101
- if has_inf_event:
102
- stats['inf_in_events'] += 1
103
- if has_neg_price:
104
- stats['negative_prices'] += 1
105
-
106
- # 4. Wallets
107
- wallets = data.get("wallets")
108
- if not wallets:
109
- stats['missing_wallets'] += 1
110
- else:
111
- has_nan_wallet = False
112
- for w_addr, w_data in wallets.items():
113
- profile = w_data.get("profile", {})
114
- for k, v in profile.items():
115
- if isinstance(v, float) and math.isnan(v):
116
- has_nan_wallet = True
117
- if has_nan_wallet:
118
- stats['nan_in_wallet_profile'] += 1
119
-
120
- # 5. Pool index out of bounds
121
- pool = data.get("embedding_pool", [])
122
- pool_idxs = [item.get("idx") for item in pool]
123
- max_idx = max(pool_idxs) if pool_idxs else 0
124
-
125
- for event in events:
126
- for k, v in event.items():
127
- if k.endswith("_idx") and isinstance(v, int):
128
- if v > max_idx:
129
- stats['invalid_pool_idx'] += 1
130
- break
131
-
132
- except Exception as e:
133
- issues[f'processing_error_{type(e).__name__}'] += 1
134
-
135
-
136
- print("\n--- Cache Audit Report (New Issues) ---")
137
- print(f"Audited {total_files} files.")
138
- for k, v in stats.items():
139
- print(f"{k}: {v}")
140
-
141
- print("\nIssues encountered:")
142
- for k, v in issues.items():
143
- print(f"{k}: {v}")
144
 
145
  if __name__ == "__main__":
146
  parser = argparse.ArgumentParser()
147
  parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache")
148
- parser.add_argument("--num", type=int, default=1000)
149
  args = parser.parse_args()
150
-
151
  audit_cache(args.cache_dir, args.num)
 
 
 
 
1
  import argparse
2
+ import math
3
+ from collections import Counter, defaultdict
4
  from pathlib import Path
5
+
6
+ import torch
7
  from tqdm import tqdm
8
 
9
+ from data.data_loader import summarize_context_window
10
+ from data.quant_ohlc_feature_schema import FEATURE_VERSION, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
11
+
12
+
13
+ REQUIRED_CONTEXT_FIELDS = [
14
+ "event_sequence",
15
+ "wallets",
16
+ "tokens",
17
+ "labels",
18
+ "labels_mask",
19
+ "quality_score",
20
+ "class_id",
21
+ "source_token",
22
+ "context_bucket",
23
+ "context_score",
24
+ "quant_ohlc_features",
25
+ "quant_feature_version",
26
+ ]
27
+
28
+
29
+ def _to_list(value):
30
+ if value is None:
31
+ return []
32
+ if isinstance(value, torch.Tensor):
33
+ return value.tolist()
34
+ return list(value)
35
+
36
+
37
+ def _safe_float(value):
38
+ if isinstance(value, torch.Tensor):
39
+ if value.numel() != 1:
40
+ raise ValueError("Expected scalar tensor.")
41
+ return float(value.item())
42
+ return float(value)
43
+
44
+
45
+ def audit_cache(cache_dir, num_samples=None):
46
+ cache_path = Path(cache_dir)
47
+ files = sorted(cache_path.glob("sample_*.pt"))
48
  if not files:
49
+ print(f"No sample_*.pt files found in {cache_path}")
50
  return
51
 
52
+ if num_samples is not None and num_samples > 0:
53
+ files = files[:num_samples]
54
+
55
+ issues = Counter()
56
+ class_counts = Counter()
57
+ bucket_counts = Counter()
58
+ class_bucket_counts = defaultdict(Counter)
59
+ token_counts_by_class = defaultdict(Counter)
60
+ samples_per_token = Counter()
61
+ missing_fields = Counter()
62
+
63
  stats = {
64
+ "files_audited": len(files),
65
+ "empty_event_sequence": 0,
66
+ "missing_wallets": 0,
67
+ "missing_tokens": 0,
68
+ "nan_labels": 0,
69
+ "nan_masks": 0,
70
+ "nan_quality_score": 0,
71
+ "negative_quality_score": 0,
72
+ "max_label_return": -float("inf"),
73
+ "min_label_return": float("inf"),
74
+ "max_events": 0,
75
+ "min_events": float("inf"),
76
+ "contexts_with_no_valid_horizons": 0,
77
+ "context_bucket_mismatch": 0,
78
+ "context_score_mismatch": 0,
79
+ "quant_feature_version_mismatch": 0,
80
+ "chart_events_missing_quant": 0,
81
+ "quant_segments_total": 0,
82
  }
83
 
84
+ for filepath in tqdm(files, desc="Auditing cache", unit="file"):
85
  try:
86
+ data = torch.load(filepath, map_location="cpu", weights_only=False)
87
+ except Exception:
88
+ issues["load_error"] += 1
89
+ continue
90
+
91
+ if not isinstance(data, dict):
92
+ issues["not_dict"] += 1
93
+ continue
94
+
95
+ missing_for_file = []
96
+ for field in REQUIRED_CONTEXT_FIELDS:
97
+ if field not in data:
98
+ missing_for_file.append(field)
99
+ missing_fields[field] += 1
100
+
101
+ if missing_for_file:
102
+ issues["missing_required_fields"] += 1
103
+ continue
104
+
105
+ class_id = int(data["class_id"])
106
+ source_token = str(data["source_token"])
107
+ context_bucket = str(data["context_bucket"])
108
+
109
+ class_counts[class_id] += 1
110
+ bucket_counts[context_bucket] += 1
111
+ class_bucket_counts[class_id][context_bucket] += 1
112
+ token_counts_by_class[class_id][source_token] += 1
113
+ samples_per_token[source_token] += 1
114
+
115
+ events = data.get("event_sequence") or []
116
+ wallets = data.get("wallets") or {}
117
+ tokens = data.get("tokens") or {}
118
+ labels = _to_list(data.get("labels"))
119
+ masks = _to_list(data.get("labels_mask"))
120
+
121
+ if not events:
122
+ stats["empty_event_sequence"] += 1
123
+ stats["max_events"] = max(stats["max_events"], len(events))
124
+ stats["min_events"] = min(stats["min_events"], len(events))
125
+
126
+ if not wallets:
127
+ stats["missing_wallets"] += 1
128
+ if not tokens:
129
+ stats["missing_tokens"] += 1
130
+
131
+ has_nan_label = False
132
+ for value in labels:
133
+ if math.isnan(float(value)):
134
+ has_nan_label = True
135
+ break
136
+ stats["max_label_return"] = max(stats["max_label_return"], float(value))
137
+ stats["min_label_return"] = min(stats["min_label_return"], float(value))
138
+ if has_nan_label:
139
+ stats["nan_labels"] += 1
140
+
141
+ has_nan_mask = False
142
+ for value in masks:
143
+ if math.isnan(float(value)):
144
+ has_nan_mask = True
145
+ break
146
+ if has_nan_mask:
147
+ stats["nan_masks"] += 1
148
+
149
+ try:
150
+ quality_score = _safe_float(data.get("quality_score"))
151
+ if math.isnan(quality_score):
152
+ stats["nan_quality_score"] += 1
153
+ elif quality_score < 0:
154
+ stats["negative_quality_score"] += 1
155
+ except Exception:
156
+ issues["invalid_quality_score"] += 1
157
+
158
+ try:
159
+ summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
160
+ if summary["valid_horizons"] == 0:
161
+ stats["contexts_with_no_valid_horizons"] += 1
162
+ if summary["context_bucket"] != context_bucket:
163
+ stats["context_bucket_mismatch"] += 1
164
+ stored_score = _safe_float(data.get("context_score"))
165
+ if not math.isclose(summary["context_score"], stored_score, rel_tol=1e-6, abs_tol=1e-6):
166
+ stats["context_score_mismatch"] += 1
167
+ except Exception:
168
+ issues["context_summary_error"] += 1
169
+
170
+ if data.get("quant_feature_version") != FEATURE_VERSION:
171
+ stats["quant_feature_version_mismatch"] += 1
172
+
173
+ chart_events = [event for event in events if event.get("event_type") == "Chart_Segment"]
174
+ stats["quant_segments_total"] += len(chart_events)
175
+ for event in chart_events:
176
+ quant_payload = event.get("quant_ohlc_features")
177
+ if not isinstance(quant_payload, list):
178
+ stats["chart_events_missing_quant"] += 1
179
  continue
180
+ if len(quant_payload) > TOKENS_PER_SEGMENT:
181
+ issues["quant_too_many_tokens"] += 1
182
+ for token_payload in quant_payload:
183
+ vec = token_payload.get("feature_vector")
184
+ if not isinstance(vec, list) or len(vec) != NUM_QUANT_OHLC_FEATURES:
185
+ issues["quant_bad_vector_shape"] += 1
186
+ break
187
+
188
+ if stats["min_events"] == float("inf"):
189
+ stats["min_events"] = 0
190
+ if stats["min_label_return"] == float("inf"):
191
+ stats["min_label_return"] = 0.0
192
+ if stats["max_label_return"] == -float("inf"):
193
+ stats["max_label_return"] = 0.0
194
+
195
+ unique_tokens_total = len(samples_per_token)
196
+ duplicate_tokens_total = sum(1 for count in samples_per_token.values() if count > 1)
197
+
198
+ print("\n=== Cache Audit ===")
199
+ print(f"Cache dir: {cache_path}")
200
+ print(f"Files audited: {stats['files_audited']}")
201
+ print(f"Unique source tokens: {unique_tokens_total}")
202
+ print(f"Tokens with >1 cached context: {duplicate_tokens_total}")
203
+ print(f"Samples per token max: {max(samples_per_token.values()) if samples_per_token else 0}")
204
+
205
+ print("\n--- Class Counts ---")
206
+ for class_id in sorted(class_counts):
207
+ unique_tokens = len(token_counts_by_class[class_id])
208
+ print(f"Class {class_id}: samples={class_counts[class_id]} unique_tokens={unique_tokens}")
209
+
210
+ print("\n--- Context Buckets ---")
211
+ for bucket, count in sorted(bucket_counts.items()):
212
+ print(f"{bucket}: {count}")
213
+
214
+ print("\n--- Class x Context Bucket ---")
215
+ for class_id in sorted(class_bucket_counts):
216
+ bucket_summary = dict(sorted(class_bucket_counts[class_id].items()))
217
+ print(f"Class {class_id}: {bucket_summary}")
218
+
219
+ print("\n--- General Stats ---")
220
+ for key, value in stats.items():
221
+ print(f"{key}: {value}")
222
+
223
+ print("\n--- Missing Fields ---")
224
+ if missing_fields:
225
+ for field, count in sorted(missing_fields.items()):
226
+ print(f"{field}: {count}")
227
+ else:
228
+ print("none")
229
+
230
+ print("\n--- Issues ---")
231
+ if issues:
232
+ for key, value in sorted(issues.items()):
233
+ print(f"{key}: {value}")
234
+ else:
235
+ print("none")
236
+
237
+ print("\n--- Duplicate-Heavy Tokens ---")
238
+ heavy_tokens = sorted(samples_per_token.items(), key=lambda item: (-item[1], item[0]))[:20]
239
+ for token, count in heavy_tokens:
240
+ print(f"{token}: {count}")
241
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  if __name__ == "__main__":
244
  parser = argparse.ArgumentParser()
245
  parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache")
246
+ parser.add_argument("--num", type=int, default=None, help="Audit only the first N files.")
247
  args = parser.parse_args()
248
+
249
  audit_cache(args.cache_dir, args.num)
data/data_collator.py CHANGED
@@ -10,6 +10,7 @@ from PIL import Image
10
 
11
  import models.vocabulary as vocab
12
  from data.data_loader import EmbeddingPooler
 
13
 
14
  NATIVE_MINT = "So11111111111111111111111111111111111111112"
15
  QUOTE_MINTS = {
@@ -40,6 +41,8 @@ class MemecoinCollator:
40
  self.device = device
41
  self.dtype = dtype
42
  self.ohlc_seq_len = 300 # HARDCODED
 
 
43
  self.max_seq_len = max_seq_len
44
 
45
  def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
@@ -86,10 +89,16 @@ class MemecoinCollator:
86
  if not chart_events:
87
  return {
88
  'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
89
- 'interval_ids': torch.empty(0, device=self.device, dtype=torch.long)
 
 
 
90
  }
91
  ohlc_tensors = []
92
  interval_ids_list = []
 
 
 
93
  seq_len = self.ohlc_seq_len
94
  unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
95
  for segment_data in chart_events:
@@ -105,9 +114,36 @@ class MemecoinCollator:
105
  ohlc_tensors.append(torch.stack([o, c]))
106
  interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
107
  interval_ids_list.append(interval_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return {
109
  'price_tensor': torch.stack(ohlc_tensors).to(self.device),
110
- 'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long)
 
 
 
111
  }
112
 
113
  def _collate_graph_links(self,
@@ -677,6 +713,9 @@ class MemecoinCollator:
677
  'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
678
  'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
679
  'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
 
 
 
680
  'graph_updater_links': graph_updater_links,
681
  'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
682
 
 
10
 
11
  import models.vocabulary as vocab
12
  from data.data_loader import EmbeddingPooler
13
+ from data.quant_ohlc_feature_schema import FEATURE_VERSION, FEATURE_VERSION_ID, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
14
 
15
  NATIVE_MINT = "So11111111111111111111111111111111111111112"
16
  QUOTE_MINTS = {
 
41
  self.device = device
42
  self.dtype = dtype
43
  self.ohlc_seq_len = 300 # HARDCODED
44
+ self.quant_ohlc_tokens = TOKENS_PER_SEGMENT
45
+ self.quant_ohlc_num_features = NUM_QUANT_OHLC_FEATURES
46
  self.max_seq_len = max_seq_len
47
 
48
  def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
 
89
  if not chart_events:
90
  return {
91
  'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
92
+ 'interval_ids': torch.empty(0, device=self.device, dtype=torch.long),
93
+ 'quant_feature_tensors': torch.empty(0, self.quant_ohlc_tokens, self.quant_ohlc_num_features, device=self.device, dtype=self.dtype),
94
+ 'quant_feature_mask': torch.empty(0, self.quant_ohlc_tokens, device=self.device, dtype=self.dtype),
95
+ 'quant_feature_version_ids': torch.empty(0, device=self.device, dtype=torch.long),
96
  }
97
  ohlc_tensors = []
98
  interval_ids_list = []
99
+ quant_feature_tensors = []
100
+ quant_feature_masks = []
101
+ quant_feature_version_ids = []
102
  seq_len = self.ohlc_seq_len
103
  unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
104
  for segment_data in chart_events:
 
114
  ohlc_tensors.append(torch.stack([o, c]))
115
  interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
116
  interval_ids_list.append(interval_id)
117
+ quant_payload = segment_data.get('quant_ohlc_features')
118
+ if quant_payload is None:
119
+ raise RuntimeError("Chart_Segment missing quant_ohlc_features. Rebuild cache with quantitative chart features.")
120
+ if not isinstance(quant_payload, list):
121
+ raise RuntimeError("Chart_Segment quant_ohlc_features must be a list.")
122
+ feature_rows = []
123
+ feature_mask = []
124
+ for token_idx in range(self.quant_ohlc_tokens):
125
+ if token_idx < len(quant_payload):
126
+ payload = quant_payload[token_idx]
127
+ vec = payload.get('feature_vector')
128
+ if not isinstance(vec, list) or len(vec) != self.quant_ohlc_num_features:
129
+ raise RuntimeError(
130
+ f"Chart_Segment quant feature vector must have length {self.quant_ohlc_num_features}."
131
+ )
132
+ feature_rows.append(vec)
133
+ feature_mask.append(1.0)
134
+ else:
135
+ feature_rows.append([0.0] * self.quant_ohlc_num_features)
136
+ feature_mask.append(0.0)
137
+ quant_feature_tensors.append(torch.tensor(feature_rows, device=self.device, dtype=self.dtype))
138
+ quant_feature_masks.append(torch.tensor(feature_mask, device=self.device, dtype=self.dtype))
139
+ version = segment_data.get('quant_feature_version', FEATURE_VERSION)
140
+ quant_feature_version_ids.append(FEATURE_VERSION_ID if version == FEATURE_VERSION else 0)
141
  return {
142
  'price_tensor': torch.stack(ohlc_tensors).to(self.device),
143
+ 'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long),
144
+ 'quant_feature_tensors': torch.stack(quant_feature_tensors).to(self.device),
145
+ 'quant_feature_mask': torch.stack(quant_feature_masks).to(self.device),
146
+ 'quant_feature_version_ids': torch.tensor(quant_feature_version_ids, device=self.device, dtype=torch.long),
147
  }
148
 
149
  def _collate_graph_links(self,
 
713
  'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
714
  'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
715
  'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
716
+ 'quant_ohlc_feature_tensors': ohlc_inputs_dict['quant_feature_tensors'],
717
+ 'quant_ohlc_feature_mask': ohlc_inputs_dict['quant_feature_mask'],
718
+ 'quant_ohlc_feature_version_ids': ohlc_inputs_dict['quant_feature_version_ids'],
719
  'graph_updater_links': graph_updater_links,
720
  'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
721
 
data/data_loader.py CHANGED
@@ -18,6 +18,19 @@ import models.vocabulary as vocab
18
  from models.multi_modal_processor import MultiModalEncoder
19
  from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
20
  from data.context_targets import derive_movement_targets
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from requests.adapters import HTTPAdapter
22
  from urllib3.util.retry import Retry
23
 
@@ -1278,14 +1291,186 @@ class OracleDataset(Dataset):
1278
 
1279
  return full_ohlc
1280
 
1281
- def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
1282
- """
1283
- Loads data from cache. Behavior depends on cache mode:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1284
 
1285
- - RAW MODE: Loads raw token data, samples T_cutoff at runtime, applies H/B/H
1286
- - CONTEXT MODE: Loads pre-computed training context directly (fully offline)
1287
 
1288
- The cache mode is auto-detected from the cached file's 'cache_mode' field.
 
 
1289
  """
1290
  import time as _time
1291
  _timings = {}
@@ -1309,23 +1494,17 @@ class OracleDataset(Dataset):
1309
  if not cached_data:
1310
  raise RuntimeError(f"No data loaded for index {idx}")
1311
 
1312
- # Auto-detect cache mode. New compact context cache may omit 'cache_mode'.
1313
- if 'cache_mode' in cached_data:
1314
- cache_mode = cached_data.get('cache_mode', 'raw')
1315
- else:
1316
- has_context_shape = (
1317
- isinstance(cached_data, dict) and
1318
- 'event_sequence' in cached_data and
1319
- 'tokens' in cached_data and
1320
- 'wallets' in cached_data and
1321
- 'labels' in cached_data and
1322
- 'labels_mask' in cached_data
1323
- )
1324
- cache_mode = 'context' if has_context_shape else 'raw'
1325
 
1326
- if cache_mode == 'context':
1327
- # CONTEXT MODE: Return pre-computed training context directly
1328
- # This is fully deterministic - no runtime sampling or processing
1329
  _timings['total'] = _time.perf_counter() - _total_start
1330
 
1331
  if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
@@ -1346,318 +1525,16 @@ class OracleDataset(Dataset):
1346
  )
1347
 
1348
  if idx % 100 == 0:
1349
- print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1350
  f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
1351
 
1352
  return cached_data
1353
 
1354
- # RAW MODE: Fall through to original __getitem__ logic with runtime T_cutoff sampling
1355
- raw_data = cached_data
1356
-
1357
- required_keys = [
1358
- "mint_timestamp",
1359
- "max_limit_time",
1360
- "token_address",
1361
- "creator_address",
1362
- "trades",
1363
- "transfers",
1364
- "pool_creations",
1365
- "liquidity_changes",
1366
- "fee_collections",
1367
- "burns",
1368
- "supply_locks",
1369
- "migrations",
1370
- "quality_score"
1371
- ]
1372
- missing_keys = [key for key in required_keys if key not in raw_data]
1373
- if missing_keys:
1374
- raise RuntimeError(
1375
- f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
1376
- )
1377
-
1378
- # --- CHECK: Determine if we have new-style complete cache ---
1379
- has_complete_cache = 'cached_wallet_data' in raw_data and 'cached_graph_data' in raw_data
1380
-
1381
- # --- TIMING: T_cutoff sampling prep ---
1382
- _t0 = _time.perf_counter()
1383
-
1384
- def _timestamp_to_order_value(ts_value: Any) -> float:
1385
- if isinstance(ts_value, datetime.datetime):
1386
- if ts_value.tzinfo is None:
1387
- ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
1388
- return ts_value.timestamp()
1389
- try:
1390
- return float(ts_value)
1391
- except (TypeError, ValueError):
1392
- return 0.0
1393
-
1394
- # --- DYNAMIC SAMPLING LOGIC ---
1395
- mint_timestamp = raw_data['mint_timestamp']
1396
- if isinstance(mint_timestamp, datetime.datetime) and mint_timestamp.tzinfo is None:
1397
- mint_timestamp = mint_timestamp.replace(tzinfo=datetime.timezone.utc)
1398
-
1399
- min_window = 30 # seconds
1400
- horizons = sorted(self.horizons_seconds)
1401
- first_horizon = horizons[0] if horizons else 60
1402
- min_label = max(60, first_horizon)
1403
- preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
1404
-
1405
- mint_ts_value = _timestamp_to_order_value(mint_timestamp)
1406
-
1407
- # ============================================================================
1408
- # T_CUTOFF SAMPLING: Index-based with Successful Trade Guarantee
1409
- # ============================================================================
1410
- # 1. Use ALL trades (sorted by timestamp) for context
1411
- # 2. Find indices of SUCCESSFUL trades (needed for label computation)
1412
- # 3. Sample interval: [min_context_trades-1, last_successful_idx - 1]
1413
- # 4. This guarantees: N trades for context, 1+ successful trade for labels
1414
- # ============================================================================
1415
-
1416
- all_trades_raw = raw_data.get('trades', [])
1417
- if not all_trades_raw:
1418
- return None
1419
-
1420
- # Sort ALL trades by timestamp
1421
- all_trades_sorted = sorted(
1422
- [t for t in all_trades_raw if t.get('timestamp') is not None],
1423
- key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
1424
  )
1425
 
1426
- min_context_trades = self.min_trades
1427
- if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
1428
- return None
1429
-
1430
- # Find indices of SUCCESSFUL trades (valid for label computation)
1431
- successful_indices = [
1432
- i for i, t in enumerate(all_trades_sorted)
1433
- if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
1434
- ]
1435
-
1436
- if len(successful_indices) < 2: # Need at least 2 successful: anchor + future
1437
- return None
1438
-
1439
- max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
1440
- # Define sampling interval
1441
- min_idx = min_context_trades - 1 # At least N trades for context
1442
- max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
1443
-
1444
- if max_idx < min_idx:
1445
- return None
1446
-
1447
- # Precompute last successful index <= i and next successful index > i
1448
- last_successful_before = [-1] * len(all_trades_sorted)
1449
- last_seen = -1
1450
- succ_set = set(successful_indices)
1451
- for i in range(len(all_trades_sorted)):
1452
- if i in succ_set:
1453
- last_seen = i
1454
- last_successful_before[i] = last_seen
1455
-
1456
- next_successful_after = [-1] * len(all_trades_sorted)
1457
- next_seen = -1
1458
- for i in range(len(all_trades_sorted) - 1, -1, -1):
1459
- if i in succ_set:
1460
- next_seen = i
1461
- next_successful_after[i] = next_seen
1462
-
1463
- # Build eligible indices that guarantee:
1464
- # 1) anchor successful trade at or before cutoff
1465
- # 2) successful trade within max_horizon_seconds after cutoff
1466
- eligible_indices = []
1467
- for i in range(min_idx, max_idx + 1):
1468
- anchor_idx = last_successful_before[i]
1469
- next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1
1470
- if anchor_idx < 0 or next_idx < 0:
1471
- continue
1472
- cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp'))
1473
- next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp'))
1474
- if next_ts <= cutoff_ts + max_horizon_seconds:
1475
- eligible_indices.append(i)
1476
-
1477
- if not eligible_indices:
1478
- return None
1479
-
1480
- # Sample random index within eligible interval
1481
- sample_idx = random.choice(eligible_indices)
1482
-
1483
- # T_cutoff = timestamp of the sampled trade
1484
- sample_trade = all_trades_sorted[sample_idx]
1485
- sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp'))
1486
- T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
1487
- _timings['t_cutoff_sampling'] = _time.perf_counter() - _t0
1488
-
1489
- # --- TIMING: Wallet collection ---
1490
- _t0 = _time.perf_counter()
1491
- token_address = raw_data['token_address']
1492
- creator_address = raw_data['creator_address']
1493
- cutoff_ts = _timestamp_to_order_value(T_cutoff)
1494
-
1495
- def _add_wallet(addr: Optional[str], wallet_set: set):
1496
- if addr:
1497
- wallet_set.add(addr)
1498
-
1499
- wallets_to_fetch = set()
1500
- _add_wallet(creator_address, wallets_to_fetch)
1501
-
1502
- for trade in raw_data.get('trades', []):
1503
- if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts:
1504
- _add_wallet(trade.get('maker'), wallets_to_fetch)
1505
-
1506
- for transfer in raw_data.get('transfers', []):
1507
- if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts:
1508
- _add_wallet(transfer.get('source'), wallets_to_fetch)
1509
- _add_wallet(transfer.get('destination'), wallets_to_fetch)
1510
-
1511
- for pool in raw_data.get('pool_creations', []):
1512
- if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts:
1513
- _add_wallet(pool.get('creator_address'), wallets_to_fetch)
1514
-
1515
- for liq in raw_data.get('liquidity_changes', []):
1516
- if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
1517
- _add_wallet(liq.get('lp_provider'), wallets_to_fetch)
1518
-
1519
- # Offline Holder Lookup using raw_data['holder_snapshots_list']
1520
- # We need the snapshot corresponding to T_cutoff.
1521
- # Intervals are every 300s from mint_ts.
1522
- # idx = (T_cutoff - mint) // 300
1523
- elapsed = (T_cutoff - mint_timestamp).total_seconds()
1524
- snap_idx = int(elapsed // 300)
1525
- holder_records = []
1526
- cached_holders_list = raw_data.get('holder_snapshots_list')
1527
- if not isinstance(cached_holders_list, list):
1528
- raise RuntimeError("Invalid cache: holder_snapshots_list must be a list.")
1529
- if not (0 <= snap_idx < len(cached_holders_list)):
1530
- raise RuntimeError(
1531
- f"Invalid cache: holder_snapshots_list index out of range (snap_idx={snap_idx}, len={len(cached_holders_list)})."
1532
- )
1533
- snapshot_data = cached_holders_list[snap_idx]
1534
- if not isinstance(snapshot_data, dict) or not isinstance(snapshot_data.get('holders'), list):
1535
- raise RuntimeError("Invalid cache: holder snapshot entry must be a dict with list field 'holders'.")
1536
- holder_records = snapshot_data['holders']
1537
- for holder in holder_records:
1538
- if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder:
1539
- raise RuntimeError("Invalid cache: each holder record must include wallet_address and current_balance.")
1540
- _add_wallet(holder['wallet_address'], wallets_to_fetch)
1541
- _timings['wallet_collection'] = _time.perf_counter() - _t0
1542
- _timings['num_wallets'] = len(wallets_to_fetch)
1543
-
1544
- pooler = EmbeddingPooler()
1545
-
1546
- # --- TIMING: Token data (OFFLINE - uses cached image bytes) ---
1547
- _t0 = _time.perf_counter()
1548
-
1549
- # Build minimal main token metadata from cache (no HTTP calls)
1550
- offline_token_data = {token_address: self._build_main_token_seed(token_address, raw_data)}
1551
-
1552
- # If we have cached image bytes, convert to PIL Image for the pooler
1553
- cached_image_bytes = raw_data.get('cached_image_bytes')
1554
- if cached_image_bytes:
1555
- try:
1556
- cached_image = Image.open(BytesIO(cached_image_bytes))
1557
- offline_token_data[token_address]['_cached_image_pil'] = cached_image
1558
- except Exception as e:
1559
- pass # Image decoding failed, will use None
1560
-
1561
- main_token_data = self._process_token_data_offline(
1562
- [token_address], pooler, T_cutoff, token_data=offline_token_data
1563
- )
1564
- _timings['token_data_and_images'] = _time.perf_counter() - _t0
1565
-
1566
- if not main_token_data:
1567
- return None
1568
-
1569
- # --- TIMING: Wallet data (OFFLINE - uses pre-cached profiles/socials/holdings) ---
1570
- _t0 = _time.perf_counter()
1571
-
1572
- if has_complete_cache:
1573
- # Use new complete cache format
1574
- cached_wallet_bundle = raw_data.get('cached_wallet_data', {})
1575
- offline_profiles = cached_wallet_bundle.get('profiles', {})
1576
- offline_socials = cached_wallet_bundle.get('socials', {})
1577
- offline_holdings = cached_wallet_bundle.get('holdings', {})
1578
- else:
1579
- # Fallback to old partial cache format
1580
- cached_social_bundle = raw_data.get('socials', {})
1581
- offline_profiles = cached_social_bundle.get('profiles', {})
1582
- offline_socials = cached_social_bundle.get('socials', {})
1583
- offline_holdings = {}
1584
-
1585
- wallet_data, all_token_data = self._process_wallet_data(
1586
- list(wallets_to_fetch),
1587
- main_token_data.copy(),
1588
- pooler,
1589
- T_cutoff,
1590
- profiles_override=offline_profiles,
1591
- socials_override=offline_socials,
1592
- holdings_override=offline_holdings
1593
- )
1594
- _timings['wallet_data'] = _time.perf_counter() - _t0
1595
- _timings['num_tokens_in_holdings'] = len(all_token_data) - 1 if all_token_data else 0
1596
-
1597
- # --- TIMING: Graph links (OFFLINE - uses pre-cached graph data) ---
1598
- _t0 = _time.perf_counter()
1599
-
1600
- if has_complete_cache:
1601
- # Use new complete cache format
1602
- cached_graph_bundle = raw_data.get('cached_graph_data', {})
1603
- graph_entities = cached_graph_bundle.get('entities', {})
1604
- graph_links = cached_graph_bundle.get('links', {})
1605
- else:
1606
- # No graph data in old cache format
1607
- graph_entities = {}
1608
- graph_links = {}
1609
-
1610
- _timings['graph_links'] = _time.perf_counter() - _t0
1611
- _timings['num_graph_entities'] = len(graph_entities)
1612
-
1613
- # --- TIMING: Generate dataset item ---
1614
- _t0 = _time.perf_counter()
1615
- # Generate the item
1616
- result = self._generate_dataset_item(
1617
- token_address=token_address,
1618
- t0=mint_timestamp,
1619
- T_cutoff=T_cutoff,
1620
- mint_event={ # Reconstruct simplified mint event
1621
- 'event_type': 'Mint',
1622
- 'timestamp': int(mint_timestamp.timestamp()),
1623
- 'relative_ts': 0,
1624
- 'wallet_address': creator_address,
1625
- 'token_address': token_address,
1626
- 'protocol_id': raw_data.get('protocol_id', 0)
1627
- },
1628
- trade_records=raw_data['trades'],
1629
- transfer_records=raw_data['transfers'],
1630
- pool_creation_records=raw_data['pool_creations'],
1631
- liquidity_change_records=raw_data['liquidity_changes'],
1632
- fee_collection_records=raw_data['fee_collections'],
1633
- burn_records=raw_data['burns'],
1634
- supply_lock_records=raw_data['supply_locks'],
1635
- migration_records=raw_data['migrations'],
1636
- wallet_data=wallet_data,
1637
- all_token_data=all_token_data,
1638
- graph_links=graph_links,
1639
- graph_seed_entities=wallets_to_fetch,
1640
- all_graph_entities=graph_entities,
1641
- future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
1642
- pooler=pooler,
1643
- sample_idx=idx,
1644
- cached_holders_list=raw_data.get('holder_snapshots_list'),
1645
- cached_ohlc_1s=raw_data.get('ohlc_1s'),
1646
- quality_score=raw_data.get('quality_score')
1647
- )
1648
- _timings['generate_item'] = _time.perf_counter() - _t0
1649
-
1650
- # --- TIMING: Total and summary ---
1651
- _timings['total'] = _time.perf_counter() - _total_start
1652
-
1653
- # Only print timing summary occasionally to reduce log spam
1654
- if idx % 100 == 0:
1655
- print(f"[Sample {idx}] OFFLINE mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1656
- f"total: {_timings['total']*1000:.1f}ms | wallets: {_timings['num_wallets']} | "
1657
- f"graph_entities: {_timings['num_graph_entities']}")
1658
-
1659
- return result
1660
-
1661
  def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler,
1662
  T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
1663
  """
@@ -2301,7 +2178,9 @@ class OracleDataset(Dataset):
2301
  'relative_ts': int(last_ts) - int(t0_timestamp),
2302
  'opens': self._normalize_price_series(opens_raw),
2303
  'closes': self._normalize_price_series(closes_raw),
2304
- 'i': interval_label
 
 
2305
  }
2306
  emitted_events.append(chart_event)
2307
  return emitted_events
@@ -2601,22 +2480,22 @@ class OracleDataset(Dataset):
2601
 
2602
  if not all_trades:
2603
  # No valid trades for label computation
2604
- movement_targets = derive_movement_targets(
2605
- [0.0] * len(self.horizons_seconds),
2606
- [0.0] * len(self.horizons_seconds),
2607
- movement_label_config=self.movement_label_config,
2608
- )
2609
  return {
2610
  'event_sequence': event_sequence,
2611
  'wallets': wallet_data,
2612
  'tokens': all_token_data,
2613
  'graph_links': graph_links,
2614
  'embedding_pooler': pooler,
 
 
2615
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2616
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2617
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
2618
- 'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
2619
- 'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
2620
  }
2621
 
2622
  # Ensure sorted
@@ -2696,12 +2575,11 @@ class OracleDataset(Dataset):
2696
 
2697
  # DEBUG: Mask summaries removed after validation
2698
 
2699
- movement_targets = derive_movement_targets(
2700
- label_values,
2701
- mask_values,
2702
- movement_label_config=self.movement_label_config,
2703
- )
2704
-
2705
  return {
2706
  'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
2707
  'token_address': token_address, # For debugging
@@ -2711,11 +2589,11 @@ class OracleDataset(Dataset):
2711
  'tokens': all_token_data,
2712
  'graph_links': graph_links,
2713
  'embedding_pooler': pooler,
 
 
2714
  'labels': torch.tensor(label_values, dtype=torch.float32),
2715
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2716
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
2717
- 'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
2718
- 'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
2719
  }
2720
 
2721
  def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
 
18
  from models.multi_modal_processor import MultiModalEncoder
19
  from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
20
  from data.context_targets import derive_movement_targets
21
+ from data.quant_ohlc_feature_schema import (
22
+ FEATURE_VERSION,
23
+ FEATURE_VERSION_ID,
24
+ LOOKBACK_SECONDS,
25
+ TOKENS_PER_SEGMENT,
26
+ WINDOW_SECONDS,
27
+ empty_feature_dict,
28
+ feature_dict_to_vector,
29
+ )
30
+ from signals.patterns import compute_pattern_features
31
+ from signals.rolling_quant import compute_rolling_quant_features
32
+ from signals.support_resistance import compute_support_resistance_features
33
+ from signals.trendlines import compute_trendline_features
34
  from requests.adapters import HTTPAdapter
35
  from urllib3.util.retry import Retry
36
 
 
1291
 
1292
  return full_ohlc
1293
 
1294
+ def _compute_quant_rolling_features(
1295
+ self,
1296
+ closes: List[float],
1297
+ end_idx: int,
1298
+ ) -> Dict[str, float]:
1299
+ return compute_rolling_quant_features(closes, end_idx)
1300
+
1301
+ def _compute_support_resistance_features(
1302
+ self,
1303
+ closes: List[float],
1304
+ highs: List[float],
1305
+ lows: List[float],
1306
+ end_idx: int,
1307
+ window_start: int,
1308
+ window_end: int,
1309
+ timestamps: List[int],
1310
+ ) -> Dict[str, float]:
1311
+ return compute_support_resistance_features(
1312
+ closes=closes,
1313
+ highs=highs,
1314
+ lows=lows,
1315
+ end_idx=end_idx,
1316
+ window_start=window_start,
1317
+ window_end=window_end,
1318
+ timestamps=timestamps,
1319
+ )
1320
+
1321
+ def _compute_trendline_features(
1322
+ self,
1323
+ closes: List[float],
1324
+ highs: List[float],
1325
+ lows: List[float],
1326
+ end_idx: int,
1327
+ ) -> Dict[str, float]:
1328
+ return compute_trendline_features(closes, highs, lows, end_idx)
1329
+
1330
+ def _compute_optional_pattern_flags(
1331
+ self,
1332
+ closes: List[float],
1333
+ highs: List[float],
1334
+ lows: List[float],
1335
+ end_idx: int,
1336
+ ) -> Dict[str, float]:
1337
+ return compute_pattern_features(closes, highs, lows, end_idx)
1338
+
1339
+ def _extract_quant_ohlc_features_for_segment(
1340
+ self,
1341
+ segment: List[tuple],
1342
+ interval_label: str,
1343
+ ) -> List[Dict[str, Any]]:
1344
+ del interval_label
1345
+ if not segment:
1346
+ return []
1347
+
1348
+ timestamps = [int(row[0]) for row in segment]
1349
+ opens = [float(row[1]) for row in segment]
1350
+ closes = [float(row[2]) for row in segment]
1351
+ highs = [max(o, c) for o, c in zip(opens, closes)]
1352
+ lows = [min(o, c) for o, c in zip(opens, closes)]
1353
+ log_closes = np.log(np.clip(np.asarray(closes, dtype=np.float64), 1e-8, None))
1354
+ one_sec_returns = np.diff(log_closes)
1355
+ feature_windows: List[Dict[str, Any]] = []
1356
+
1357
+ for window_idx in range(TOKENS_PER_SEGMENT):
1358
+ window_start = window_idx * WINDOW_SECONDS
1359
+ if window_start >= len(segment):
1360
+ break
1361
+ window_end = min(len(segment), window_start + WINDOW_SECONDS)
1362
+ current_end_idx = window_end - 1
1363
+ window_returns = one_sec_returns[window_start:max(window_start, current_end_idx)]
1364
+ window_closes = closes[window_start:window_end]
1365
+ window_highs = highs[window_start:window_end]
1366
+ window_lows = lows[window_start:window_end]
1367
+ features = empty_feature_dict()
1368
+
1369
+ if window_closes:
1370
+ window_close_arr = np.asarray(window_closes, dtype=np.float64)
1371
+ window_return_sum = float(np.sum(window_returns)) if window_returns.size > 0 else 0.0
1372
+ range_width = max(max(window_highs) - min(window_lows), 0.0)
1373
+ first_close = float(window_close_arr[0])
1374
+ last_close = float(window_close_arr[-1])
1375
+ accel_proxy = 0.0
1376
+ if window_returns.size >= 2:
1377
+ accel_proxy = float(window_returns[-1] - window_returns[0])
1378
+ features.update({
1379
+ "cum_log_return": window_return_sum,
1380
+ "mean_log_return_1s": float(np.mean(window_returns)) if window_returns.size > 0 else 0.0,
1381
+ "std_log_return_1s": float(np.std(window_returns)) if window_returns.size > 0 else 0.0,
1382
+ "max_up_1s": float(np.max(window_returns)) if window_returns.size > 0 else 0.0,
1383
+ "max_down_1s": float(np.min(window_returns)) if window_returns.size > 0 else 0.0,
1384
+ "realized_vol": float(np.sqrt(np.sum(np.square(window_returns)))) if window_returns.size > 0 else 0.0,
1385
+ "window_range_frac": range_width / max(abs(last_close), 1e-8),
1386
+ "close_to_close_slope": (last_close - first_close) / max(abs(first_close), 1e-8),
1387
+ "accel_proxy": accel_proxy,
1388
+ "frac_pos_1s": float(np.mean(window_returns > 0)) if window_returns.size > 0 else 0.0,
1389
+ "frac_neg_1s": float(np.mean(window_returns < 0)) if window_returns.size > 0 else 0.0,
1390
+ })
1391
+
1392
+ current_price = closes[current_end_idx]
1393
+ current_high = highs[current_end_idx]
1394
+ current_low = lows[current_end_idx]
1395
+ for lookback in LOOKBACK_SECONDS:
1396
+ prefix = f"lb_{lookback}s"
1397
+ lookback_start = max(0, current_end_idx - lookback + 1)
1398
+ hist_closes = closes[lookback_start: current_end_idx + 1]
1399
+ hist_highs = highs[lookback_start: current_end_idx + 1]
1400
+ hist_lows = lows[lookback_start: current_end_idx + 1]
1401
+ hist_range = max(max(hist_highs) - min(hist_lows), 1e-8)
1402
+ rolling_high = max(hist_highs)
1403
+ rolling_low = min(hist_lows)
1404
+ hist_returns = np.diff(np.log(np.clip(np.asarray(hist_closes, dtype=np.float64), 1e-8, None)))
1405
+ current_width = max(max(window_highs) - min(window_lows), 0.0)
1406
+ prev_hist_width = max(max(hist_highs[:-len(window_highs)]) - min(hist_lows[:-len(window_lows)]), 0.0) if len(hist_highs) > len(window_highs) else current_width
1407
+ prev_close = closes[current_end_idx - 1] if current_end_idx > 0 else current_price
1408
+
1409
+ features.update({
1410
+ f"{prefix}_dist_high": (rolling_high - current_price) / max(abs(current_price), 1e-8),
1411
+ f"{prefix}_dist_low": (current_price - rolling_low) / max(abs(current_price), 1e-8),
1412
+ f"{prefix}_drawdown_high": (current_price - rolling_high) / max(abs(rolling_high), 1e-8),
1413
+ f"{prefix}_rebound_low": (current_price - rolling_low) / max(abs(rolling_low), 1e-8),
1414
+ f"{prefix}_pos_in_range": (current_price - rolling_low) / hist_range,
1415
+ f"{prefix}_range_width": hist_range / max(abs(current_price), 1e-8),
1416
+ f"{prefix}_compression_ratio": current_width / max(prev_hist_width, 1e-8),
1417
+ f"{prefix}_breakout_high": 1.0 if current_high > rolling_high and prev_close <= rolling_high else 0.0,
1418
+ f"{prefix}_breakdown_low": 1.0 if current_low < rolling_low and prev_close >= rolling_low else 0.0,
1419
+ f"{prefix}_reclaim_breakdown": 1.0 if current_low < rolling_low and current_price >= rolling_low else 0.0,
1420
+ f"{prefix}_rejection_breakout": 1.0 if current_high > rolling_high and current_price <= rolling_high else 0.0,
1421
+ })
1422
+
1423
+ features.update(self._compute_support_resistance_features(
1424
+ closes=closes,
1425
+ highs=highs,
1426
+ lows=lows,
1427
+ end_idx=current_end_idx,
1428
+ window_start=window_start,
1429
+ window_end=window_end,
1430
+ timestamps=timestamps,
1431
+ ))
1432
+ features.update(self._compute_trendline_features(
1433
+ closes=closes,
1434
+ highs=highs,
1435
+ lows=lows,
1436
+ end_idx=current_end_idx,
1437
+ ))
1438
+ features.update(self._compute_quant_rolling_features(
1439
+ closes=closes,
1440
+ end_idx=current_end_idx,
1441
+ ))
1442
+ features.update(self._compute_optional_pattern_flags(
1443
+ closes=closes,
1444
+ highs=highs,
1445
+ lows=lows,
1446
+ end_idx=current_end_idx,
1447
+ ))
1448
+
1449
+ feature_windows.append({
1450
+ "start_ts": timestamps[window_start],
1451
+ "end_ts": timestamps[current_end_idx],
1452
+ "window_seconds": WINDOW_SECONDS,
1453
+ "feature_vector": feature_dict_to_vector(features),
1454
+ "feature_names_version": FEATURE_VERSION,
1455
+ "feature_version_id": FEATURE_VERSION_ID,
1456
+ "level_snapshot": {
1457
+ "support_distance": features.get("nearest_support_dist", 0.0),
1458
+ "resistance_distance": features.get("nearest_resistance_dist", 0.0),
1459
+ "support_strength": features.get("support_strength", 0.0),
1460
+ "resistance_strength": features.get("resistance_strength", 0.0),
1461
+ },
1462
+ "pattern_flags": {
1463
+ key.replace("pattern_", "").replace("_confidence", ""): features[key]
1464
+ for key in features.keys()
1465
+ if key.startswith("pattern_") and key.endswith("_confidence")
1466
+ },
1467
+ })
1468
 
1469
+ return feature_windows
 
1470
 
1471
+ def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
1472
+ """
1473
+ Loads data from cache.
1474
  """
1475
  import time as _time
1476
  _timings = {}
 
1494
  if not cached_data:
1495
  raise RuntimeError(f"No data loaded for index {idx}")
1496
 
1497
+ has_context_shape = (
1498
+ isinstance(cached_data, dict) and
1499
+ 'event_sequence' in cached_data and
1500
+ 'tokens' in cached_data and
1501
+ 'wallets' in cached_data and
1502
+ 'labels' in cached_data and
1503
+ 'labels_mask' in cached_data
1504
+ )
 
 
 
 
 
1505
 
1506
+ if has_context_shape:
1507
+ # Return pre-computed training context directly.
 
1508
  _timings['total'] = _time.perf_counter() - _total_start
1509
 
1510
  if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
 
1525
  )
1526
 
1527
  if idx % 100 == 0:
1528
+ print(f"[Sample {idx}] context cache | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1529
  f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
1530
 
1531
  return cached_data
1532
 
1533
+ raise RuntimeError(
1534
+ f"Cached item at {filepath} is not a valid context cache. "
1535
+ "Rebuild the cache with scripts/cache_dataset.py."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1536
  )
1537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1538
  def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler,
1539
  T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
1540
  """
 
2178
  'relative_ts': int(last_ts) - int(t0_timestamp),
2179
  'opens': self._normalize_price_series(opens_raw),
2180
  'closes': self._normalize_price_series(closes_raw),
2181
+ 'i': interval_label,
2182
+ 'quant_ohlc_features': self._extract_quant_ohlc_features_for_segment(segment, interval_label) if interval_label == "1s" else [],
2183
+ 'quant_feature_version': FEATURE_VERSION,
2184
  }
2185
  emitted_events.append(chart_event)
2186
  return emitted_events
 
2480
 
2481
  if not all_trades:
2482
  # No valid trades for label computation
2483
+ quant_payload = [
2484
+ event.get('quant_ohlc_features', [])
2485
+ for event in event_sequence
2486
+ if event.get('event_type') == 'Chart_Segment'
2487
+ ]
2488
  return {
2489
  'event_sequence': event_sequence,
2490
  'wallets': wallet_data,
2491
  'tokens': all_token_data,
2492
  'graph_links': graph_links,
2493
  'embedding_pooler': pooler,
2494
+ 'quant_ohlc_features': quant_payload,
2495
+ 'quant_feature_version': FEATURE_VERSION,
2496
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2497
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2498
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
 
 
2499
  }
2500
 
2501
  # Ensure sorted
 
2575
 
2576
  # DEBUG: Mask summaries removed after validation
2577
 
2578
+ quant_payload = [
2579
+ event.get('quant_ohlc_features', [])
2580
+ for event in event_sequence
2581
+ if event.get('event_type') == 'Chart_Segment'
2582
+ ]
 
2583
  return {
2584
  'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
2585
  'token_address': token_address, # For debugging
 
2589
  'tokens': all_token_data,
2590
  'graph_links': graph_links,
2591
  'embedding_pooler': pooler,
2592
+ 'quant_ohlc_features': quant_payload,
2593
+ 'quant_feature_version': FEATURE_VERSION,
2594
  'labels': torch.tensor(label_values, dtype=torch.float32),
2595
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2596
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
 
 
2597
  }
2598
 
2599
  def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
data/quant_ohlc_feature_schema.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Dict, Iterable, List
3
+
4
+
5
+ FEATURE_VERSION = "qohlc_v1"
6
+ FEATURE_VERSION_ID = 1
7
+ WINDOW_SECONDS = 5
8
+ SEGMENT_SECONDS = 300
9
+ TOKENS_PER_SEGMENT = SEGMENT_SECONDS // WINDOW_SECONDS
10
+ PATTERN_NAMES = [
11
+ "double_top",
12
+ "double_bottom",
13
+ "ascending_triangle",
14
+ "descending_triangle",
15
+ "head_shoulders",
16
+ "inverse_head_shoulders",
17
+ ]
18
+ LOOKBACK_SECONDS = [15, 30, 60, 120]
19
+
20
+
21
+ FEATURE_NAMES: List[str] = [
22
+ "cum_log_return",
23
+ "mean_log_return_1s",
24
+ "std_log_return_1s",
25
+ "max_up_1s",
26
+ "max_down_1s",
27
+ "realized_vol",
28
+ "window_range_frac",
29
+ "close_to_close_slope",
30
+ "accel_proxy",
31
+ "frac_pos_1s",
32
+ "frac_neg_1s",
33
+ ]
34
+
35
+ for lookback in LOOKBACK_SECONDS:
36
+ prefix = f"lb_{lookback}s"
37
+ FEATURE_NAMES.extend([
38
+ f"{prefix}_dist_high",
39
+ f"{prefix}_dist_low",
40
+ f"{prefix}_drawdown_high",
41
+ f"{prefix}_rebound_low",
42
+ f"{prefix}_pos_in_range",
43
+ f"{prefix}_range_width",
44
+ f"{prefix}_compression_ratio",
45
+ f"{prefix}_breakout_high",
46
+ f"{prefix}_breakdown_low",
47
+ f"{prefix}_reclaim_breakdown",
48
+ f"{prefix}_rejection_breakout",
49
+ ])
50
+
51
+ FEATURE_NAMES.extend([
52
+ "nearest_support_dist",
53
+ "nearest_resistance_dist",
54
+ "support_touch_count",
55
+ "resistance_touch_count",
56
+ "support_age_sec",
57
+ "resistance_age_sec",
58
+ "support_strength",
59
+ "resistance_strength",
60
+ "inside_support_zone",
61
+ "inside_resistance_zone",
62
+ "support_swept",
63
+ "resistance_swept",
64
+ "support_reclaim",
65
+ "resistance_reject",
66
+ "lower_trendline_slope",
67
+ "upper_trendline_slope",
68
+ "dist_to_lower_line",
69
+ "dist_to_upper_line",
70
+ "trend_channel_width",
71
+ "trend_convergence",
72
+ "trend_breakout_upper",
73
+ "trend_breakdown_lower",
74
+ "trend_reentry",
75
+ "ema_fast",
76
+ "ema_medium",
77
+ "sma_fast",
78
+ "sma_medium",
79
+ "price_minus_ema_fast",
80
+ "price_minus_ema_medium",
81
+ "ema_spread",
82
+ "price_zscore",
83
+ "mean_reversion_score",
84
+ "rolling_vol_zscore",
85
+ ])
86
+
87
+ for pattern_name in PATTERN_NAMES:
88
+ FEATURE_NAMES.append(f"pattern_{pattern_name}_confidence")
89
+
90
+ FEATURE_NAMES.extend([
91
+ "sr_available",
92
+ "trendline_available",
93
+ "pattern_available",
94
+ ])
95
+
96
+ FEATURE_INDEX = {name: idx for idx, name in enumerate(FEATURE_NAMES)}
97
+ NUM_QUANT_OHLC_FEATURES = len(FEATURE_NAMES)
98
+
99
+ FEATURE_GROUPS = OrderedDict([
100
+ ("price_path", [
101
+ "cum_log_return",
102
+ "mean_log_return_1s",
103
+ "std_log_return_1s",
104
+ "max_up_1s",
105
+ "max_down_1s",
106
+ "realized_vol",
107
+ "window_range_frac",
108
+ "close_to_close_slope",
109
+ "accel_proxy",
110
+ "frac_pos_1s",
111
+ "frac_neg_1s",
112
+ ]),
113
+ ("relative_structure", [name for name in FEATURE_NAMES if name.startswith("lb_")]),
114
+ ("levels_breaks", [
115
+ "nearest_support_dist",
116
+ "nearest_resistance_dist",
117
+ "support_touch_count",
118
+ "resistance_touch_count",
119
+ "support_age_sec",
120
+ "resistance_age_sec",
121
+ "support_strength",
122
+ "resistance_strength",
123
+ "inside_support_zone",
124
+ "inside_resistance_zone",
125
+ "support_swept",
126
+ "resistance_swept",
127
+ "support_reclaim",
128
+ "resistance_reject",
129
+ ]),
130
+ ("trendlines", [
131
+ "lower_trendline_slope",
132
+ "upper_trendline_slope",
133
+ "dist_to_lower_line",
134
+ "dist_to_upper_line",
135
+ "trend_channel_width",
136
+ "trend_convergence",
137
+ "trend_breakout_upper",
138
+ "trend_breakdown_lower",
139
+ "trend_reentry",
140
+ ]),
141
+ ("rolling_quant", [
142
+ "ema_fast",
143
+ "ema_medium",
144
+ "sma_fast",
145
+ "sma_medium",
146
+ "price_minus_ema_fast",
147
+ "price_minus_ema_medium",
148
+ "ema_spread",
149
+ "price_zscore",
150
+ "mean_reversion_score",
151
+ "rolling_vol_zscore",
152
+ ]),
153
+ ("patterns", [name for name in FEATURE_NAMES if name.startswith("pattern_")]),
154
+ ("availability", [
155
+ "sr_available",
156
+ "trendline_available",
157
+ "pattern_available",
158
+ ]),
159
+ ])
160
+
161
+
162
+ def empty_feature_dict() -> Dict[str, float]:
163
+ return {name: 0.0 for name in FEATURE_NAMES}
164
+
165
+
166
+ def feature_dict_to_vector(features: Dict[str, float]) -> List[float]:
167
+ out: List[float] = []
168
+ for name in FEATURE_NAMES:
169
+ value = features.get(name, 0.0)
170
+ try:
171
+ out.append(float(value))
172
+ except Exception:
173
+ out.append(0.0)
174
+ return out
175
+
176
+
177
+ def group_feature_indices(group_names: Iterable[str]) -> List[int]:
178
+ indices: List[int] = []
179
+ for group_name in group_names:
180
+ for feature_name in FEATURE_GROUPS[group_name]:
181
+ indices.append(FEATURE_INDEX[feature_name])
182
+ return sorted(set(indices))
inference.py CHANGED
@@ -15,7 +15,9 @@ from models.token_encoder import TokenEncoder
15
  from models.wallet_encoder import WalletEncoder
16
  from models.graph_updater import GraphUpdater
17
  from models.ohlc_embedder import OHLCEmbedder
 
18
  import models.vocabulary as vocab
 
19
 
20
  # --- NEW: Import database clients ---
21
  from clickhouse_driver import Client as ClickHouseClient
@@ -56,7 +58,11 @@ if __name__ == "__main__":
56
 
57
  real_ohlc_emb = OHLCEmbedder(
58
  num_intervals=vocab.NUM_OHLC_INTERVALS,
59
- sequence_length=OHLC_SEQ_LEN,
 
 
 
 
60
  dtype=dtype
61
  )
62
 
@@ -96,7 +102,8 @@ if __name__ == "__main__":
96
  quantiles=_test_quantiles,
97
  horizons_seconds=_test_horizons,
98
  dtype=dtype,
99
- ohlc_embedder=real_ohlc_emb
 
100
  ).to(device)
101
  model.eval()
102
  print(f"Oracle d_model: {model.d_model}")
 
15
  from models.wallet_encoder import WalletEncoder
16
  from models.graph_updater import GraphUpdater
17
  from models.ohlc_embedder import OHLCEmbedder
18
+ from models.quant_ohlc_embedder import QuantOHLCEmbedder
19
  import models.vocabulary as vocab
20
+ from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
21
 
22
  # --- NEW: Import database clients ---
23
  from clickhouse_driver import Client as ClickHouseClient
 
58
 
59
  real_ohlc_emb = OHLCEmbedder(
60
  num_intervals=vocab.NUM_OHLC_INTERVALS,
61
+ dtype=dtype
62
+ )
63
+ real_quant_ohlc_emb = QuantOHLCEmbedder(
64
+ num_features=NUM_QUANT_OHLC_FEATURES,
65
+ sequence_length=TOKENS_PER_SEGMENT,
66
  dtype=dtype
67
  )
68
 
 
102
  quantiles=_test_quantiles,
103
  horizons_seconds=_test_horizons,
104
  dtype=dtype,
105
+ ohlc_embedder=real_ohlc_emb,
106
+ quant_ohlc_embedder=real_quant_ohlc_emb
107
  ).to(device)
108
  model.eval()
109
  print(f"Oracle d_model: {model.d_model}")
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:df78e2b44dd97a148be762f91e3b00f397651f8e7e43ee21f938492291fdfa3a
3
- size 83447
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49c4d500d58301a7c158851716c6cf0c7e6bc60cdf59f7a166fcb92b4e77b04a
3
+ size 390587
models/model.py CHANGED
@@ -14,6 +14,7 @@ from models.token_encoder import TokenEncoder
14
  from models.wallet_encoder import WalletEncoder
15
  from models.graph_updater import GraphUpdater
16
  from models.ohlc_embedder import OHLCEmbedder
 
17
  from models.HoldersEncoder import HolderDistributionEncoder # NEW
18
  from models.SocialEncoders import SocialEncoder # NEW
19
  import models.vocabulary as vocab # For vocab sizes
@@ -28,6 +29,7 @@ class Oracle(nn.Module):
28
  wallet_encoder: WalletEncoder,
29
  graph_updater: GraphUpdater,
30
  ohlc_embedder: OHLCEmbedder, # NEW
 
31
  time_encoder: ContextualTimeEncoder,
32
  num_event_types: int,
33
  multi_modal_dim: int,
@@ -124,7 +126,8 @@ class Oracle(nn.Module):
124
  self.token_encoder = token_encoder
125
  self.wallet_encoder = wallet_encoder
126
  self.graph_updater = graph_updater
127
- self.ohlc_embedder = ohlc_embedder
 
128
  self.time_encoder = time_encoder # Store time_encoder
129
 
130
  self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
@@ -140,7 +143,7 @@ class Oracle(nn.Module):
140
  # --- 5. Define Entity Padding (Learnable) ---
141
  self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
142
  self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
143
- self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.ohlc_embedder.output_dim))
144
  self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
145
 
146
  # --- NEW: Instantiate HolderDistributionEncoder internally ---
@@ -157,7 +160,16 @@ class Oracle(nn.Module):
157
  self.rel_ts_norm = nn.LayerNorm(1)
158
  self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
159
  self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
160
- self.ohlc_proj = nn.Linear(self.ohlc_embedder.output_dim, self.d_model)
 
 
 
 
 
 
 
 
 
161
  # self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
162
 
163
 
@@ -309,7 +321,7 @@ class Oracle(nn.Module):
309
 
310
  @classmethod
311
  def from_pretrained(cls, load_directory: str,
312
- token_encoder, wallet_encoder, graph_updater, ohlc_embedder, time_encoder):
313
  """
314
  Loads the Oracle model from a saved directory.
315
  Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
@@ -329,6 +341,7 @@ class Oracle(nn.Module):
329
  wallet_encoder=wallet_encoder,
330
  graph_updater=graph_updater,
331
  ohlc_embedder=ohlc_embedder,
 
332
  time_encoder=time_encoder,
333
  num_event_types=config["num_event_types"],
334
  multi_modal_dim=config["multi_modal_dim"],
@@ -431,6 +444,14 @@ class Oracle(nn.Module):
431
  neginf=0.0
432
  )
433
  ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
 
 
 
 
 
 
 
 
434
  graph_updater_links = batch['graph_updater_links']
435
 
436
  # 1a. Encode Tokens
@@ -490,9 +511,39 @@ class Oracle(nn.Module):
490
 
491
  # 1c. Encode OHLC
492
  if ohlc_price_tensors.shape[0] > 0:
493
- batch_ohlc_embeddings_raw = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  else:
495
- batch_ohlc_embeddings_raw = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
496
 
497
  # 1d. Run Graph Updater
498
  pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
 
14
  from models.wallet_encoder import WalletEncoder
15
  from models.graph_updater import GraphUpdater
16
  from models.ohlc_embedder import OHLCEmbedder
17
+ from models.quant_ohlc_embedder import QuantOHLCEmbedder
18
  from models.HoldersEncoder import HolderDistributionEncoder # NEW
19
  from models.SocialEncoders import SocialEncoder # NEW
20
  import models.vocabulary as vocab # For vocab sizes
 
29
  wallet_encoder: WalletEncoder,
30
  graph_updater: GraphUpdater,
31
  ohlc_embedder: OHLCEmbedder, # NEW
32
+ quant_ohlc_embedder: QuantOHLCEmbedder,
33
  time_encoder: ContextualTimeEncoder,
34
  num_event_types: int,
35
  multi_modal_dim: int,
 
126
  self.token_encoder = token_encoder
127
  self.wallet_encoder = wallet_encoder
128
  self.graph_updater = graph_updater
129
+ self.ohlc_embedder = ohlc_embedder
130
+ self.quant_ohlc_embedder = quant_ohlc_embedder
131
  self.time_encoder = time_encoder # Store time_encoder
132
 
133
  self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
 
143
  # --- 5. Define Entity Padding (Learnable) ---
144
  self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
145
  self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
146
+ self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim))
147
  self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
148
 
149
  # --- NEW: Instantiate HolderDistributionEncoder internally ---
 
160
  self.rel_ts_norm = nn.LayerNorm(1)
161
  self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
162
  self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
163
+ self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model)
164
+ self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0)
165
+ fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32
166
+ self.chart_fusion = nn.Sequential(
167
+ nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim),
168
+ nn.GELU(),
169
+ nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
170
+ nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim),
171
+ nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
172
+ )
173
  # self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
174
 
175
 
 
321
 
322
  @classmethod
323
  def from_pretrained(cls, load_directory: str,
324
+ token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder):
325
  """
326
  Loads the Oracle model from a saved directory.
327
  Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
 
341
  wallet_encoder=wallet_encoder,
342
  graph_updater=graph_updater,
343
  ohlc_embedder=ohlc_embedder,
344
+ quant_ohlc_embedder=quant_ohlc_embedder,
345
  time_encoder=time_encoder,
346
  num_event_types=config["num_event_types"],
347
  multi_modal_dim=config["multi_modal_dim"],
 
444
  neginf=0.0
445
  )
446
  ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
447
+ quant_ohlc_feature_tensors = torch.nan_to_num(
448
+ batch['quant_ohlc_feature_tensors'].to(device, self.dtype),
449
+ nan=0.0,
450
+ posinf=0.0,
451
+ neginf=0.0
452
+ )
453
+ quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device)
454
+ quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device)
455
  graph_updater_links = batch['graph_updater_links']
456
 
457
  # 1a. Encode Tokens
 
511
 
512
  # 1c. Encode OHLC
513
  if ohlc_price_tensors.shape[0] > 0:
514
+ raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
515
+ else:
516
+ raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
517
+ if quant_ohlc_feature_tensors.shape[0] > 0:
518
+ quant_chart_embeddings = self.quant_ohlc_embedder(
519
+ quant_ohlc_feature_tensors,
520
+ quant_ohlc_feature_mask,
521
+ quant_ohlc_feature_version_ids,
522
+ )
523
+ else:
524
+ quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
525
+ num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0])
526
+ if num_chart_segments > 0:
527
+ if raw_chart_embeddings.shape[0] == 0:
528
+ raw_chart_embeddings = torch.zeros(
529
+ num_chart_segments,
530
+ self.ohlc_embedder.output_dim,
531
+ device=device,
532
+ dtype=self.dtype,
533
+ )
534
+ if quant_chart_embeddings.shape[0] == 0:
535
+ quant_chart_embeddings = torch.zeros(
536
+ num_chart_segments,
537
+ self.quant_ohlc_embedder.output_dim,
538
+ device=device,
539
+ dtype=self.dtype,
540
+ )
541
+ interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype)
542
+ batch_ohlc_embeddings_raw = self.chart_fusion(
543
+ torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1)
544
+ )
545
  else:
546
+ batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
547
 
548
  # 1d. Run Graph Updater
549
  pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
models/ohlc_embedder.py CHANGED
@@ -19,11 +19,11 @@ class OHLCEmbedder(nn.Module):
19
  num_intervals: int,
20
  input_channels: int = 2, # Open, Close
21
  # sequence_length: int = 300, # REMOVED: HARDCODED
22
- cnn_channels: List[int] = [16, 32, 64],
23
  kernel_sizes: List[int] = [3, 3, 3],
24
  # --- NEW: Interval embedding dim ---
25
- interval_embed_dim: int = 32,
26
- output_dim: int = 4096,
27
  dtype: torch.dtype = torch.float16
28
  ):
29
  super().__init__()
@@ -116,4 +116,3 @@ class OHLCEmbedder(nn.Module):
116
  # Shape: [batch_size, output_dim]
117
 
118
  return x
119
-
 
19
  num_intervals: int,
20
  input_channels: int = 2, # Open, Close
21
  # sequence_length: int = 300, # REMOVED: HARDCODED
22
+ cnn_channels: List[int] = [8, 16, 32],
23
  kernel_sizes: List[int] = [3, 3, 3],
24
  # --- NEW: Interval embedding dim ---
25
+ interval_embed_dim: int = 16,
26
+ output_dim: int = 512,
27
  dtype: torch.dtype = torch.float16
28
  ):
29
  super().__init__()
 
116
  # Shape: [batch_size, output_dim]
117
 
118
  return x
 
models/quant_ohlc_embedder.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class QuantOHLCEmbedder(nn.Module):
6
+ def __init__(
7
+ self,
8
+ num_features: int,
9
+ sequence_length: int = 60,
10
+ version_vocab_size: int = 4,
11
+ hidden_dim: int = 320,
12
+ num_layers: int = 3,
13
+ num_heads: int = 8,
14
+ output_dim: int = 1536,
15
+ dtype: torch.dtype = torch.float16,
16
+ ):
17
+ super().__init__()
18
+ self.num_features = num_features
19
+ self.sequence_length = sequence_length
20
+ self.output_dim = output_dim
21
+ self.dtype = dtype
22
+
23
+ self.feature_proj = nn.Sequential(
24
+ nn.LayerNorm(num_features),
25
+ nn.Linear(num_features, hidden_dim),
26
+ nn.GELU(),
27
+ )
28
+ self.position_embedding = nn.Parameter(torch.zeros(1, sequence_length, hidden_dim))
29
+ self.version_embedding = nn.Embedding(version_vocab_size, hidden_dim, padding_idx=0)
30
+ encoder_layer = nn.TransformerEncoderLayer(
31
+ d_model=hidden_dim,
32
+ nhead=num_heads,
33
+ dim_feedforward=hidden_dim * 4,
34
+ dropout=0.0,
35
+ batch_first=True,
36
+ activation="gelu",
37
+ norm_first=True,
38
+ )
39
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
40
+ self.output_head = nn.Sequential(
41
+ nn.LayerNorm(hidden_dim),
42
+ nn.Linear(hidden_dim, hidden_dim * 2),
43
+ nn.GELU(),
44
+ nn.LayerNorm(hidden_dim * 2),
45
+ nn.Linear(hidden_dim * 2, output_dim),
46
+ nn.LayerNorm(output_dim),
47
+ )
48
+ self.to(dtype)
49
+
50
+ def forward(
51
+ self,
52
+ feature_tokens: torch.Tensor,
53
+ feature_mask: torch.Tensor,
54
+ version_ids: torch.Tensor,
55
+ ) -> torch.Tensor:
56
+ if feature_tokens.ndim != 3:
57
+ raise ValueError(f"Expected [B, T, F], got {feature_tokens.shape}")
58
+ if feature_tokens.shape[1] != self.sequence_length:
59
+ raise ValueError(f"Expected T={self.sequence_length}, got {feature_tokens.shape[1]}")
60
+ if feature_tokens.shape[2] != self.num_features:
61
+ raise ValueError(f"Expected F={self.num_features}, got {feature_tokens.shape[2]}")
62
+
63
+ x = self.feature_proj(feature_tokens.to(self.dtype))
64
+ version_embed = self.version_embedding(version_ids).unsqueeze(1)
65
+ x = x + self.position_embedding[:, : x.shape[1], :].to(x.dtype) + version_embed
66
+ key_padding_mask = ~(feature_mask > 0)
67
+ x = self.encoder(x, src_key_padding_mask=key_padding_mask)
68
+
69
+ mask = feature_mask.to(x.dtype).unsqueeze(-1)
70
+ valid_any = (feature_mask.sum(dim=1, keepdim=True) > 0).to(x.dtype)
71
+ denom = mask.sum(dim=1).clamp_min(1.0)
72
+ pooled = (x * mask).sum(dim=1) / denom
73
+ out = self.output_head(pooled)
74
+ return out * valid_any
pre_cache.sh CHANGED
@@ -4,6 +4,7 @@
4
  CONTEXT_LENGTH=4096
5
  MIN_TRADES=10
6
  SAMPLES_PER_TOKEN=1
 
7
  NUM_WORKERS=1
8
 
9
  OUTPUT_DIR="data/cache"
@@ -20,6 +21,7 @@ echo "========================================"
20
  echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
21
  echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
22
  echo "Samples per Token: $SAMPLES_PER_TOKEN"
 
23
  echo "Num Workers: $NUM_WORKERS"
24
  echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
25
  echo "Quantiles: ${QUANTILES[*]}"
@@ -36,10 +38,10 @@ python3 scripts/cache_dataset.py \
36
  --context_length "$CONTEXT_LENGTH" \
37
  --min_trades "$MIN_TRADES" \
38
  --samples_per_token "$SAMPLES_PER_TOKEN" \
 
39
  --num_workers "$NUM_WORKERS" \
40
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
41
  --quantiles "${QUANTILES[@]}" \
42
- --max_samples 10 \
43
  "$@"
44
 
45
  echo "Done!"
 
4
  CONTEXT_LENGTH=4096
5
  MIN_TRADES=10
6
  SAMPLES_PER_TOKEN=1
7
+ TARGET_CONTEXTS_PER_CLASS=10
8
  NUM_WORKERS=1
9
 
10
  OUTPUT_DIR="data/cache"
 
21
  echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
22
  echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
23
  echo "Samples per Token: $SAMPLES_PER_TOKEN"
24
+ echo "Target Contexts per Class: $TARGET_CONTEXTS_PER_CLASS"
25
  echo "Num Workers: $NUM_WORKERS"
26
  echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
27
  echo "Quantiles: ${QUANTILES[*]}"
 
38
  --context_length "$CONTEXT_LENGTH" \
39
  --min_trades "$MIN_TRADES" \
40
  --samples_per_token "$SAMPLES_PER_TOKEN" \
41
+ --target_contexts_per_class "$TARGET_CONTEXTS_PER_CLASS" \
42
  --num_workers "$NUM_WORKERS" \
43
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
44
  --quantiles "${QUANTILES[@]}" \
 
45
  "$@"
46
 
47
  echo "Done!"
sample_2kGqvM18kGLby9bY_5.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/cache_dataset.py CHANGED
@@ -6,15 +6,13 @@ import argparse
6
  import datetime
7
  import torch
8
  import json
9
- import math
10
  from pathlib import Path
11
  from tqdm import tqdm
12
  from dotenv import load_dotenv
13
  import huggingface_hub
14
  import logging
15
- from concurrent.futures import ProcessPoolExecutor, as_completed
16
  import multiprocessing as mp
17
- from PIL import Image
18
 
19
  logging.getLogger("httpx").setLevel(logging.WARNING)
20
  logging.getLogger("transformers").setLevel(logging.ERROR)
@@ -23,7 +21,9 @@ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
23
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
 
25
  from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
26
- from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
 
 
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
29
  from neo4j import GraphDatabase
@@ -58,6 +58,32 @@ def _representative_context_polarity(context):
58
  return "positive" if max(valid_returns) > 0.0 else "negative"
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
62
  if len(contexts) <= max_keep:
63
  polarity_counts = {}
@@ -113,68 +139,6 @@ def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desi
113
  return selected[:max_keep], polarity_counts
114
 
115
 
116
- def _allocate_class_targets(mints_by_class, target_total, positive_balance_min_class, positive_ratio):
117
- from collections import defaultdict
118
- import random
119
-
120
- class_ids = sorted(mints_by_class.keys())
121
- if not class_ids:
122
- return {}, {}, {}
123
-
124
- target_per_class = target_total // len(class_ids)
125
- remainder = target_total % len(class_ids)
126
-
127
- token_plans = {}
128
- class_targets = {}
129
- class_polarity_targets = {}
130
-
131
- for pos, class_id in enumerate(class_ids):
132
- class_target = target_per_class + (1 if pos < remainder else 0)
133
- class_targets[class_id] = class_target
134
-
135
- token_list = list(mints_by_class[class_id])
136
- random.shuffle(token_list)
137
- if not token_list or class_target <= 0:
138
- class_polarity_targets[class_id] = {"positive": 0, "negative": 0}
139
- continue
140
-
141
- if class_id >= positive_balance_min_class:
142
- positive_target = int(round(class_target * positive_ratio))
143
- positive_target = min(max(positive_target, 0), class_target)
144
- else:
145
- positive_target = 0
146
- negative_target = class_target - positive_target
147
- class_polarity_targets[class_id] = {
148
- "positive": positive_target,
149
- "negative": negative_target,
150
- }
151
-
152
- assigned_positive = 0
153
- assigned_negative = 0
154
- token_count = len(token_list)
155
- for sample_num in range(class_target):
156
- token_idx, mint_record = token_list[sample_num % token_count]
157
- mint_addr = mint_record["mint_address"]
158
- plan_key = (token_idx, mint_addr)
159
- if plan_key not in token_plans:
160
- token_plans[plan_key] = {
161
- "samples_to_keep": 0,
162
- "desired_positive": 0,
163
- "desired_negative": 0,
164
- "class_id": class_id,
165
- }
166
- token_plans[plan_key]["samples_to_keep"] += 1
167
-
168
- if assigned_positive < positive_target:
169
- token_plans[plan_key]["desired_positive"] += 1
170
- assigned_positive += 1
171
- else:
172
- token_plans[plan_key]["desired_negative"] += 1
173
- assigned_negative += 1
174
-
175
- return token_plans, class_targets, class_polarity_targets
176
-
177
-
178
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
179
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
180
  from data.data_loader import OracleDataset
@@ -200,7 +164,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
200
 
201
  _worker_dataset = OracleDataset(
202
  data_fetcher=data_fetcher,
203
- max_samples=dataset_config['max_samples'],
204
  start_date=dataset_config['start_date'],
205
  horizons_seconds=dataset_config['horizons_seconds'],
206
  quantiles=dataset_config['quantiles'],
@@ -214,7 +177,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
214
 
215
 
216
  def _process_single_token_context(args):
217
- idx, mint_addr, samples_per_token, output_dir, oversample_factor, desired_positive, desired_negative = args
218
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
219
  try:
220
  class_id = _worker_return_class_map.get(mint_addr)
@@ -244,15 +207,6 @@ def _process_single_token_context(args):
244
  q_score = _worker_quality_scores_map.get(mint_addr)
245
  if q_score is None:
246
  return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
247
- saved_files = []
248
- for ctx_idx, ctx in enumerate(contexts):
249
- ctx["quality_score"] = q_score
250
- ctx["class_id"] = class_id
251
- filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
252
- output_path = Path(output_dir) / filename
253
-
254
- torch.save(ctx, output_path)
255
- saved_files.append(filename)
256
  return {
257
  'status': 'success',
258
  'mint': mint_addr,
@@ -260,7 +214,7 @@ def _process_single_token_context(args):
260
  'q_score': q_score,
261
  'n_contexts': len(contexts),
262
  'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
263
- 'files': saved_files,
264
  'polarity_counts': polarity_counts,
265
  }
266
  except Exception as e:
@@ -284,7 +238,6 @@ def main():
284
 
285
  parser = argparse.ArgumentParser()
286
  parser.add_argument("--output_dir", type=str, default="data/cache")
287
- parser.add_argument("--max_samples", type=int, default=None)
288
  parser.add_argument("--start_date", type=str, default=None)
289
 
290
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
@@ -292,8 +245,8 @@ def main():
292
  parser.add_argument("--context_length", type=int, default=8192)
293
  parser.add_argument("--min_trades", type=int, default=10)
294
  parser.add_argument("--samples_per_token", type=int, default=1)
 
295
  parser.add_argument("--context_oversample_factor", type=int, default=4)
296
- parser.add_argument("--cache_balance_mode", type=str, default="hybrid", choices=["class", "uniform", "hybrid"])
297
  parser.add_argument("--positive_balance_min_class", type=int, default=2)
298
  parser.add_argument("--positive_context_ratio", type=float, default=0.5)
299
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
@@ -308,6 +261,10 @@ def main():
308
 
309
  if args.num_workers == 0:
310
  args.num_workers = max(1, mp.cpu_count() - 4)
 
 
 
 
311
 
312
  output_dir = Path(args.output_dir)
313
  output_dir.mkdir(parents=True, exist_ok=True)
@@ -334,7 +291,7 @@ def main():
334
  quality_scores_map = get_token_quality_scores(clickhouse_client)
335
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
336
 
337
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
338
 
339
  if len(dataset) == 0:
340
  print("WARNING: No samples. Exiting.")
@@ -370,93 +327,73 @@ def main():
370
  print(f"INFO: Workers: {args.num_workers}")
371
 
372
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
373
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
374
 
375
- # Build tasks with class-aware multi-sampling for balanced cache
376
  import random
377
- from collections import Counter, defaultdict
378
 
379
- # Count eligible tokens per class
380
  eligible_class_counts = Counter()
381
- mints_by_class = defaultdict(list)
382
  for i, m in enumerate(filtered_mints):
383
  cid = return_class_map.get(m['mint_address'])
384
  if cid is not None:
385
  eligible_class_counts[cid] += 1
386
- mints_by_class[cid].append((i, m))
387
 
388
  print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
389
 
390
- num_classes = len(eligible_class_counts)
391
- if args.max_samples:
392
- target_total = args.max_samples
393
- else:
394
- target_total = 15000 # Default target: 15k balanced files
395
- target_per_class = target_total // max(num_classes, 1)
396
-
397
- token_plans, class_targets, class_polarity_targets = _allocate_class_targets(
398
- mints_by_class=mints_by_class,
399
- target_total=target_total,
400
- positive_balance_min_class=args.positive_balance_min_class,
401
- positive_ratio=args.positive_context_ratio,
402
- )
403
 
404
- print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
 
405
  print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
406
  print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
407
 
408
- # Build balanced task list
409
- tasks = []
410
- if args.cache_balance_mode == "uniform":
411
- target_tokens = len(filtered_mints)
412
- if args.max_samples:
413
- target_tokens = min(len(filtered_mints), max(1, math.ceil(args.max_samples / max(args.samples_per_token, 1))))
414
- mint_pool = list(enumerate(filtered_mints))
415
- random.shuffle(mint_pool)
416
- for i, m in mint_pool[:target_tokens]:
417
- tasks.append((i, m['mint_address'], args.samples_per_token, str(output_dir), args.context_oversample_factor, 0, args.samples_per_token))
418
- else:
419
- for (token_idx, mint_addr), plan in token_plans.items():
420
- tasks.append((
421
- token_idx,
422
- mint_addr,
423
- plan["samples_to_keep"],
424
- str(output_dir),
425
- args.context_oversample_factor,
426
- plan["desired_positive"],
427
- plan["desired_negative"],
428
- ))
429
-
430
- random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
431
- expected_files = sum(task[2] for task in tasks)
432
- print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
433
 
434
  success_count, skipped_count, error_count = 0, 0, 0
435
- class_distribution = {}
436
- polarity_distribution = {}
 
 
 
 
437
 
438
- # --- Resume support: skip tokens that already have cached files ---
439
  existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
440
  if existing_files:
441
- pre_resume = len(tasks)
442
- filtered_tasks = []
443
  already_cached = 0
444
- for task in tasks:
445
- mint_addr = task[1] # task = (idx, mint_addr, ...)
446
- # Check if any file exists for this mint (context mode: sample_MINT_0.pt, raw mode: sample_MINT.pt)
447
- mint_prefix = f"sample_{mint_addr[:16]}"
448
- has_cached = any(ef.startswith(mint_prefix) for ef in existing_files)
449
- if has_cached:
450
- already_cached += 1
451
- # Count existing files toward class distribution
452
- cid = return_class_map.get(mint_addr)
453
- if cid is not None:
454
- class_distribution[cid] = class_distribution.get(cid, 0) + 1
455
- success_count += 1
456
- else:
457
- filtered_tasks.append(task)
458
- tasks = filtered_tasks
459
- print(f"INFO: Resume: {already_cached} tokens already cached, {len(tasks)} remaining (was {pre_resume})")
 
 
 
 
460
 
461
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
462
  process_fn = _process_single_token_context
@@ -484,64 +421,113 @@ def main():
484
  error_log_path = Path(args.output_dir) / "cache_errors.log"
485
  error_samples = [] # First 20 unique error messages
486
 
487
- if args.num_workers == 1:
488
- print("INFO: Single-threaded mode...")
489
- _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
490
- start_time = _time.perf_counter()
491
- recent_times = []
492
- for task_num, task in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
493
- t0 = _time.perf_counter()
494
- result = process_fn(task)
495
- elapsed = _time.perf_counter() - t0
496
- recent_times.append(elapsed)
497
- if len(recent_times) > 50:
498
- recent_times.pop(0)
499
- if result['status'] == 'success':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  success_count += 1
501
- class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
502
- for polarity, count in result.get('polarity_counts', {}).items():
503
- polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
504
- elif result['status'] == 'skipped':
505
- skipped_count += 1
506
  else:
507
- error_count += 1
508
- err_msg = result.get('error', 'unknown')
509
- tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
510
- if len(error_samples) < 20:
511
- error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
 
 
 
 
 
 
 
512
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
513
- else:
514
- print(f"INFO: Running with {args.num_workers} workers...")
515
- start_time = _time.perf_counter()
516
- recent_times = []
517
- with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
518
- futures = {executor.submit(process_fn, task): task for task in tasks}
519
- for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
520
- t0 = _time.perf_counter()
521
- try:
522
- result = future.result(timeout=300)
523
- elapsed = _time.perf_counter() - t0
524
- recent_times.append(elapsed)
525
- if len(recent_times) > 50:
526
- recent_times.pop(0)
527
- if result['status'] == 'success':
528
- success_count += 1
529
- class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
530
- for polarity, count in result.get('polarity_counts', {}).items():
531
- polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
532
- elif result['status'] == 'skipped':
533
- skipped_count += 1
534
- else:
535
- error_count += 1
536
- err_msg = result.get('error', 'unknown')
537
- if len(error_samples) < 20:
538
- error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
539
- if error_count <= 5:
540
- tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
541
- except Exception as e:
542
- error_count += 1
543
- tqdm.write(f"WORKER ERROR: {e}")
544
- _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
545
 
546
  # Write error log
547
  if error_samples:
@@ -553,28 +539,24 @@ def main():
553
  print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
554
 
555
  print("INFO: Building metadata...")
556
- file_class_map = {}
557
- for f in sorted(output_dir.glob("sample_*.pt")):
558
- try:
559
- file_class_map[f.name] = torch.load(f, map_location="cpu", weights_only=False).get("class_id", 0)
560
- except:
561
- pass
562
-
563
  with open(output_dir / "class_metadata.json", 'w') as f:
564
  json.dump({
565
  'file_class_map': file_class_map,
 
 
566
  'class_distribution': {str(k): v for k, v in class_distribution.items()},
567
  'num_workers': args.num_workers,
568
  'horizons_seconds': args.horizons_seconds,
569
  'quantiles': args.quantiles,
570
  'target_total': target_total,
571
- 'target_per_class': target_per_class,
572
- 'cache_balance_mode': args.cache_balance_mode,
573
  'context_polarity_distribution': polarity_distribution,
574
  'class_targets': {str(k): v for k, v in class_targets.items()},
575
  'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
 
576
  'positive_balance_min_class': args.positive_balance_min_class,
577
  'positive_context_ratio': args.positive_context_ratio,
 
578
  }, f, indent=2)
579
 
580
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
 
6
  import datetime
7
  import torch
8
  import json
 
9
  from pathlib import Path
10
  from tqdm import tqdm
11
  from dotenv import load_dotenv
12
  import huggingface_hub
13
  import logging
 
14
  import multiprocessing as mp
15
+ from collections import Counter, defaultdict
16
 
17
  logging.getLogger("httpx").setLevel(logging.WARNING)
18
  logging.getLogger("transformers").setLevel(logging.ERROR)
 
21
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
 
23
  from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
24
+ from scripts.compute_quality_score import get_token_quality_scores
25
+ from data.data_loader import summarize_context_window
26
+ from data.quant_ohlc_feature_schema import FEATURE_VERSION
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
29
  from neo4j import GraphDatabase
 
58
  return "positive" if max(valid_returns) > 0.0 else "negative"
59
 
60
 
61
+ def _class_polarity_targets(class_id, target_contexts_per_class, positive_balance_min_class, positive_ratio):
62
+ if class_id >= positive_balance_min_class:
63
+ positive_target = int(round(target_contexts_per_class * positive_ratio))
64
+ positive_target = min(max(positive_target, 0), target_contexts_per_class)
65
+ else:
66
+ positive_target = 0
67
+ return {
68
+ "positive": positive_target,
69
+ "negative": max(0, target_contexts_per_class - positive_target),
70
+ }
71
+
72
+
73
+ def _remaining_polarity_targets(class_id, accepted_counts, target_contexts_per_class, positive_balance_min_class, positive_ratio):
74
+ targets = _class_polarity_targets(
75
+ class_id=class_id,
76
+ target_contexts_per_class=target_contexts_per_class,
77
+ positive_balance_min_class=positive_balance_min_class,
78
+ positive_ratio=positive_ratio,
79
+ )
80
+ class_counts = accepted_counts[class_id]
81
+ return {
82
+ "positive": max(0, targets["positive"] - class_counts["positive"]),
83
+ "negative": max(0, targets["negative"] - class_counts["negative"]),
84
+ }
85
+
86
+
87
  def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
88
  if len(contexts) <= max_keep:
89
  polarity_counts = {}
 
139
  return selected[:max_keep], polarity_counts
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
143
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
144
  from data.data_loader import OracleDataset
 
164
 
165
  _worker_dataset = OracleDataset(
166
  data_fetcher=data_fetcher,
 
167
  start_date=dataset_config['start_date'],
168
  horizons_seconds=dataset_config['horizons_seconds'],
169
  quantiles=dataset_config['quantiles'],
 
177
 
178
 
179
  def _process_single_token_context(args):
180
+ idx, mint_addr, samples_per_token, oversample_factor, desired_positive, desired_negative = args
181
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
182
  try:
183
  class_id = _worker_return_class_map.get(mint_addr)
 
207
  q_score = _worker_quality_scores_map.get(mint_addr)
208
  if q_score is None:
209
  return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
 
 
 
 
 
 
 
 
 
210
  return {
211
  'status': 'success',
212
  'mint': mint_addr,
 
214
  'q_score': q_score,
215
  'n_contexts': len(contexts),
216
  'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
217
+ 'contexts': contexts,
218
  'polarity_counts': polarity_counts,
219
  }
220
  except Exception as e:
 
238
 
239
  parser = argparse.ArgumentParser()
240
  parser.add_argument("--output_dir", type=str, default="data/cache")
 
241
  parser.add_argument("--start_date", type=str, default=None)
242
 
243
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
 
245
  parser.add_argument("--context_length", type=int, default=8192)
246
  parser.add_argument("--min_trades", type=int, default=10)
247
  parser.add_argument("--samples_per_token", type=int, default=1)
248
+ parser.add_argument("--target_contexts_per_class", type=int, default=2500)
249
  parser.add_argument("--context_oversample_factor", type=int, default=4)
 
250
  parser.add_argument("--positive_balance_min_class", type=int, default=2)
251
  parser.add_argument("--positive_context_ratio", type=float, default=0.5)
252
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
 
261
 
262
  if args.num_workers == 0:
263
  args.num_workers = max(1, mp.cpu_count() - 4)
264
+ if args.num_workers != 1:
265
+ raise RuntimeError("Quota-based caching requires --num_workers 1 so class counters remain exact.")
266
+ if args.target_contexts_per_class <= 0:
267
+ raise RuntimeError("--target_contexts_per_class must be positive.")
268
 
269
  output_dir = Path(args.output_dir)
270
  output_dir.mkdir(parents=True, exist_ok=True)
 
291
  quality_scores_map = get_token_quality_scores(clickhouse_client)
292
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
293
 
294
+ dataset = OracleDataset(data_fetcher=data_fetcher, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
295
 
296
  if len(dataset) == 0:
297
  print("WARNING: No samples. Exiting.")
 
327
  print(f"INFO: Workers: {args.num_workers}")
328
 
329
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
330
+ dataset_config = {'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
331
 
 
332
  import random
 
333
 
 
334
  eligible_class_counts = Counter()
 
335
  for i, m in enumerate(filtered_mints):
336
  cid = return_class_map.get(m['mint_address'])
337
  if cid is not None:
338
  eligible_class_counts[cid] += 1
 
339
 
340
  print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
341
 
342
+ class_targets = {
343
+ int(class_id): int(args.target_contexts_per_class)
344
+ for class_id in sorted(eligible_class_counts.keys())
345
+ }
346
+ class_polarity_targets = {
347
+ class_id: _class_polarity_targets(
348
+ class_id=class_id,
349
+ target_contexts_per_class=args.target_contexts_per_class,
350
+ positive_balance_min_class=args.positive_balance_min_class,
351
+ positive_ratio=args.positive_context_ratio,
352
+ )
353
+ for class_id in class_targets
354
+ }
355
 
356
+ target_total = args.target_contexts_per_class * len(class_targets)
357
+ print(f"INFO: Target total: {target_total}, Target per class: {args.target_contexts_per_class}")
358
  print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
359
  print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
360
 
361
+ tasks = list(enumerate(filtered_mints))
362
+ random.shuffle(tasks)
363
+ print(f"INFO: Total candidate tokens: {len(tasks)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  success_count, skipped_count, error_count = 0, 0, 0
366
+ class_distribution = defaultdict(int)
367
+ polarity_distribution = defaultdict(int)
368
+ file_class_map = {}
369
+ file_context_bucket_map = {}
370
+ file_context_summary_map = {}
371
+ accepted_counts = defaultdict(lambda: {"total": 0, "positive": 0, "negative": 0})
372
 
373
+ # Resume support: count existing files toward quotas.
374
  existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
375
  if existing_files:
 
 
376
  already_cached = 0
377
+ for f in sorted(output_dir.glob("sample_*.pt")):
378
+ try:
379
+ cached = torch.load(f, map_location="cpu", weights_only=False)
380
+ except Exception:
381
+ continue
382
+ class_id = cached.get("class_id")
383
+ if class_id is None or int(class_id) not in class_targets:
384
+ continue
385
+ class_id = int(class_id)
386
+ context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
387
+ polarity = _representative_context_polarity(cached)
388
+ file_class_map[f.name] = class_id
389
+ file_context_bucket_map[f.name] = context_summary["context_bucket"]
390
+ file_context_summary_map[f.name] = context_summary
391
+ class_distribution[class_id] += 1
392
+ polarity_distribution[polarity] += 1
393
+ accepted_counts[class_id]["total"] += 1
394
+ accepted_counts[class_id][polarity] += 1
395
+ already_cached += 1
396
+ print(f"INFO: Resume: counted {already_cached} cached contexts toward quotas.")
397
 
398
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
399
  process_fn = _process_single_token_context
 
421
  error_log_path = Path(args.output_dir) / "cache_errors.log"
422
  error_samples = [] # First 20 unique error messages
423
 
424
+ print("INFO: Single-threaded mode...")
425
+ _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
426
+ start_time = _time.perf_counter()
427
+ recent_times = []
428
+ completed_classes = set()
429
+ for task_num, (idx, mint_record) in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
430
+ mint_addr = mint_record["mint_address"]
431
+ class_id = return_class_map.get(mint_addr)
432
+ if class_id is None:
433
+ skipped_count += 1
434
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
435
+ continue
436
+ if accepted_counts[class_id]["total"] >= class_targets[class_id]:
437
+ if class_id not in completed_classes:
438
+ completed_classes.add(class_id)
439
+ tqdm.write(f"INFO: Class {class_id} quota filled. Skipping remaining tokens for this class.")
440
+ skipped_count += 1
441
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
442
+ continue
443
+
444
+ remaining_total = class_targets[class_id] - accepted_counts[class_id]["total"]
445
+ remaining_polarity = _remaining_polarity_targets(
446
+ class_id=class_id,
447
+ accepted_counts=accepted_counts,
448
+ target_contexts_per_class=args.target_contexts_per_class,
449
+ positive_balance_min_class=args.positive_balance_min_class,
450
+ positive_ratio=args.positive_context_ratio,
451
+ )
452
+ desired_positive = min(remaining_polarity["positive"], args.samples_per_token, remaining_total)
453
+ desired_negative = min(
454
+ remaining_polarity["negative"],
455
+ max(0, min(args.samples_per_token, remaining_total) - desired_positive),
456
+ )
457
+ samples_to_keep = min(args.samples_per_token, remaining_total)
458
+ task = (
459
+ idx,
460
+ mint_addr,
461
+ samples_to_keep,
462
+ args.context_oversample_factor,
463
+ desired_positive,
464
+ desired_negative,
465
+ )
466
+
467
+ t0 = _time.perf_counter()
468
+ result = process_fn(task)
469
+ elapsed = _time.perf_counter() - t0
470
+ recent_times.append(elapsed)
471
+ if len(recent_times) > 50:
472
+ recent_times.pop(0)
473
+
474
+ if result["status"] == "success":
475
+ saved_contexts = 0
476
+ for ctx in result.get("contexts", []):
477
+ if accepted_counts[class_id]["total"] >= class_targets[class_id]:
478
+ break
479
+
480
+ polarity = _representative_context_polarity(ctx)
481
+ remaining_polarity = _remaining_polarity_targets(
482
+ class_id=class_id,
483
+ accepted_counts=accepted_counts,
484
+ target_contexts_per_class=args.target_contexts_per_class,
485
+ positive_balance_min_class=args.positive_balance_min_class,
486
+ positive_ratio=args.positive_context_ratio,
487
+ )
488
+ other_polarity = "negative" if polarity == "positive" else "positive"
489
+ if remaining_polarity[polarity] <= 0 and remaining_polarity[other_polarity] > 0:
490
+ continue
491
+
492
+ ctx["quality_score"] = result["q_score"]
493
+ ctx["class_id"] = class_id
494
+ ctx["source_token"] = mint_addr
495
+ context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
496
+ ctx["context_bucket"] = context_summary["context_bucket"]
497
+ ctx["context_score"] = context_summary["context_score"]
498
+ file_idx = accepted_counts[class_id]["total"]
499
+ filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
500
+ output_path = output_dir / filename
501
+ torch.save(ctx, output_path)
502
+
503
+ file_class_map[filename] = class_id
504
+ file_context_bucket_map[filename] = context_summary["context_bucket"]
505
+ file_context_summary_map[filename] = context_summary
506
+ class_distribution[class_id] += 1
507
+ polarity_distribution[polarity] += 1
508
+ accepted_counts[class_id]["total"] += 1
509
+ accepted_counts[class_id][polarity] += 1
510
+ saved_contexts += 1
511
+
512
+ if saved_contexts > 0:
513
  success_count += 1
 
 
 
 
 
514
  else:
515
+ skipped_count += 1
516
+ elif result["status"] == "skipped":
517
+ skipped_count += 1
518
+ else:
519
+ error_count += 1
520
+ err_msg = result.get("error", "unknown")
521
+ tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
522
+ if len(error_samples) < 20:
523
+ error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
524
+
525
+ if all(accepted_counts[cid]["total"] >= class_targets[cid] for cid in class_targets):
526
+ tqdm.write("INFO: All class quotas filled. Stopping early.")
527
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
528
+ break
529
+
530
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  # Write error log
533
  if error_samples:
 
539
  print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
540
 
541
  print("INFO: Building metadata...")
 
 
 
 
 
 
 
542
  with open(output_dir / "class_metadata.json", 'w') as f:
543
  json.dump({
544
  'file_class_map': file_class_map,
545
+ 'file_context_bucket_map': file_context_bucket_map,
546
+ 'file_context_summary_map': file_context_summary_map,
547
  'class_distribution': {str(k): v for k, v in class_distribution.items()},
548
  'num_workers': args.num_workers,
549
  'horizons_seconds': args.horizons_seconds,
550
  'quantiles': args.quantiles,
551
  'target_total': target_total,
552
+ 'target_contexts_per_class': args.target_contexts_per_class,
 
553
  'context_polarity_distribution': polarity_distribution,
554
  'class_targets': {str(k): v for k, v in class_targets.items()},
555
  'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
556
+ 'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
557
  'positive_balance_min_class': args.positive_balance_min_class,
558
  'positive_context_ratio': args.positive_context_ratio,
559
+ 'quant_feature_version': FEATURE_VERSION,
560
  }, f, indent=2)
561
 
562
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
scripts/dump_cache_sample.py CHANGED
@@ -121,7 +121,7 @@ def main():
121
  "__metadata__": {
122
  "source_file": str(filepath.absolute()),
123
  "dumped_at": datetime.now().isoformat(),
124
- "cache_mode": data.get("cache_mode", "unknown") if isinstance(data, dict) else "unknown"
125
  },
126
  "data": serializable_data
127
  }
@@ -143,7 +143,7 @@ def main():
143
  if isinstance(data, dict):
144
  print(f"\n=== Summary ===")
145
  print(f"Top-level keys: {list(data.keys())}")
146
- print(f"Cache mode: {data.get('cache_mode', 'not specified')}")
147
  if 'event_sequence' in data:
148
  print(f"Event count: {len(data['event_sequence'])}")
149
  if 'trades' in data:
@@ -152,6 +152,10 @@ def main():
152
  print(f"Source token: {data['source_token']}")
153
  if 'class_id' in data:
154
  print(f"Class ID: {data['class_id']}")
 
 
 
 
155
  if 'quality_score' in data:
156
  print(f"Quality score: {data['quality_score']}")
157
 
 
121
  "__metadata__": {
122
  "source_file": str(filepath.absolute()),
123
  "dumped_at": datetime.now().isoformat(),
124
+ "cache_format": "context" if isinstance(data, dict) and "event_sequence" in data else "legacy"
125
  },
126
  "data": serializable_data
127
  }
 
143
  if isinstance(data, dict):
144
  print(f"\n=== Summary ===")
145
  print(f"Top-level keys: {list(data.keys())}")
146
+ print(f"Cache format: {'context' if 'event_sequence' in data else 'legacy'}")
147
  if 'event_sequence' in data:
148
  print(f"Event count: {len(data['event_sequence'])}")
149
  if 'trades' in data:
 
152
  print(f"Source token: {data['source_token']}")
153
  if 'class_id' in data:
154
  print(f"Class ID: {data['class_id']}")
155
+ if 'context_bucket' in data:
156
+ print(f"Context bucket: {data['context_bucket']}")
157
+ if 'context_score' in data:
158
+ print(f"Context score: {data['context_score']}")
159
  if 'quality_score' in data:
160
  print(f"Quality score: {data['quality_score']}")
161
 
scripts/evaluate_sample.py CHANGED
@@ -23,8 +23,10 @@ from models.token_encoder import TokenEncoder
23
  from models.wallet_encoder import WalletEncoder
24
  from models.graph_updater import GraphUpdater
25
  from models.ohlc_embedder import OHLCEmbedder
 
26
  from models.model import Oracle
27
  import models.vocabulary as vocab
 
28
  from train import create_balanced_split
29
  from dotenv import load_dotenv
30
  from clickhouse_driver import Client as ClickHouseClient
@@ -43,6 +45,12 @@ ABLATION_SWEEP_MODES = [
43
  "trade",
44
  "onchain",
45
  "wallet_graph",
 
 
 
 
 
 
46
  ]
47
 
48
  OHLC_PROBE_MODES = [
@@ -77,7 +85,7 @@ def parse_args():
77
  "--ablation",
78
  type=str,
79
  default="none",
80
- choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe"],
81
  help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
82
  )
83
  return parser.parse_args()
@@ -164,6 +172,23 @@ def apply_ablation(batch, mode, device):
164
  ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
165
  if "ohlc_interval_ids" in ablated:
166
  ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  if mode in {"trade", "all"}:
169
  for key in (
@@ -627,6 +652,11 @@ def main():
627
  wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
628
  graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
629
  ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
 
 
 
 
 
630
 
631
  collator = MemecoinCollator(
632
  event_type_to_id=vocab.EVENT_TO_ID,
@@ -641,6 +671,7 @@ def main():
641
  wallet_encoder=wallet_encoder,
642
  graph_updater=graph_updater,
643
  ohlc_embedder=ohlc_embedder,
 
644
  time_encoder=time_encoder,
645
  num_event_types=vocab.NUM_EVENT_TYPES,
646
  multi_modal_dim=multi_modal_encoder.embedding_dim,
 
23
  from models.wallet_encoder import WalletEncoder
24
  from models.graph_updater import GraphUpdater
25
  from models.ohlc_embedder import OHLCEmbedder
26
+ from models.quant_ohlc_embedder import QuantOHLCEmbedder
27
  from models.model import Oracle
28
  import models.vocabulary as vocab
29
+ from data.quant_ohlc_feature_schema import FEATURE_GROUPS, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT, group_feature_indices
30
  from train import create_balanced_split
31
  from dotenv import load_dotenv
32
  from clickhouse_driver import Client as ClickHouseClient
 
45
  "trade",
46
  "onchain",
47
  "wallet_graph",
48
+ "quant_ohlc",
49
+ "quant_levels",
50
+ "quant_trendline",
51
+ "quant_breaks",
52
+ "quant_rolling",
53
+ "quant_patterns",
54
  ]
55
 
56
  OHLC_PROBE_MODES = [
 
85
  "--ablation",
86
  type=str,
87
  default="none",
88
+ choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling", "quant_patterns"],
89
  help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
90
  )
91
  return parser.parse_args()
 
172
  ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
173
  if "ohlc_interval_ids" in ablated:
174
  ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
175
+ if "quant_ohlc_feature_tensors" in ablated:
176
+ ablated["quant_ohlc_feature_tensors"] = torch.zeros_like(ablated["quant_ohlc_feature_tensors"])
177
+ if "quant_ohlc_feature_mask" in ablated:
178
+ ablated["quant_ohlc_feature_mask"] = torch.zeros_like(ablated["quant_ohlc_feature_mask"])
179
+
180
+ quant_group_map = {
181
+ "quant_ohlc": list(FEATURE_GROUPS.keys()),
182
+ "quant_levels": ["levels_breaks"],
183
+ "quant_trendline": ["trendlines"],
184
+ "quant_breaks": ["relative_structure", "levels_breaks"],
185
+ "quant_rolling": ["rolling_quant"],
186
+ "quant_patterns": ["patterns"],
187
+ }
188
+ if mode in quant_group_map and "quant_ohlc_feature_tensors" in ablated:
189
+ idxs = group_feature_indices(quant_group_map[mode])
190
+ if idxs:
191
+ ablated["quant_ohlc_feature_tensors"][:, :, idxs] = 0
192
 
193
  if mode in {"trade", "all"}:
194
  for key in (
 
652
  wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
653
  graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
654
  ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
655
+ quant_ohlc_embedder = QuantOHLCEmbedder(
656
+ num_features=NUM_QUANT_OHLC_FEATURES,
657
+ sequence_length=TOKENS_PER_SEGMENT,
658
+ dtype=init_dtype,
659
+ )
660
 
661
  collator = MemecoinCollator(
662
  event_type_to_id=vocab.EVENT_TO_ID,
 
671
  wallet_encoder=wallet_encoder,
672
  graph_updater=graph_updater,
673
  ohlc_embedder=ohlc_embedder,
674
+ quant_ohlc_embedder=quant_ohlc_embedder,
675
  time_encoder=time_encoder,
676
  num_event_types=vocab.NUM_EVENT_TYPES,
677
  multi_modal_dim=multi_modal_encoder.embedding_dim,
scripts/rebuild_metadata.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
  from tqdm import tqdm
7
  from collections import defaultdict
8
  from data.data_loader import summarize_context_window
 
9
 
10
  def rebuild_metadata(cache_dir="data/cache"):
11
  cache_path = Path(cache_dir)
@@ -52,6 +53,7 @@ def rebuild_metadata(cache_dir="data/cache"):
52
  'num_workers': 1,
53
  'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
54
  'quantiles': [0.1, 0.5, 0.9],
 
55
  }
56
 
57
  out_file = cache_path / "class_metadata.json"
 
6
  from tqdm import tqdm
7
  from collections import defaultdict
8
  from data.data_loader import summarize_context_window
9
+ from data.quant_ohlc_feature_schema import FEATURE_VERSION
10
 
11
  def rebuild_metadata(cache_dir="data/cache"):
12
  cache_path = Path(cache_dir)
 
53
  'num_workers': 1,
54
  'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
55
  'quantiles': [0.1, 0.5, 0.9],
56
+ 'quant_feature_version': FEATURE_VERSION,
57
  }
58
 
59
  out_file = cache_path / "class_metadata.json"
signals/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Deterministic chart signal extraction package.
signals/patterns.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+
3
+ import numpy as np
4
+ from scipy.signal import find_peaks
5
+
6
+ from data.quant_ohlc_feature_schema import PATTERN_NAMES
7
+
8
+
9
+ def _empty_pattern_output() -> Dict[str, float]:
10
+ out = {f"pattern_{name}_confidence": 0.0 for name in PATTERN_NAMES}
11
+ out["pattern_available"] = 1.0
12
+ return out
13
+
14
+
15
+ def _confidence_from_error(error: float, tolerance: float) -> float:
16
+ if tolerance <= 1e-8:
17
+ return 0.0
18
+ return float(max(0.0, min(1.0, 1.0 - (error / tolerance))))
19
+
20
+
21
+ def _recent_prominent_peaks(series: np.ndarray, distance: int, prominence: float) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
22
+ peaks, props = find_peaks(series, distance=distance, prominence=prominence)
23
+ if peaks.size == 0:
24
+ return peaks, props
25
+ order = np.argsort(props["prominences"])
26
+ keep = order[-5:]
27
+ keep_sorted = np.sort(keep)
28
+ peaks = peaks[keep_sorted]
29
+ props = {key: value[keep_sorted] for key, value in props.items()}
30
+ return peaks, props
31
+
32
+
33
+ def _double_top_confidence(highs: np.ndarray, current_price: float, tolerance: float) -> float:
34
+ peaks, props = _recent_prominent_peaks(highs, distance=3, prominence=tolerance * 0.5)
35
+ if peaks.size < 2:
36
+ return 0.0
37
+ top1_idx, top2_idx = peaks[-2], peaks[-1]
38
+ top1 = float(highs[top1_idx])
39
+ top2 = float(highs[top2_idx])
40
+ neckline = float(np.min(highs[top1_idx:top2_idx + 1])) if top2_idx > top1_idx else min(top1, top2)
41
+ if current_price > max(top1, top2):
42
+ return 0.0
43
+ symmetry = _confidence_from_error(abs(top1 - top2), tolerance)
44
+ separation = min(1.0, float(top2_idx - top1_idx) / 8.0)
45
+ breakdown = 1.0 if current_price <= neckline else 0.6
46
+ prominence = min(1.0, float(np.mean(props["prominences"][-2:])) / max(tolerance, 1e-8))
47
+ return float(max(0.0, min(1.0, symmetry * separation * breakdown * prominence)))
48
+
49
+
50
+ def _double_bottom_confidence(lows: np.ndarray, current_price: float, tolerance: float) -> float:
51
+ troughs, props = _recent_prominent_peaks(-lows, distance=3, prominence=tolerance * 0.5)
52
+ if troughs.size < 2:
53
+ return 0.0
54
+ low1_idx, low2_idx = troughs[-2], troughs[-1]
55
+ low1 = float(lows[low1_idx])
56
+ low2 = float(lows[low2_idx])
57
+ ceiling = float(np.max(lows[low1_idx:low2_idx + 1])) if low2_idx > low1_idx else max(low1, low2)
58
+ if current_price < min(low1, low2):
59
+ return 0.0
60
+ symmetry = _confidence_from_error(abs(low1 - low2), tolerance)
61
+ separation = min(1.0, float(low2_idx - low1_idx) / 8.0)
62
+ breakout = 1.0 if current_price >= ceiling else 0.6
63
+ prominence = min(1.0, float(np.mean(props["prominences"][-2:])) / max(tolerance, 1e-8))
64
+ return float(max(0.0, min(1.0, symmetry * separation * breakout * prominence)))
65
+
66
+
67
+ def _triangle_confidences(highs: np.ndarray, lows: np.ndarray, tolerance: float) -> Dict[str, float]:
68
+ out = {
69
+ "ascending_triangle": 0.0,
70
+ "descending_triangle": 0.0,
71
+ }
72
+ peak_idx, _ = _recent_prominent_peaks(highs, distance=3, prominence=tolerance * 0.5)
73
+ trough_idx, _ = _recent_prominent_peaks(-lows, distance=3, prominence=tolerance * 0.5)
74
+ if peak_idx.size < 2 or trough_idx.size < 2:
75
+ return out
76
+
77
+ peak_vals = highs[peak_idx[-3:]]
78
+ trough_vals = lows[trough_idx[-3:]]
79
+ peak_slope = np.polyfit(np.arange(len(peak_vals), dtype=np.float64), peak_vals.astype(np.float64), deg=1)[0]
80
+ trough_slope = np.polyfit(np.arange(len(trough_vals), dtype=np.float64), trough_vals.astype(np.float64), deg=1)[0]
81
+ peak_flatness = _confidence_from_error(float(np.max(peak_vals) - np.min(peak_vals)), tolerance)
82
+ trough_flatness = _confidence_from_error(float(np.max(trough_vals) - np.min(trough_vals)), tolerance)
83
+
84
+ out["ascending_triangle"] = float(max(0.0, min(1.0, peak_flatness * max(0.0, trough_slope) / max(tolerance, 1e-8))))
85
+ out["descending_triangle"] = float(max(0.0, min(1.0, trough_flatness * max(0.0, -peak_slope) / max(tolerance, 1e-8))))
86
+ return out
87
+
88
+
89
+ def _head_shoulders_confidence(highs: np.ndarray, lows: np.ndarray, tolerance: float, inverse: bool = False) -> float:
90
+ series = -lows if inverse else highs
91
+ pivots, props = _recent_prominent_peaks(series, distance=3, prominence=tolerance * 0.5)
92
+ if pivots.size < 3:
93
+ return 0.0
94
+ idxs = pivots[-3:]
95
+ values = series[idxs]
96
+ left, head, right = [float(v) for v in values]
97
+ shoulders_match = _confidence_from_error(abs(left - right), tolerance)
98
+ if inverse:
99
+ head_margin = max(0.0, min(left, right) - head)
100
+ else:
101
+ head_margin = max(0.0, head - max(left, right))
102
+ head_score = min(1.0, head_margin / max(tolerance, 1e-8))
103
+ spacing = min(1.0, float(min(idxs[1] - idxs[0], idxs[2] - idxs[1])) / 5.0)
104
+ prominence = min(1.0, float(np.mean(props["prominences"][-3:])) / max(tolerance, 1e-8))
105
+ return float(max(0.0, min(1.0, shoulders_match * head_score * spacing * prominence)))
106
+
107
+
108
+ def compute_pattern_features(closes, highs, lows, end_idx: int) -> Dict[str, float]:
109
+ out = _empty_pattern_output()
110
+ closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
111
+ highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
112
+ lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
113
+ if closes_np.size < 10:
114
+ return out
115
+
116
+ current_price = float(closes_np[-1])
117
+ tolerance = max(float(np.std(closes_np[-20:])) if closes_np.size >= 20 else float(np.std(closes_np)), current_price * 0.003, 1e-5)
118
+
119
+ out["pattern_double_top_confidence"] = _double_top_confidence(highs_np, current_price, tolerance)
120
+ out["pattern_double_bottom_confidence"] = _double_bottom_confidence(lows_np, current_price, tolerance)
121
+
122
+ triangle = _triangle_confidences(highs_np, lows_np, tolerance)
123
+ out["pattern_ascending_triangle_confidence"] = triangle["ascending_triangle"]
124
+ out["pattern_descending_triangle_confidence"] = triangle["descending_triangle"]
125
+ out["pattern_head_shoulders_confidence"] = _head_shoulders_confidence(highs_np, lows_np, tolerance, inverse=False)
126
+ out["pattern_inverse_head_shoulders_confidence"] = _head_shoulders_confidence(highs_np, lows_np, tolerance, inverse=True)
127
+ return out
signals/rolling_quant.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from ta.trend import ema_indicator, sma_indicator
6
+
7
+
8
+ def _finite_or_zero(value: float) -> float:
9
+ try:
10
+ value = float(value)
11
+ except Exception:
12
+ return 0.0
13
+ if not np.isfinite(value):
14
+ return 0.0
15
+ return value
16
+
17
+
18
+ def compute_rolling_quant_features(closes: List[float], end_idx: int) -> Dict[str, float]:
19
+ closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
20
+ if closes_np.size == 0:
21
+ return {
22
+ "ema_fast": 0.0,
23
+ "ema_medium": 0.0,
24
+ "sma_fast": 0.0,
25
+ "sma_medium": 0.0,
26
+ "price_minus_ema_fast": 0.0,
27
+ "price_minus_ema_medium": 0.0,
28
+ "ema_spread": 0.0,
29
+ "price_zscore": 0.0,
30
+ "mean_reversion_score": 0.0,
31
+ "rolling_vol_zscore": 0.0,
32
+ }
33
+
34
+ close_series = pd.Series(closes_np)
35
+ current = float(closes_np[-1])
36
+ current_scale = max(abs(current), 1e-8)
37
+
38
+ ema_fast = _finite_or_zero(ema_indicator(close_series, window=8, fillna=True).iloc[-1])
39
+ ema_medium = _finite_or_zero(ema_indicator(close_series, window=21, fillna=True).iloc[-1])
40
+ sma_fast = _finite_or_zero(sma_indicator(close_series, window=8, fillna=True).iloc[-1])
41
+ sma_medium = _finite_or_zero(sma_indicator(close_series, window=21, fillna=True).iloc[-1])
42
+
43
+ mean_all = _finite_or_zero(close_series.mean())
44
+ std_all = _finite_or_zero(close_series.std(ddof=0))
45
+ price_zscore = 0.0 if std_all <= 1e-8 else (current - mean_all) / std_all
46
+
47
+ log_returns = np.diff(np.log(np.clip(closes_np, 1e-8, None)))
48
+ if log_returns.size == 0:
49
+ rolling_vol = 0.0
50
+ mean_vol = 0.0
51
+ std_vol = 0.0
52
+ else:
53
+ abs_log_returns = np.abs(log_returns)
54
+ rolling_vol = _finite_or_zero(np.std(log_returns[-20:]))
55
+ mean_vol = _finite_or_zero(np.mean(abs_log_returns))
56
+ std_vol = _finite_or_zero(np.std(abs_log_returns))
57
+ rolling_vol_zscore = 0.0 if std_vol <= 1e-8 else (rolling_vol - mean_vol) / std_vol
58
+
59
+ denom = max(abs(sma_medium), 1e-8)
60
+ return {
61
+ "ema_fast": ema_fast / current_scale,
62
+ "ema_medium": ema_medium / current_scale,
63
+ "sma_fast": sma_fast / current_scale,
64
+ "sma_medium": sma_medium / current_scale,
65
+ "price_minus_ema_fast": (current - ema_fast) / current_scale,
66
+ "price_minus_ema_medium": (current - ema_medium) / current_scale,
67
+ "ema_spread": (ema_fast - ema_medium) / current_scale,
68
+ "price_zscore": _finite_or_zero(price_zscore),
69
+ "mean_reversion_score": (mean_all - current) / denom,
70
+ "rolling_vol_zscore": _finite_or_zero(rolling_vol_zscore),
71
+ }
signals/support_resistance.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import numpy as np
4
+ from scipy.signal import argrelextrema
5
+
6
+
7
+ def _compute_pivots(prices: np.ndarray, order: int = 2) -> Dict[str, List[int]]:
8
+ if prices.size < (2 * order + 1):
9
+ return {"highs": [], "lows": []}
10
+ highs = argrelextrema(prices, np.greater_equal, order=order, mode="clip")[0].tolist()
11
+ lows = argrelextrema(prices, np.less_equal, order=order, mode="clip")[0].tolist()
12
+ highs = [idx for idx in highs if order <= idx < len(prices) - order]
13
+ lows = [idx for idx in lows if order <= idx < len(prices) - order]
14
+ return {"highs": highs, "lows": lows}
15
+
16
+
17
+ def _cluster_levels(prices: np.ndarray, pivot_indices: List[int], tolerance: float) -> List[Dict[str, float]]:
18
+ levels: List[Dict[str, float]] = []
19
+ for pivot_idx in pivot_indices:
20
+ price = float(prices[pivot_idx])
21
+ matched = None
22
+ for level in levels:
23
+ if abs(price - level["price"]) <= tolerance:
24
+ matched = level
25
+ break
26
+ if matched is None:
27
+ levels.append({
28
+ "price": price,
29
+ "touches": 1.0,
30
+ "last_idx": float(pivot_idx),
31
+ "first_idx": float(pivot_idx),
32
+ })
33
+ continue
34
+ touches = matched["touches"] + 1.0
35
+ matched["price"] = ((matched["price"] * matched["touches"]) + price) / touches
36
+ matched["touches"] = touches
37
+ matched["last_idx"] = float(pivot_idx)
38
+ return levels
39
+
40
+
41
+ def _nearest_level(levels: List[Dict[str, float]], current_price: float, below: bool) -> Optional[Dict[str, float]]:
42
+ candidates = [level for level in levels if (level["price"] <= current_price if below else level["price"] >= current_price)]
43
+ if not candidates:
44
+ return None
45
+ return min(candidates, key=lambda level: abs(level["price"] - current_price))
46
+
47
+
48
+ def compute_support_resistance_features(
49
+ closes: List[float],
50
+ highs: List[float],
51
+ lows: List[float],
52
+ end_idx: int,
53
+ window_start: int,
54
+ window_end: int,
55
+ timestamps: List[int],
56
+ ) -> Dict[str, float]:
57
+ closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
58
+ highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
59
+ lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
60
+ if closes_np.size == 0:
61
+ return {key: 0.0 for key in [
62
+ "nearest_support_dist", "nearest_resistance_dist", "support_touch_count",
63
+ "resistance_touch_count", "support_age_sec", "resistance_age_sec",
64
+ "support_strength", "resistance_strength", "inside_support_zone",
65
+ "inside_resistance_zone", "support_swept", "resistance_swept",
66
+ "support_reclaim", "resistance_reject", "sr_available",
67
+ ]}
68
+
69
+ current_price = float(closes_np[-1])
70
+ local_range = max(float(np.max(highs_np) - np.min(lows_np)), current_price * 1e-3, 1e-8)
71
+ tolerance = max(local_range * 0.08, current_price * 0.0025)
72
+
73
+ pivots_high = _compute_pivots(highs_np, order=2)["highs"]
74
+ pivots_low = _compute_pivots(lows_np, order=2)["lows"]
75
+ support_levels = _cluster_levels(lows_np, pivots_low, tolerance)
76
+ resistance_levels = _cluster_levels(highs_np, pivots_high, tolerance)
77
+
78
+ support = _nearest_level(support_levels, current_price, below=True)
79
+ resistance = _nearest_level(resistance_levels, current_price, below=False)
80
+ current_ts = float(timestamps[min(end_idx, len(timestamps) - 1)]) if timestamps else float(end_idx)
81
+
82
+ def _level_stats(level: Optional[Dict[str, float]], is_support: bool) -> Dict[str, float]:
83
+ if level is None:
84
+ return {
85
+ "dist": 0.0,
86
+ "touch_count": 0.0,
87
+ "age_sec": 0.0,
88
+ "strength": 0.0,
89
+ "inside_zone": 0.0,
90
+ "swept": 0.0,
91
+ "confirm": 0.0,
92
+ }
93
+ level_price = float(level["price"])
94
+ zone_half_width = max(tolerance, abs(level_price) * 0.002)
95
+ window_prices_high = highs[window_start:window_end]
96
+ window_prices_low = lows[window_start:window_end]
97
+ swept = 0.0
98
+ confirm = 0.0
99
+ if is_support:
100
+ min_low = min(window_prices_low) if window_prices_low else current_price
101
+ swept = 1.0 if min_low < (level_price - zone_half_width) else 0.0
102
+ confirm = 1.0 if swept > 0 and current_price >= level_price else 0.0
103
+ else:
104
+ max_high = max(window_prices_high) if window_prices_high else current_price
105
+ swept = 1.0 if max_high > (level_price + zone_half_width) else 0.0
106
+ confirm = 1.0 if swept > 0 and current_price <= level_price else 0.0
107
+
108
+ return {
109
+ "dist": (current_price - level_price) / max(abs(current_price), 1e-8) if is_support else (level_price - current_price) / max(abs(current_price), 1e-8),
110
+ "touch_count": float(level["touches"]),
111
+ "age_sec": max(0.0, current_ts - float(timestamps[int(level["last_idx"])]) if timestamps else current_ts - level["last_idx"]),
112
+ "strength": float(level["touches"]) / max(1.0, float(end_idx + 1)),
113
+ "inside_zone": 1.0 if abs(current_price - level_price) <= zone_half_width else 0.0,
114
+ "swept": swept,
115
+ "confirm": confirm,
116
+ }
117
+
118
+ support_stats = _level_stats(support, True)
119
+ resistance_stats = _level_stats(resistance, False)
120
+ return {
121
+ "nearest_support_dist": support_stats["dist"],
122
+ "nearest_resistance_dist": resistance_stats["dist"],
123
+ "support_touch_count": support_stats["touch_count"],
124
+ "resistance_touch_count": resistance_stats["touch_count"],
125
+ "support_age_sec": support_stats["age_sec"],
126
+ "resistance_age_sec": resistance_stats["age_sec"],
127
+ "support_strength": support_stats["strength"],
128
+ "resistance_strength": resistance_stats["strength"],
129
+ "inside_support_zone": support_stats["inside_zone"],
130
+ "inside_resistance_zone": resistance_stats["inside_zone"],
131
+ "support_swept": support_stats["swept"],
132
+ "resistance_swept": resistance_stats["swept"],
133
+ "support_reclaim": support_stats["confirm"],
134
+ "resistance_reject": resistance_stats["confirm"],
135
+ "sr_available": 1.0 if support_levels or resistance_levels else 0.0,
136
+ }
signals/trendlines.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import trendln
5
+
6
+
7
+ def _empty_trendline_features() -> Dict[str, float]:
8
+ return {
9
+ "lower_trendline_slope": 0.0,
10
+ "upper_trendline_slope": 0.0,
11
+ "dist_to_lower_line": 0.0,
12
+ "dist_to_upper_line": 0.0,
13
+ "trend_channel_width": 0.0,
14
+ "trend_convergence": 0.0,
15
+ "trend_breakout_upper": 0.0,
16
+ "trend_breakdown_lower": 0.0,
17
+ "trend_reentry": 0.0,
18
+ "trendline_available": 0.0,
19
+ }
20
+
21
+
22
+ def _extract_best_line(line_data: object) -> Optional[Tuple[float, float]]:
23
+ if not isinstance(line_data, Iterable):
24
+ return None
25
+ line_list = list(line_data)
26
+ if not line_list:
27
+ return None
28
+ best = line_list[0]
29
+ if not isinstance(best, tuple) or len(best) < 2 or not isinstance(best[1], tuple) or len(best[1]) < 2:
30
+ return None
31
+ slope = float(best[1][0])
32
+ intercept = float(best[1][1])
33
+ if not np.isfinite(slope) or not np.isfinite(intercept):
34
+ return None
35
+ return slope, intercept
36
+
37
+
38
+ def _extract_overall_line(summary_data: object) -> Optional[Tuple[float, float]]:
39
+ if isinstance(summary_data, (tuple, list)) and len(summary_data) >= 2:
40
+ slope = float(summary_data[0])
41
+ intercept = float(summary_data[1])
42
+ if np.isfinite(slope) and np.isfinite(intercept):
43
+ return slope, intercept
44
+ return None
45
+
46
+
47
+ def _fit_with_trendln(values: np.ndarray) -> Tuple[Optional[Tuple[float, float]], Optional[Tuple[float, float]]]:
48
+ window = max(10, min(125, int(values.size)))
49
+ out = trendln.calc_support_resistance(
50
+ values,
51
+ extmethod=trendln.METHOD_NAIVE,
52
+ method=trendln.METHOD_NSQUREDLOGN,
53
+ window=window,
54
+ errpct=0.01,
55
+ )
56
+ support_part, resistance_part = out
57
+ support_line = _extract_best_line(support_part[2]) or _extract_overall_line(support_part[1])
58
+ resistance_line = _extract_best_line(resistance_part[2]) or _extract_overall_line(resistance_part[1])
59
+ return support_line, resistance_line
60
+
61
+
62
+ def compute_trendline_features(closes, highs, lows, end_idx: int) -> Dict[str, float]:
63
+ closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
64
+ highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
65
+ lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
66
+ if closes_np.size < 5:
67
+ return _empty_trendline_features()
68
+
69
+ current_price = float(closes_np[-1])
70
+ idx = float(closes_np.size - 1)
71
+ prev_idx = max(0.0, idx - 1.0)
72
+ prev_close = float(closes_np[-2]) if closes_np.size > 1 else current_price
73
+ price_scale = max(abs(current_price), 1e-8)
74
+
75
+ try:
76
+ support_line, resistance_line = _fit_with_trendln(closes_np)
77
+ except Exception:
78
+ support_line, resistance_line = None, None
79
+
80
+ if support_line is None:
81
+ try:
82
+ support_line, _ = _fit_with_trendln(lows_np)
83
+ except Exception:
84
+ support_line = None
85
+ if resistance_line is None:
86
+ try:
87
+ _, resistance_line = _fit_with_trendln(highs_np)
88
+ except Exception:
89
+ resistance_line = None
90
+
91
+ if support_line is None and resistance_line is None:
92
+ return _empty_trendline_features()
93
+
94
+ lower_pred = support_line[0] * idx + support_line[1] if support_line is not None else current_price
95
+ upper_pred = resistance_line[0] * idx + resistance_line[1] if resistance_line is not None else current_price
96
+ lower_prev = support_line[0] * prev_idx + support_line[1] if support_line is not None else lower_pred
97
+ upper_prev = resistance_line[0] * prev_idx + resistance_line[1] if resistance_line is not None else upper_pred
98
+ width = max(abs(upper_pred - lower_pred), 1e-8)
99
+
100
+ breakout_upper = 1.0 if current_price > upper_pred and prev_close <= upper_prev else 0.0
101
+ breakdown_lower = 1.0 if current_price < lower_pred and prev_close >= lower_prev else 0.0
102
+ reentry = 1.0 if (
103
+ (prev_close > upper_prev and current_price <= upper_pred) or
104
+ (prev_close < lower_prev and current_price >= lower_pred)
105
+ ) else 0.0
106
+
107
+ return {
108
+ "lower_trendline_slope": 0.0 if support_line is None else support_line[0] / price_scale,
109
+ "upper_trendline_slope": 0.0 if resistance_line is None else resistance_line[0] / price_scale,
110
+ "dist_to_lower_line": (current_price - lower_pred) / price_scale,
111
+ "dist_to_upper_line": (upper_pred - current_price) / price_scale,
112
+ "trend_channel_width": width / price_scale,
113
+ "trend_convergence": 0.0 if support_line is None or resistance_line is None else (resistance_line[0] - support_line[0]) / price_scale,
114
+ "trend_breakout_upper": breakout_upper,
115
+ "trend_breakdown_lower": breakdown_lower,
116
+ "trend_reentry": reentry,
117
+ "trendline_available": 1.0,
118
+ }
train.py CHANGED
@@ -59,8 +59,10 @@ from models.token_encoder import TokenEncoder
59
  from models.wallet_encoder import WalletEncoder
60
  from models.graph_updater import GraphUpdater
61
  from models.ohlc_embedder import OHLCEmbedder
 
62
  from models.model import Oracle
63
  import models.vocabulary as vocab
 
64
 
65
  # Setup Logger
66
  logger = get_logger(__name__)
@@ -671,6 +673,11 @@ def main() -> None:
671
  num_intervals=vocab.NUM_OHLC_INTERVALS,
672
  dtype=init_dtype
673
  )
 
 
 
 
 
674
 
675
  collator = MemecoinCollator(
676
  event_type_to_id=vocab.EVENT_TO_ID,
@@ -809,6 +816,7 @@ def main() -> None:
809
  wallet_encoder=wallet_encoder,
810
  graph_updater=graph_updater,
811
  ohlc_embedder=ohlc_embedder,
 
812
  time_encoder=time_encoder,
813
  num_event_types=vocab.NUM_EVENT_TYPES,
814
  multi_modal_dim=multi_modal_encoder.embedding_dim,
 
59
  from models.wallet_encoder import WalletEncoder
60
  from models.graph_updater import GraphUpdater
61
  from models.ohlc_embedder import OHLCEmbedder
62
+ from models.quant_ohlc_embedder import QuantOHLCEmbedder
63
  from models.model import Oracle
64
  import models.vocabulary as vocab
65
+ from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
66
 
67
  # Setup Logger
68
  logger = get_logger(__name__)
 
673
  num_intervals=vocab.NUM_OHLC_INTERVALS,
674
  dtype=init_dtype
675
  )
676
+ quant_ohlc_embedder = QuantOHLCEmbedder(
677
+ num_features=NUM_QUANT_OHLC_FEATURES,
678
+ sequence_length=TOKENS_PER_SEGMENT,
679
+ dtype=init_dtype,
680
+ )
681
 
682
  collator = MemecoinCollator(
683
  event_type_to_id=vocab.EVENT_TO_ID,
 
816
  wallet_encoder=wallet_encoder,
817
  graph_updater=graph_updater,
818
  ohlc_embedder=ohlc_embedder,
819
+ quant_ohlc_embedder=quant_ohlc_embedder,
820
  time_encoder=time_encoder,
821
  num_event_types=vocab.NUM_EVENT_TYPES,
822
  multi_modal_dim=multi_modal_encoder.embedding_dim,