zirobtc commited on
Commit
b441d51
·
verified ·
1 Parent(s): 54191e5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/evaluate_sample.py +304 -0
scripts/evaluate_sample.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import random
5
+ import torch
6
+ from pathlib import Path
7
+
8
+ # Add project root to sys.path so we can import data and models
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ # Provide standard defaults
12
+ from accelerate import Accelerator
13
+ from torch.utils.data import DataLoader, Subset
14
+
15
+ from data.data_loader import OracleDataset
16
+ from data.data_collator import MemecoinCollator
17
+ from models.multi_modal_processor import MultiModalEncoder
18
+ from models.helper_encoders import ContextualTimeEncoder
19
+ from models.token_encoder import TokenEncoder
20
+ from models.wallet_encoder import WalletEncoder
21
+ from models.graph_updater import GraphUpdater
22
+ from models.ohlc_embedder import OHLCEmbedder
23
+ from models.model import Oracle
24
+ import models.vocabulary as vocab
25
+ from train import create_balanced_split
26
+
27
+ def unlog_transform(tensor):
28
+ """Invert the log1p transform applied during training."""
29
+ # During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
30
+ return torch.sign(tensor) * (torch.exp(torch.abs(tensor)) - 1)
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/checkpoint-90000", help="Path to checkpoint dir")
35
+ parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache", help="Path to dataset cache")
36
+ parser.add_argument("--sample_idx", type=int, default=None, help="Specific sample index to evaluate")
37
+ parser.add_argument("--mixed_precision", type=str, default="bf16")
38
+ parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
39
+ parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
40
+ parser.add_argument("--seed", type=int, default=42)
41
+ parser.add_argument("--min_horizon", type=int, default=900, help="Ensure the sampled coin has ground truth for at least this horizon (in seconds)")
42
+ return parser.parse_args()
43
+
44
+ def get_latest_checkpoint(checkpoint_dir):
45
+ ckpt_dir = Path(checkpoint_dir)
46
+ if ckpt_dir.exists():
47
+ dirs = [d for d in ckpt_dir.iterdir() if d.is_dir()]
48
+ if dirs:
49
+ dirs.sort(key=lambda x: x.stat().st_mtime)
50
+ latest_checkpoint = dirs[-1]
51
+ return str(latest_checkpoint)
52
+ return None
53
+
54
+ def main():
55
+ args = parse_args()
56
+
57
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
58
+ device = accelerator.device
59
+
60
+ init_dtype = torch.float32
61
+ if accelerator.mixed_precision == 'bf16':
62
+ init_dtype = torch.bfloat16
63
+ elif accelerator.mixed_precision == 'fp16':
64
+ init_dtype = torch.float16
65
+
66
+ print(f"Loading cached dataset from {args.cache_dir}...")
67
+ dataset = OracleDataset(
68
+ data_fetcher=None,
69
+ fetcher_config=None,
70
+ horizons_seconds=args.horizons_seconds,
71
+ quantiles=args.quantiles,
72
+ max_samples=None,
73
+ t_cutoff_seconds=60,
74
+ cache_dir=args.cache_dir
75
+ )
76
+
77
+ if len(dataset) == 0:
78
+ raise ValueError("Dataset is empty!")
79
+
80
+ # Optionally pick validation sample like in training
81
+ print("Creating balanced train/val split to pick a validation sample...")
82
+ _, val_indices, _ = create_balanced_split(dataset, n_val_per_class=10, seed=args.seed)
83
+
84
+ # Re-seed with system time so we don't pick the same sample every time
85
+ import time
86
+ random.seed(time.time())
87
+
88
+ # --- Filter by minimum horizon if requested ---
89
+ if args.min_horizon is not None and args.min_horizon in args.horizons_seconds:
90
+ print(f"Filtering dataset to find samples with ground truth >= {args.min_horizon}s...")
91
+ h_idx = args.horizons_seconds.index(args.min_horizon)
92
+ num_quantiles = len(args.quantiles)
93
+
94
+ valid_indices = []
95
+ # We search through a shuffled subset to avoid checking the whole dataset
96
+ search_pool = val_indices.copy()
97
+ random.shuffle(search_pool)
98
+ if not search_pool:
99
+ search_pool = list(range(len(dataset)))
100
+ random.shuffle(search_pool)
101
+
102
+ for idx in search_pool:
103
+ sample = dataset[idx]
104
+ if sample is None:
105
+ continue
106
+ mask = sample.get('labels_mask')
107
+ if mask is not None:
108
+ # Based on raw file inspection, mask is shape [H], so we index by h_idx directly
109
+ if h_idx < len(mask) and mask[h_idx] > 0.0:
110
+ valid_indices.append(idx)
111
+ # Once we find a handful of valid ones, we can stop searching
112
+ if len(valid_indices) >= 10:
113
+ break
114
+
115
+ if valid_indices:
116
+ print(f"Found {len(valid_indices)} candidate samples with >= {args.min_horizon}s horizon.")
117
+ val_indices = valid_indices
118
+ else:
119
+ print(f"WARNING: No samples found with ground truth for horizon {args.min_horizon}s. Reverting to random pick.")
120
+
121
+ if args.sample_idx is not None:
122
+ if args.sample_idx >= len(dataset):
123
+ raise ValueError(f"Sample index {args.sample_idx} out of range [0, {len(dataset)-1}]")
124
+ sample_idx = args.sample_idx
125
+ else:
126
+ # Pick a random sample from validation set
127
+ if len(val_indices) > 0:
128
+ sample_idx = random.choice(val_indices)
129
+ else:
130
+ print("No validation indices found. Picking random sample from entire set.")
131
+ sample_idx = random.randint(0, len(dataset) - 1)
132
+
133
+ print(f"\nEvaluating on sample index: {sample_idx}")
134
+
135
+ # Initialize encoders and model
136
+ print("Initializing encoders...")
137
+ multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device=device)
138
+ time_encoder = ContextualTimeEncoder(dtype=init_dtype)
139
+ token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
140
+ wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
141
+ graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
142
+ ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
143
+
144
+ collator = MemecoinCollator(
145
+ event_type_to_id=vocab.EVENT_TO_ID,
146
+ device=device,
147
+ dtype=init_dtype,
148
+ max_seq_len=4096
149
+ )
150
+
151
+ print("Initializing model...")
152
+ model = Oracle(
153
+ token_encoder=token_encoder,
154
+ wallet_encoder=wallet_encoder,
155
+ graph_updater=graph_updater,
156
+ ohlc_embedder=ohlc_embedder,
157
+ time_encoder=time_encoder,
158
+ num_event_types=vocab.NUM_EVENT_TYPES,
159
+ multi_modal_dim=multi_modal_encoder.embedding_dim,
160
+ event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
161
+ event_type_to_id=vocab.EVENT_TO_ID,
162
+ model_config_name="llama3-12l-768d-gqa4-8k-random",
163
+ quantiles=args.quantiles,
164
+ horizons_seconds=args.horizons_seconds,
165
+ dtype=init_dtype
166
+ )
167
+
168
+ if hasattr(model.model, 'embed_tokens'):
169
+ del model.model.embed_tokens
170
+
171
+ # Load checkpoint
172
+ ckpt_path = args.checkpoint
173
+ if ckpt_path.endswith("latest"):
174
+ base_dir = Path(ckpt_path).parent
175
+ found = get_latest_checkpoint(base_dir)
176
+ if found:
177
+ ckpt_path = found
178
+
179
+ if not os.path.exists(ckpt_path):
180
+ print(f"Warning: Checkpoint {ckpt_path} not found. Running with random weights!")
181
+ model = accelerator.prepare(model)
182
+ else:
183
+ print(f"Loading checkpoint from {ckpt_path}...")
184
+ # Since we use accelerate, the state dict is usually split or in pytorch_model.bin/model.safetensors
185
+ # Using accelerate to load:
186
+ # We need to wrap it if we want to use `accelerator.load_state`
187
+ model = accelerator.prepare(model)
188
+ try:
189
+ accelerator.load_state(ckpt_path)
190
+ print("Successfully loaded accelerator state.")
191
+ except Exception as e:
192
+ print(f"Could not load using accelerate.load_state: {e}")
193
+ print("Trying to load model weights directly...")
194
+ model_file = os.path.join(ckpt_path, "pytorch_model.bin")
195
+ if not os.path.exists(model_file):
196
+ model_file = os.path.join(ckpt_path, "model.safetensors")
197
+
198
+ if os.path.exists(model_file):
199
+ if model_file.endswith(".safetensors"):
200
+ from safetensors.torch import load_file
201
+ state_dict = load_file(model_file)
202
+ else:
203
+ state_dict = torch.load(model_file, map_location="cpu")
204
+
205
+ # Unwrap model to load state
206
+ uw_model = accelerator.unwrap_model(model)
207
+ uw_model.load_state_dict(state_dict, strict=False)
208
+ print("Successfully loaded weights directly.")
209
+ else:
210
+ print(f"Error: model weights not found in {ckpt_path}")
211
+
212
+ model.eval()
213
+
214
+ # Get sample
215
+ raw_sample = dataset[sample_idx]
216
+ if raw_sample is None:
217
+ print("Sample is None!")
218
+ return
219
+
220
+ batch = collator([raw_sample])
221
+
222
+ # Move batch to device
223
+ for k, v in batch.items():
224
+ if isinstance(v, torch.Tensor):
225
+ batch[k] = v.to(device)
226
+ elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
227
+ batch[k] = [t.to(device) for t in v]
228
+
229
+ # Add missing keys needed by model safety checks
230
+ if 'textual_event_indices' not in batch:
231
+ B, L = batch['event_type_ids'].shape
232
+ batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device)
233
+ if 'textual_event_data' not in batch:
234
+ batch['textual_event_data'] = []
235
+
236
+ print("\n--- Running Inference ---")
237
+ with torch.no_grad():
238
+ outputs = model(batch)
239
+
240
+ preds = outputs["quantile_logits"][0].cpu() # shape [Horizons * Quantiles]
241
+ quality_preds = outputs["quality_logits"][0].cpu() if "quality_logits" in outputs else None
242
+
243
+ # Raw labels from dataset (these are NOT log-transformed yet)
244
+ gt_labels = batch["labels"][0].cpu()
245
+ gt_mask = batch["labels_mask"][0].cpu().bool()
246
+
247
+ # Quality target if available
248
+ gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None
249
+
250
+ # Un-log the predictions since model was trained on log-transformed returns
251
+ # But wait, did the user train with log transformed returns?
252
+ # Yes, train.py does: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
253
+ real_preds = unlog_transform(preds)
254
+
255
+ print("\n================== Results ==================")
256
+ print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}")
257
+ if gt_quality is not None:
258
+ print(f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_preds.item() if quality_preds is not None else 'N/A'}")
259
+
260
+ print("\nReturns per Horizon:")
261
+ num_quantiles = len(args.quantiles)
262
+ # The models outputs all defined horizons, but the dataset labels might be truncated
263
+ # if it was generated with fewer horizons.
264
+ num_gt_horizons = len(gt_mask) # Shape is [H]
265
+
266
+ for h_idx, horizon in enumerate(args.horizons_seconds):
267
+ horizon_min = horizon // 60
268
+ print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
269
+
270
+ if h_idx >= num_gt_horizons:
271
+ print(" [No Ground Truth Available for this Horizon - Not in Dataset]")
272
+ valid = False
273
+ else:
274
+ # Mask format is [H]
275
+ valid = gt_mask[h_idx].item()
276
+
277
+ if not valid:
278
+ print(" [No Ground Truth Available for this Horizon - Masked]")
279
+ # We still print predictions even if GT is masked/missing
280
+ print(" Predictions:")
281
+ for q_idx, q in enumerate(args.quantiles):
282
+ flat_idx = h_idx * num_quantiles + q_idx
283
+ pred_ret = real_preds[flat_idx].item()
284
+ log_pred = preds[flat_idx].item()
285
+ print(f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})")
286
+ continue
287
+
288
+ # Ground truth (raw)
289
+ gt_ret = gt_labels[h_idx].item()
290
+ print(f" Ground Truth: {gt_ret * 100:.2f}%")
291
+
292
+ # Predictions
293
+ print(" Predictions:")
294
+ for q_idx, q in enumerate(args.quantiles):
295
+ flat_idx = h_idx * num_quantiles + q_idx
296
+ pred_ret = real_preds[flat_idx].item()
297
+ log_pred = preds[flat_idx].item()
298
+
299
+ print(f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})")
300
+
301
+ print("=============================================\n")
302
+
303
+ if __name__ == "__main__":
304
+ main()