Upload folder using huggingface_hub
Browse files- audit_cache.py +231 -133
- data/data_collator.py +41 -2
- data/data_loader.py +222 -344
- data/quant_ohlc_feature_schema.py +182 -0
- inference.py +9 -2
- log.log +2 -2
- models/model.py +57 -6
- models/ohlc_embedder.py +3 -4
- models/quant_ohlc_embedder.py +74 -0
- pre_cache.sh +3 -1
- sample_2kGqvM18kGLby9bY_5.json +0 -0
- scripts/cache_dataset.py +193 -211
- scripts/dump_cache_sample.py +6 -2
- scripts/evaluate_sample.py +32 -1
- scripts/rebuild_metadata.py +2 -0
- signals/__init__.py +1 -0
- signals/patterns.py +127 -0
- signals/rolling_quant.py +71 -0
- signals/support_resistance.py +136 -0
- signals/trendlines.py +118 -0
- train.py +8 -0
audit_cache.py
CHANGED
|
@@ -1,151 +1,249 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import math
|
| 4 |
import argparse
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
if not files:
|
| 13 |
-
print(f"No .pt files found in {
|
| 14 |
return
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
stats = {
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
|
| 38 |
-
for
|
| 39 |
try:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
continue
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
wallets = data.get("wallets")
|
| 108 |
-
if not wallets:
|
| 109 |
-
stats['missing_wallets'] += 1
|
| 110 |
-
else:
|
| 111 |
-
has_nan_wallet = False
|
| 112 |
-
for w_addr, w_data in wallets.items():
|
| 113 |
-
profile = w_data.get("profile", {})
|
| 114 |
-
for k, v in profile.items():
|
| 115 |
-
if isinstance(v, float) and math.isnan(v):
|
| 116 |
-
has_nan_wallet = True
|
| 117 |
-
if has_nan_wallet:
|
| 118 |
-
stats['nan_in_wallet_profile'] += 1
|
| 119 |
-
|
| 120 |
-
# 5. Pool index out of bounds
|
| 121 |
-
pool = data.get("embedding_pool", [])
|
| 122 |
-
pool_idxs = [item.get("idx") for item in pool]
|
| 123 |
-
max_idx = max(pool_idxs) if pool_idxs else 0
|
| 124 |
-
|
| 125 |
-
for event in events:
|
| 126 |
-
for k, v in event.items():
|
| 127 |
-
if k.endswith("_idx") and isinstance(v, int):
|
| 128 |
-
if v > max_idx:
|
| 129 |
-
stats['invalid_pool_idx'] += 1
|
| 130 |
-
break
|
| 131 |
-
|
| 132 |
-
except Exception as e:
|
| 133 |
-
issues[f'processing_error_{type(e).__name__}'] += 1
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
print("\n--- Cache Audit Report (New Issues) ---")
|
| 137 |
-
print(f"Audited {total_files} files.")
|
| 138 |
-
for k, v in stats.items():
|
| 139 |
-
print(f"{k}: {v}")
|
| 140 |
-
|
| 141 |
-
print("\nIssues encountered:")
|
| 142 |
-
for k, v in issues.items():
|
| 143 |
-
print(f"{k}: {v}")
|
| 144 |
|
| 145 |
if __name__ == "__main__":
|
| 146 |
parser = argparse.ArgumentParser()
|
| 147 |
parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache")
|
| 148 |
-
parser.add_argument("--num", type=int, default=
|
| 149 |
args = parser.parse_args()
|
| 150 |
-
|
| 151 |
audit_cache(args.cache_dir, args.num)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import math
|
| 3 |
+
from collections import Counter, defaultdict
|
| 4 |
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
+
from data.data_loader import summarize_context_window
|
| 10 |
+
from data.quant_ohlc_feature_schema import FEATURE_VERSION, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
REQUIRED_CONTEXT_FIELDS = [
|
| 14 |
+
"event_sequence",
|
| 15 |
+
"wallets",
|
| 16 |
+
"tokens",
|
| 17 |
+
"labels",
|
| 18 |
+
"labels_mask",
|
| 19 |
+
"quality_score",
|
| 20 |
+
"class_id",
|
| 21 |
+
"source_token",
|
| 22 |
+
"context_bucket",
|
| 23 |
+
"context_score",
|
| 24 |
+
"quant_ohlc_features",
|
| 25 |
+
"quant_feature_version",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _to_list(value):
|
| 30 |
+
if value is None:
|
| 31 |
+
return []
|
| 32 |
+
if isinstance(value, torch.Tensor):
|
| 33 |
+
return value.tolist()
|
| 34 |
+
return list(value)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _safe_float(value):
|
| 38 |
+
if isinstance(value, torch.Tensor):
|
| 39 |
+
if value.numel() != 1:
|
| 40 |
+
raise ValueError("Expected scalar tensor.")
|
| 41 |
+
return float(value.item())
|
| 42 |
+
return float(value)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def audit_cache(cache_dir, num_samples=None):
|
| 46 |
+
cache_path = Path(cache_dir)
|
| 47 |
+
files = sorted(cache_path.glob("sample_*.pt"))
|
| 48 |
if not files:
|
| 49 |
+
print(f"No sample_*.pt files found in {cache_path}")
|
| 50 |
return
|
| 51 |
|
| 52 |
+
if num_samples is not None and num_samples > 0:
|
| 53 |
+
files = files[:num_samples]
|
| 54 |
+
|
| 55 |
+
issues = Counter()
|
| 56 |
+
class_counts = Counter()
|
| 57 |
+
bucket_counts = Counter()
|
| 58 |
+
class_bucket_counts = defaultdict(Counter)
|
| 59 |
+
token_counts_by_class = defaultdict(Counter)
|
| 60 |
+
samples_per_token = Counter()
|
| 61 |
+
missing_fields = Counter()
|
| 62 |
+
|
| 63 |
stats = {
|
| 64 |
+
"files_audited": len(files),
|
| 65 |
+
"empty_event_sequence": 0,
|
| 66 |
+
"missing_wallets": 0,
|
| 67 |
+
"missing_tokens": 0,
|
| 68 |
+
"nan_labels": 0,
|
| 69 |
+
"nan_masks": 0,
|
| 70 |
+
"nan_quality_score": 0,
|
| 71 |
+
"negative_quality_score": 0,
|
| 72 |
+
"max_label_return": -float("inf"),
|
| 73 |
+
"min_label_return": float("inf"),
|
| 74 |
+
"max_events": 0,
|
| 75 |
+
"min_events": float("inf"),
|
| 76 |
+
"contexts_with_no_valid_horizons": 0,
|
| 77 |
+
"context_bucket_mismatch": 0,
|
| 78 |
+
"context_score_mismatch": 0,
|
| 79 |
+
"quant_feature_version_mismatch": 0,
|
| 80 |
+
"chart_events_missing_quant": 0,
|
| 81 |
+
"quant_segments_total": 0,
|
| 82 |
}
|
| 83 |
|
| 84 |
+
for filepath in tqdm(files, desc="Auditing cache", unit="file"):
|
| 85 |
try:
|
| 86 |
+
data = torch.load(filepath, map_location="cpu", weights_only=False)
|
| 87 |
+
except Exception:
|
| 88 |
+
issues["load_error"] += 1
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
if not isinstance(data, dict):
|
| 92 |
+
issues["not_dict"] += 1
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
missing_for_file = []
|
| 96 |
+
for field in REQUIRED_CONTEXT_FIELDS:
|
| 97 |
+
if field not in data:
|
| 98 |
+
missing_for_file.append(field)
|
| 99 |
+
missing_fields[field] += 1
|
| 100 |
+
|
| 101 |
+
if missing_for_file:
|
| 102 |
+
issues["missing_required_fields"] += 1
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
class_id = int(data["class_id"])
|
| 106 |
+
source_token = str(data["source_token"])
|
| 107 |
+
context_bucket = str(data["context_bucket"])
|
| 108 |
+
|
| 109 |
+
class_counts[class_id] += 1
|
| 110 |
+
bucket_counts[context_bucket] += 1
|
| 111 |
+
class_bucket_counts[class_id][context_bucket] += 1
|
| 112 |
+
token_counts_by_class[class_id][source_token] += 1
|
| 113 |
+
samples_per_token[source_token] += 1
|
| 114 |
+
|
| 115 |
+
events = data.get("event_sequence") or []
|
| 116 |
+
wallets = data.get("wallets") or {}
|
| 117 |
+
tokens = data.get("tokens") or {}
|
| 118 |
+
labels = _to_list(data.get("labels"))
|
| 119 |
+
masks = _to_list(data.get("labels_mask"))
|
| 120 |
+
|
| 121 |
+
if not events:
|
| 122 |
+
stats["empty_event_sequence"] += 1
|
| 123 |
+
stats["max_events"] = max(stats["max_events"], len(events))
|
| 124 |
+
stats["min_events"] = min(stats["min_events"], len(events))
|
| 125 |
+
|
| 126 |
+
if not wallets:
|
| 127 |
+
stats["missing_wallets"] += 1
|
| 128 |
+
if not tokens:
|
| 129 |
+
stats["missing_tokens"] += 1
|
| 130 |
+
|
| 131 |
+
has_nan_label = False
|
| 132 |
+
for value in labels:
|
| 133 |
+
if math.isnan(float(value)):
|
| 134 |
+
has_nan_label = True
|
| 135 |
+
break
|
| 136 |
+
stats["max_label_return"] = max(stats["max_label_return"], float(value))
|
| 137 |
+
stats["min_label_return"] = min(stats["min_label_return"], float(value))
|
| 138 |
+
if has_nan_label:
|
| 139 |
+
stats["nan_labels"] += 1
|
| 140 |
+
|
| 141 |
+
has_nan_mask = False
|
| 142 |
+
for value in masks:
|
| 143 |
+
if math.isnan(float(value)):
|
| 144 |
+
has_nan_mask = True
|
| 145 |
+
break
|
| 146 |
+
if has_nan_mask:
|
| 147 |
+
stats["nan_masks"] += 1
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
quality_score = _safe_float(data.get("quality_score"))
|
| 151 |
+
if math.isnan(quality_score):
|
| 152 |
+
stats["nan_quality_score"] += 1
|
| 153 |
+
elif quality_score < 0:
|
| 154 |
+
stats["negative_quality_score"] += 1
|
| 155 |
+
except Exception:
|
| 156 |
+
issues["invalid_quality_score"] += 1
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
|
| 160 |
+
if summary["valid_horizons"] == 0:
|
| 161 |
+
stats["contexts_with_no_valid_horizons"] += 1
|
| 162 |
+
if summary["context_bucket"] != context_bucket:
|
| 163 |
+
stats["context_bucket_mismatch"] += 1
|
| 164 |
+
stored_score = _safe_float(data.get("context_score"))
|
| 165 |
+
if not math.isclose(summary["context_score"], stored_score, rel_tol=1e-6, abs_tol=1e-6):
|
| 166 |
+
stats["context_score_mismatch"] += 1
|
| 167 |
+
except Exception:
|
| 168 |
+
issues["context_summary_error"] += 1
|
| 169 |
+
|
| 170 |
+
if data.get("quant_feature_version") != FEATURE_VERSION:
|
| 171 |
+
stats["quant_feature_version_mismatch"] += 1
|
| 172 |
+
|
| 173 |
+
chart_events = [event for event in events if event.get("event_type") == "Chart_Segment"]
|
| 174 |
+
stats["quant_segments_total"] += len(chart_events)
|
| 175 |
+
for event in chart_events:
|
| 176 |
+
quant_payload = event.get("quant_ohlc_features")
|
| 177 |
+
if not isinstance(quant_payload, list):
|
| 178 |
+
stats["chart_events_missing_quant"] += 1
|
| 179 |
continue
|
| 180 |
+
if len(quant_payload) > TOKENS_PER_SEGMENT:
|
| 181 |
+
issues["quant_too_many_tokens"] += 1
|
| 182 |
+
for token_payload in quant_payload:
|
| 183 |
+
vec = token_payload.get("feature_vector")
|
| 184 |
+
if not isinstance(vec, list) or len(vec) != NUM_QUANT_OHLC_FEATURES:
|
| 185 |
+
issues["quant_bad_vector_shape"] += 1
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
if stats["min_events"] == float("inf"):
|
| 189 |
+
stats["min_events"] = 0
|
| 190 |
+
if stats["min_label_return"] == float("inf"):
|
| 191 |
+
stats["min_label_return"] = 0.0
|
| 192 |
+
if stats["max_label_return"] == -float("inf"):
|
| 193 |
+
stats["max_label_return"] = 0.0
|
| 194 |
+
|
| 195 |
+
unique_tokens_total = len(samples_per_token)
|
| 196 |
+
duplicate_tokens_total = sum(1 for count in samples_per_token.values() if count > 1)
|
| 197 |
+
|
| 198 |
+
print("\n=== Cache Audit ===")
|
| 199 |
+
print(f"Cache dir: {cache_path}")
|
| 200 |
+
print(f"Files audited: {stats['files_audited']}")
|
| 201 |
+
print(f"Unique source tokens: {unique_tokens_total}")
|
| 202 |
+
print(f"Tokens with >1 cached context: {duplicate_tokens_total}")
|
| 203 |
+
print(f"Samples per token max: {max(samples_per_token.values()) if samples_per_token else 0}")
|
| 204 |
+
|
| 205 |
+
print("\n--- Class Counts ---")
|
| 206 |
+
for class_id in sorted(class_counts):
|
| 207 |
+
unique_tokens = len(token_counts_by_class[class_id])
|
| 208 |
+
print(f"Class {class_id}: samples={class_counts[class_id]} unique_tokens={unique_tokens}")
|
| 209 |
+
|
| 210 |
+
print("\n--- Context Buckets ---")
|
| 211 |
+
for bucket, count in sorted(bucket_counts.items()):
|
| 212 |
+
print(f"{bucket}: {count}")
|
| 213 |
+
|
| 214 |
+
print("\n--- Class x Context Bucket ---")
|
| 215 |
+
for class_id in sorted(class_bucket_counts):
|
| 216 |
+
bucket_summary = dict(sorted(class_bucket_counts[class_id].items()))
|
| 217 |
+
print(f"Class {class_id}: {bucket_summary}")
|
| 218 |
+
|
| 219 |
+
print("\n--- General Stats ---")
|
| 220 |
+
for key, value in stats.items():
|
| 221 |
+
print(f"{key}: {value}")
|
| 222 |
+
|
| 223 |
+
print("\n--- Missing Fields ---")
|
| 224 |
+
if missing_fields:
|
| 225 |
+
for field, count in sorted(missing_fields.items()):
|
| 226 |
+
print(f"{field}: {count}")
|
| 227 |
+
else:
|
| 228 |
+
print("none")
|
| 229 |
+
|
| 230 |
+
print("\n--- Issues ---")
|
| 231 |
+
if issues:
|
| 232 |
+
for key, value in sorted(issues.items()):
|
| 233 |
+
print(f"{key}: {value}")
|
| 234 |
+
else:
|
| 235 |
+
print("none")
|
| 236 |
+
|
| 237 |
+
print("\n--- Duplicate-Heavy Tokens ---")
|
| 238 |
+
heavy_tokens = sorted(samples_per_token.items(), key=lambda item: (-item[1], item[0]))[:20]
|
| 239 |
+
for token, count in heavy_tokens:
|
| 240 |
+
print(f"{token}: {count}")
|
| 241 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
if __name__ == "__main__":
|
| 244 |
parser = argparse.ArgumentParser()
|
| 245 |
parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache")
|
| 246 |
+
parser.add_argument("--num", type=int, default=None, help="Audit only the first N files.")
|
| 247 |
args = parser.parse_args()
|
| 248 |
+
|
| 249 |
audit_cache(args.cache_dir, args.num)
|
data/data_collator.py
CHANGED
|
@@ -10,6 +10,7 @@ from PIL import Image
|
|
| 10 |
|
| 11 |
import models.vocabulary as vocab
|
| 12 |
from data.data_loader import EmbeddingPooler
|
|
|
|
| 13 |
|
| 14 |
NATIVE_MINT = "So11111111111111111111111111111111111111112"
|
| 15 |
QUOTE_MINTS = {
|
|
@@ -40,6 +41,8 @@ class MemecoinCollator:
|
|
| 40 |
self.device = device
|
| 41 |
self.dtype = dtype
|
| 42 |
self.ohlc_seq_len = 300 # HARDCODED
|
|
|
|
|
|
|
| 43 |
self.max_seq_len = max_seq_len
|
| 44 |
|
| 45 |
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
|
|
@@ -86,10 +89,16 @@ class MemecoinCollator:
|
|
| 86 |
if not chart_events:
|
| 87 |
return {
|
| 88 |
'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
|
| 89 |
-
'interval_ids': torch.empty(0, device=self.device, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
| 90 |
}
|
| 91 |
ohlc_tensors = []
|
| 92 |
interval_ids_list = []
|
|
|
|
|
|
|
|
|
|
| 93 |
seq_len = self.ohlc_seq_len
|
| 94 |
unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
|
| 95 |
for segment_data in chart_events:
|
|
@@ -105,9 +114,36 @@ class MemecoinCollator:
|
|
| 105 |
ohlc_tensors.append(torch.stack([o, c]))
|
| 106 |
interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
|
| 107 |
interval_ids_list.append(interval_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
return {
|
| 109 |
'price_tensor': torch.stack(ohlc_tensors).to(self.device),
|
| 110 |
-
'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
def _collate_graph_links(self,
|
|
@@ -677,6 +713,9 @@ class MemecoinCollator:
|
|
| 677 |
'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
|
| 678 |
'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
|
| 679 |
'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
|
|
|
|
|
|
|
|
|
|
| 680 |
'graph_updater_links': graph_updater_links,
|
| 681 |
'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
|
| 682 |
|
|
|
|
| 10 |
|
| 11 |
import models.vocabulary as vocab
|
| 12 |
from data.data_loader import EmbeddingPooler
|
| 13 |
+
from data.quant_ohlc_feature_schema import FEATURE_VERSION, FEATURE_VERSION_ID, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
|
| 14 |
|
| 15 |
NATIVE_MINT = "So11111111111111111111111111111111111111112"
|
| 16 |
QUOTE_MINTS = {
|
|
|
|
| 41 |
self.device = device
|
| 42 |
self.dtype = dtype
|
| 43 |
self.ohlc_seq_len = 300 # HARDCODED
|
| 44 |
+
self.quant_ohlc_tokens = TOKENS_PER_SEGMENT
|
| 45 |
+
self.quant_ohlc_num_features = NUM_QUANT_OHLC_FEATURES
|
| 46 |
self.max_seq_len = max_seq_len
|
| 47 |
|
| 48 |
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
|
|
|
|
| 89 |
if not chart_events:
|
| 90 |
return {
|
| 91 |
'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
|
| 92 |
+
'interval_ids': torch.empty(0, device=self.device, dtype=torch.long),
|
| 93 |
+
'quant_feature_tensors': torch.empty(0, self.quant_ohlc_tokens, self.quant_ohlc_num_features, device=self.device, dtype=self.dtype),
|
| 94 |
+
'quant_feature_mask': torch.empty(0, self.quant_ohlc_tokens, device=self.device, dtype=self.dtype),
|
| 95 |
+
'quant_feature_version_ids': torch.empty(0, device=self.device, dtype=torch.long),
|
| 96 |
}
|
| 97 |
ohlc_tensors = []
|
| 98 |
interval_ids_list = []
|
| 99 |
+
quant_feature_tensors = []
|
| 100 |
+
quant_feature_masks = []
|
| 101 |
+
quant_feature_version_ids = []
|
| 102 |
seq_len = self.ohlc_seq_len
|
| 103 |
unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
|
| 104 |
for segment_data in chart_events:
|
|
|
|
| 114 |
ohlc_tensors.append(torch.stack([o, c]))
|
| 115 |
interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
|
| 116 |
interval_ids_list.append(interval_id)
|
| 117 |
+
quant_payload = segment_data.get('quant_ohlc_features')
|
| 118 |
+
if quant_payload is None:
|
| 119 |
+
raise RuntimeError("Chart_Segment missing quant_ohlc_features. Rebuild cache with quantitative chart features.")
|
| 120 |
+
if not isinstance(quant_payload, list):
|
| 121 |
+
raise RuntimeError("Chart_Segment quant_ohlc_features must be a list.")
|
| 122 |
+
feature_rows = []
|
| 123 |
+
feature_mask = []
|
| 124 |
+
for token_idx in range(self.quant_ohlc_tokens):
|
| 125 |
+
if token_idx < len(quant_payload):
|
| 126 |
+
payload = quant_payload[token_idx]
|
| 127 |
+
vec = payload.get('feature_vector')
|
| 128 |
+
if not isinstance(vec, list) or len(vec) != self.quant_ohlc_num_features:
|
| 129 |
+
raise RuntimeError(
|
| 130 |
+
f"Chart_Segment quant feature vector must have length {self.quant_ohlc_num_features}."
|
| 131 |
+
)
|
| 132 |
+
feature_rows.append(vec)
|
| 133 |
+
feature_mask.append(1.0)
|
| 134 |
+
else:
|
| 135 |
+
feature_rows.append([0.0] * self.quant_ohlc_num_features)
|
| 136 |
+
feature_mask.append(0.0)
|
| 137 |
+
quant_feature_tensors.append(torch.tensor(feature_rows, device=self.device, dtype=self.dtype))
|
| 138 |
+
quant_feature_masks.append(torch.tensor(feature_mask, device=self.device, dtype=self.dtype))
|
| 139 |
+
version = segment_data.get('quant_feature_version', FEATURE_VERSION)
|
| 140 |
+
quant_feature_version_ids.append(FEATURE_VERSION_ID if version == FEATURE_VERSION else 0)
|
| 141 |
return {
|
| 142 |
'price_tensor': torch.stack(ohlc_tensors).to(self.device),
|
| 143 |
+
'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long),
|
| 144 |
+
'quant_feature_tensors': torch.stack(quant_feature_tensors).to(self.device),
|
| 145 |
+
'quant_feature_mask': torch.stack(quant_feature_masks).to(self.device),
|
| 146 |
+
'quant_feature_version_ids': torch.tensor(quant_feature_version_ids, device=self.device, dtype=torch.long),
|
| 147 |
}
|
| 148 |
|
| 149 |
def _collate_graph_links(self,
|
|
|
|
| 713 |
'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
|
| 714 |
'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
|
| 715 |
'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
|
| 716 |
+
'quant_ohlc_feature_tensors': ohlc_inputs_dict['quant_feature_tensors'],
|
| 717 |
+
'quant_ohlc_feature_mask': ohlc_inputs_dict['quant_feature_mask'],
|
| 718 |
+
'quant_ohlc_feature_version_ids': ohlc_inputs_dict['quant_feature_version_ids'],
|
| 719 |
'graph_updater_links': graph_updater_links,
|
| 720 |
'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
|
| 721 |
|
data/data_loader.py
CHANGED
|
@@ -18,6 +18,19 @@ import models.vocabulary as vocab
|
|
| 18 |
from models.multi_modal_processor import MultiModalEncoder
|
| 19 |
from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
|
| 20 |
from data.context_targets import derive_movement_targets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from requests.adapters import HTTPAdapter
|
| 22 |
from urllib3.util.retry import Retry
|
| 23 |
|
|
@@ -1278,14 +1291,186 @@ class OracleDataset(Dataset):
|
|
| 1278 |
|
| 1279 |
return full_ohlc
|
| 1280 |
|
| 1281 |
-
def
|
| 1282 |
-
|
| 1283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1284 |
|
| 1285 |
-
|
| 1286 |
-
- CONTEXT MODE: Loads pre-computed training context directly (fully offline)
|
| 1287 |
|
| 1288 |
-
|
|
|
|
|
|
|
| 1289 |
"""
|
| 1290 |
import time as _time
|
| 1291 |
_timings = {}
|
|
@@ -1309,23 +1494,17 @@ class OracleDataset(Dataset):
|
|
| 1309 |
if not cached_data:
|
| 1310 |
raise RuntimeError(f"No data loaded for index {idx}")
|
| 1311 |
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
'wallets' in cached_data and
|
| 1321 |
-
'labels' in cached_data and
|
| 1322 |
-
'labels_mask' in cached_data
|
| 1323 |
-
)
|
| 1324 |
-
cache_mode = 'context' if has_context_shape else 'raw'
|
| 1325 |
|
| 1326 |
-
if
|
| 1327 |
-
#
|
| 1328 |
-
# This is fully deterministic - no runtime sampling or processing
|
| 1329 |
_timings['total'] = _time.perf_counter() - _total_start
|
| 1330 |
|
| 1331 |
if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
|
|
@@ -1346,318 +1525,16 @@ class OracleDataset(Dataset):
|
|
| 1346 |
)
|
| 1347 |
|
| 1348 |
if idx % 100 == 0:
|
| 1349 |
-
print(f"[Sample {idx}]
|
| 1350 |
f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
|
| 1351 |
|
| 1352 |
return cached_data
|
| 1353 |
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
required_keys = [
|
| 1358 |
-
"mint_timestamp",
|
| 1359 |
-
"max_limit_time",
|
| 1360 |
-
"token_address",
|
| 1361 |
-
"creator_address",
|
| 1362 |
-
"trades",
|
| 1363 |
-
"transfers",
|
| 1364 |
-
"pool_creations",
|
| 1365 |
-
"liquidity_changes",
|
| 1366 |
-
"fee_collections",
|
| 1367 |
-
"burns",
|
| 1368 |
-
"supply_locks",
|
| 1369 |
-
"migrations",
|
| 1370 |
-
"quality_score"
|
| 1371 |
-
]
|
| 1372 |
-
missing_keys = [key for key in required_keys if key not in raw_data]
|
| 1373 |
-
if missing_keys:
|
| 1374 |
-
raise RuntimeError(
|
| 1375 |
-
f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
|
| 1376 |
-
)
|
| 1377 |
-
|
| 1378 |
-
# --- CHECK: Determine if we have new-style complete cache ---
|
| 1379 |
-
has_complete_cache = 'cached_wallet_data' in raw_data and 'cached_graph_data' in raw_data
|
| 1380 |
-
|
| 1381 |
-
# --- TIMING: T_cutoff sampling prep ---
|
| 1382 |
-
_t0 = _time.perf_counter()
|
| 1383 |
-
|
| 1384 |
-
def _timestamp_to_order_value(ts_value: Any) -> float:
|
| 1385 |
-
if isinstance(ts_value, datetime.datetime):
|
| 1386 |
-
if ts_value.tzinfo is None:
|
| 1387 |
-
ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
|
| 1388 |
-
return ts_value.timestamp()
|
| 1389 |
-
try:
|
| 1390 |
-
return float(ts_value)
|
| 1391 |
-
except (TypeError, ValueError):
|
| 1392 |
-
return 0.0
|
| 1393 |
-
|
| 1394 |
-
# --- DYNAMIC SAMPLING LOGIC ---
|
| 1395 |
-
mint_timestamp = raw_data['mint_timestamp']
|
| 1396 |
-
if isinstance(mint_timestamp, datetime.datetime) and mint_timestamp.tzinfo is None:
|
| 1397 |
-
mint_timestamp = mint_timestamp.replace(tzinfo=datetime.timezone.utc)
|
| 1398 |
-
|
| 1399 |
-
min_window = 30 # seconds
|
| 1400 |
-
horizons = sorted(self.horizons_seconds)
|
| 1401 |
-
first_horizon = horizons[0] if horizons else 60
|
| 1402 |
-
min_label = max(60, first_horizon)
|
| 1403 |
-
preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
|
| 1404 |
-
|
| 1405 |
-
mint_ts_value = _timestamp_to_order_value(mint_timestamp)
|
| 1406 |
-
|
| 1407 |
-
# ============================================================================
|
| 1408 |
-
# T_CUTOFF SAMPLING: Index-based with Successful Trade Guarantee
|
| 1409 |
-
# ============================================================================
|
| 1410 |
-
# 1. Use ALL trades (sorted by timestamp) for context
|
| 1411 |
-
# 2. Find indices of SUCCESSFUL trades (needed for label computation)
|
| 1412 |
-
# 3. Sample interval: [min_context_trades-1, last_successful_idx - 1]
|
| 1413 |
-
# 4. This guarantees: N trades for context, 1+ successful trade for labels
|
| 1414 |
-
# ============================================================================
|
| 1415 |
-
|
| 1416 |
-
all_trades_raw = raw_data.get('trades', [])
|
| 1417 |
-
if not all_trades_raw:
|
| 1418 |
-
return None
|
| 1419 |
-
|
| 1420 |
-
# Sort ALL trades by timestamp
|
| 1421 |
-
all_trades_sorted = sorted(
|
| 1422 |
-
[t for t in all_trades_raw if t.get('timestamp') is not None],
|
| 1423 |
-
key=lambda t: _timestamp_to_order_value(t.get('timestamp'))
|
| 1424 |
)
|
| 1425 |
|
| 1426 |
-
min_context_trades = self.min_trades
|
| 1427 |
-
if len(all_trades_sorted) < (min_context_trades + 1): # context + 1 trade after cutoff
|
| 1428 |
-
return None
|
| 1429 |
-
|
| 1430 |
-
# Find indices of SUCCESSFUL trades (valid for label computation)
|
| 1431 |
-
successful_indices = [
|
| 1432 |
-
i for i, t in enumerate(all_trades_sorted)
|
| 1433 |
-
if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
|
| 1434 |
-
]
|
| 1435 |
-
|
| 1436 |
-
if len(successful_indices) < 2: # Need at least 2 successful: anchor + future
|
| 1437 |
-
return None
|
| 1438 |
-
|
| 1439 |
-
max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
|
| 1440 |
-
# Define sampling interval
|
| 1441 |
-
min_idx = min_context_trades - 1 # At least N trades for context
|
| 1442 |
-
max_idx = len(all_trades_sorted) - 2 # Need at least 1 trade after cutoff
|
| 1443 |
-
|
| 1444 |
-
if max_idx < min_idx:
|
| 1445 |
-
return None
|
| 1446 |
-
|
| 1447 |
-
# Precompute last successful index <= i and next successful index > i
|
| 1448 |
-
last_successful_before = [-1] * len(all_trades_sorted)
|
| 1449 |
-
last_seen = -1
|
| 1450 |
-
succ_set = set(successful_indices)
|
| 1451 |
-
for i in range(len(all_trades_sorted)):
|
| 1452 |
-
if i in succ_set:
|
| 1453 |
-
last_seen = i
|
| 1454 |
-
last_successful_before[i] = last_seen
|
| 1455 |
-
|
| 1456 |
-
next_successful_after = [-1] * len(all_trades_sorted)
|
| 1457 |
-
next_seen = -1
|
| 1458 |
-
for i in range(len(all_trades_sorted) - 1, -1, -1):
|
| 1459 |
-
if i in succ_set:
|
| 1460 |
-
next_seen = i
|
| 1461 |
-
next_successful_after[i] = next_seen
|
| 1462 |
-
|
| 1463 |
-
# Build eligible indices that guarantee:
|
| 1464 |
-
# 1) anchor successful trade at or before cutoff
|
| 1465 |
-
# 2) successful trade within max_horizon_seconds after cutoff
|
| 1466 |
-
eligible_indices = []
|
| 1467 |
-
for i in range(min_idx, max_idx + 1):
|
| 1468 |
-
anchor_idx = last_successful_before[i]
|
| 1469 |
-
next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1
|
| 1470 |
-
if anchor_idx < 0 or next_idx < 0:
|
| 1471 |
-
continue
|
| 1472 |
-
cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp'))
|
| 1473 |
-
next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp'))
|
| 1474 |
-
if next_ts <= cutoff_ts + max_horizon_seconds:
|
| 1475 |
-
eligible_indices.append(i)
|
| 1476 |
-
|
| 1477 |
-
if not eligible_indices:
|
| 1478 |
-
return None
|
| 1479 |
-
|
| 1480 |
-
# Sample random index within eligible interval
|
| 1481 |
-
sample_idx = random.choice(eligible_indices)
|
| 1482 |
-
|
| 1483 |
-
# T_cutoff = timestamp of the sampled trade
|
| 1484 |
-
sample_trade = all_trades_sorted[sample_idx]
|
| 1485 |
-
sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp'))
|
| 1486 |
-
T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
|
| 1487 |
-
_timings['t_cutoff_sampling'] = _time.perf_counter() - _t0
|
| 1488 |
-
|
| 1489 |
-
# --- TIMING: Wallet collection ---
|
| 1490 |
-
_t0 = _time.perf_counter()
|
| 1491 |
-
token_address = raw_data['token_address']
|
| 1492 |
-
creator_address = raw_data['creator_address']
|
| 1493 |
-
cutoff_ts = _timestamp_to_order_value(T_cutoff)
|
| 1494 |
-
|
| 1495 |
-
def _add_wallet(addr: Optional[str], wallet_set: set):
|
| 1496 |
-
if addr:
|
| 1497 |
-
wallet_set.add(addr)
|
| 1498 |
-
|
| 1499 |
-
wallets_to_fetch = set()
|
| 1500 |
-
_add_wallet(creator_address, wallets_to_fetch)
|
| 1501 |
-
|
| 1502 |
-
for trade in raw_data.get('trades', []):
|
| 1503 |
-
if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts:
|
| 1504 |
-
_add_wallet(trade.get('maker'), wallets_to_fetch)
|
| 1505 |
-
|
| 1506 |
-
for transfer in raw_data.get('transfers', []):
|
| 1507 |
-
if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts:
|
| 1508 |
-
_add_wallet(transfer.get('source'), wallets_to_fetch)
|
| 1509 |
-
_add_wallet(transfer.get('destination'), wallets_to_fetch)
|
| 1510 |
-
|
| 1511 |
-
for pool in raw_data.get('pool_creations', []):
|
| 1512 |
-
if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts:
|
| 1513 |
-
_add_wallet(pool.get('creator_address'), wallets_to_fetch)
|
| 1514 |
-
|
| 1515 |
-
for liq in raw_data.get('liquidity_changes', []):
|
| 1516 |
-
if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
|
| 1517 |
-
_add_wallet(liq.get('lp_provider'), wallets_to_fetch)
|
| 1518 |
-
|
| 1519 |
-
# Offline Holder Lookup using raw_data['holder_snapshots_list']
|
| 1520 |
-
# We need the snapshot corresponding to T_cutoff.
|
| 1521 |
-
# Intervals are every 300s from mint_ts.
|
| 1522 |
-
# idx = (T_cutoff - mint) // 300
|
| 1523 |
-
elapsed = (T_cutoff - mint_timestamp).total_seconds()
|
| 1524 |
-
snap_idx = int(elapsed // 300)
|
| 1525 |
-
holder_records = []
|
| 1526 |
-
cached_holders_list = raw_data.get('holder_snapshots_list')
|
| 1527 |
-
if not isinstance(cached_holders_list, list):
|
| 1528 |
-
raise RuntimeError("Invalid cache: holder_snapshots_list must be a list.")
|
| 1529 |
-
if not (0 <= snap_idx < len(cached_holders_list)):
|
| 1530 |
-
raise RuntimeError(
|
| 1531 |
-
f"Invalid cache: holder_snapshots_list index out of range (snap_idx={snap_idx}, len={len(cached_holders_list)})."
|
| 1532 |
-
)
|
| 1533 |
-
snapshot_data = cached_holders_list[snap_idx]
|
| 1534 |
-
if not isinstance(snapshot_data, dict) or not isinstance(snapshot_data.get('holders'), list):
|
| 1535 |
-
raise RuntimeError("Invalid cache: holder snapshot entry must be a dict with list field 'holders'.")
|
| 1536 |
-
holder_records = snapshot_data['holders']
|
| 1537 |
-
for holder in holder_records:
|
| 1538 |
-
if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder:
|
| 1539 |
-
raise RuntimeError("Invalid cache: each holder record must include wallet_address and current_balance.")
|
| 1540 |
-
_add_wallet(holder['wallet_address'], wallets_to_fetch)
|
| 1541 |
-
_timings['wallet_collection'] = _time.perf_counter() - _t0
|
| 1542 |
-
_timings['num_wallets'] = len(wallets_to_fetch)
|
| 1543 |
-
|
| 1544 |
-
pooler = EmbeddingPooler()
|
| 1545 |
-
|
| 1546 |
-
# --- TIMING: Token data (OFFLINE - uses cached image bytes) ---
|
| 1547 |
-
_t0 = _time.perf_counter()
|
| 1548 |
-
|
| 1549 |
-
# Build minimal main token metadata from cache (no HTTP calls)
|
| 1550 |
-
offline_token_data = {token_address: self._build_main_token_seed(token_address, raw_data)}
|
| 1551 |
-
|
| 1552 |
-
# If we have cached image bytes, convert to PIL Image for the pooler
|
| 1553 |
-
cached_image_bytes = raw_data.get('cached_image_bytes')
|
| 1554 |
-
if cached_image_bytes:
|
| 1555 |
-
try:
|
| 1556 |
-
cached_image = Image.open(BytesIO(cached_image_bytes))
|
| 1557 |
-
offline_token_data[token_address]['_cached_image_pil'] = cached_image
|
| 1558 |
-
except Exception as e:
|
| 1559 |
-
pass # Image decoding failed, will use None
|
| 1560 |
-
|
| 1561 |
-
main_token_data = self._process_token_data_offline(
|
| 1562 |
-
[token_address], pooler, T_cutoff, token_data=offline_token_data
|
| 1563 |
-
)
|
| 1564 |
-
_timings['token_data_and_images'] = _time.perf_counter() - _t0
|
| 1565 |
-
|
| 1566 |
-
if not main_token_data:
|
| 1567 |
-
return None
|
| 1568 |
-
|
| 1569 |
-
# --- TIMING: Wallet data (OFFLINE - uses pre-cached profiles/socials/holdings) ---
|
| 1570 |
-
_t0 = _time.perf_counter()
|
| 1571 |
-
|
| 1572 |
-
if has_complete_cache:
|
| 1573 |
-
# Use new complete cache format
|
| 1574 |
-
cached_wallet_bundle = raw_data.get('cached_wallet_data', {})
|
| 1575 |
-
offline_profiles = cached_wallet_bundle.get('profiles', {})
|
| 1576 |
-
offline_socials = cached_wallet_bundle.get('socials', {})
|
| 1577 |
-
offline_holdings = cached_wallet_bundle.get('holdings', {})
|
| 1578 |
-
else:
|
| 1579 |
-
# Fallback to old partial cache format
|
| 1580 |
-
cached_social_bundle = raw_data.get('socials', {})
|
| 1581 |
-
offline_profiles = cached_social_bundle.get('profiles', {})
|
| 1582 |
-
offline_socials = cached_social_bundle.get('socials', {})
|
| 1583 |
-
offline_holdings = {}
|
| 1584 |
-
|
| 1585 |
-
wallet_data, all_token_data = self._process_wallet_data(
|
| 1586 |
-
list(wallets_to_fetch),
|
| 1587 |
-
main_token_data.copy(),
|
| 1588 |
-
pooler,
|
| 1589 |
-
T_cutoff,
|
| 1590 |
-
profiles_override=offline_profiles,
|
| 1591 |
-
socials_override=offline_socials,
|
| 1592 |
-
holdings_override=offline_holdings
|
| 1593 |
-
)
|
| 1594 |
-
_timings['wallet_data'] = _time.perf_counter() - _t0
|
| 1595 |
-
_timings['num_tokens_in_holdings'] = len(all_token_data) - 1 if all_token_data else 0
|
| 1596 |
-
|
| 1597 |
-
# --- TIMING: Graph links (OFFLINE - uses pre-cached graph data) ---
|
| 1598 |
-
_t0 = _time.perf_counter()
|
| 1599 |
-
|
| 1600 |
-
if has_complete_cache:
|
| 1601 |
-
# Use new complete cache format
|
| 1602 |
-
cached_graph_bundle = raw_data.get('cached_graph_data', {})
|
| 1603 |
-
graph_entities = cached_graph_bundle.get('entities', {})
|
| 1604 |
-
graph_links = cached_graph_bundle.get('links', {})
|
| 1605 |
-
else:
|
| 1606 |
-
# No graph data in old cache format
|
| 1607 |
-
graph_entities = {}
|
| 1608 |
-
graph_links = {}
|
| 1609 |
-
|
| 1610 |
-
_timings['graph_links'] = _time.perf_counter() - _t0
|
| 1611 |
-
_timings['num_graph_entities'] = len(graph_entities)
|
| 1612 |
-
|
| 1613 |
-
# --- TIMING: Generate dataset item ---
|
| 1614 |
-
_t0 = _time.perf_counter()
|
| 1615 |
-
# Generate the item
|
| 1616 |
-
result = self._generate_dataset_item(
|
| 1617 |
-
token_address=token_address,
|
| 1618 |
-
t0=mint_timestamp,
|
| 1619 |
-
T_cutoff=T_cutoff,
|
| 1620 |
-
mint_event={ # Reconstruct simplified mint event
|
| 1621 |
-
'event_type': 'Mint',
|
| 1622 |
-
'timestamp': int(mint_timestamp.timestamp()),
|
| 1623 |
-
'relative_ts': 0,
|
| 1624 |
-
'wallet_address': creator_address,
|
| 1625 |
-
'token_address': token_address,
|
| 1626 |
-
'protocol_id': raw_data.get('protocol_id', 0)
|
| 1627 |
-
},
|
| 1628 |
-
trade_records=raw_data['trades'],
|
| 1629 |
-
transfer_records=raw_data['transfers'],
|
| 1630 |
-
pool_creation_records=raw_data['pool_creations'],
|
| 1631 |
-
liquidity_change_records=raw_data['liquidity_changes'],
|
| 1632 |
-
fee_collection_records=raw_data['fee_collections'],
|
| 1633 |
-
burn_records=raw_data['burns'],
|
| 1634 |
-
supply_lock_records=raw_data['supply_locks'],
|
| 1635 |
-
migration_records=raw_data['migrations'],
|
| 1636 |
-
wallet_data=wallet_data,
|
| 1637 |
-
all_token_data=all_token_data,
|
| 1638 |
-
graph_links=graph_links,
|
| 1639 |
-
graph_seed_entities=wallets_to_fetch,
|
| 1640 |
-
all_graph_entities=graph_entities,
|
| 1641 |
-
future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
|
| 1642 |
-
pooler=pooler,
|
| 1643 |
-
sample_idx=idx,
|
| 1644 |
-
cached_holders_list=raw_data.get('holder_snapshots_list'),
|
| 1645 |
-
cached_ohlc_1s=raw_data.get('ohlc_1s'),
|
| 1646 |
-
quality_score=raw_data.get('quality_score')
|
| 1647 |
-
)
|
| 1648 |
-
_timings['generate_item'] = _time.perf_counter() - _t0
|
| 1649 |
-
|
| 1650 |
-
# --- TIMING: Total and summary ---
|
| 1651 |
-
_timings['total'] = _time.perf_counter() - _total_start
|
| 1652 |
-
|
| 1653 |
-
# Only print timing summary occasionally to reduce log spam
|
| 1654 |
-
if idx % 100 == 0:
|
| 1655 |
-
print(f"[Sample {idx}] OFFLINE mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
|
| 1656 |
-
f"total: {_timings['total']*1000:.1f}ms | wallets: {_timings['num_wallets']} | "
|
| 1657 |
-
f"graph_entities: {_timings['num_graph_entities']}")
|
| 1658 |
-
|
| 1659 |
-
return result
|
| 1660 |
-
|
| 1661 |
def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler,
|
| 1662 |
T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
|
| 1663 |
"""
|
|
@@ -2301,7 +2178,9 @@ class OracleDataset(Dataset):
|
|
| 2301 |
'relative_ts': int(last_ts) - int(t0_timestamp),
|
| 2302 |
'opens': self._normalize_price_series(opens_raw),
|
| 2303 |
'closes': self._normalize_price_series(closes_raw),
|
| 2304 |
-
'i': interval_label
|
|
|
|
|
|
|
| 2305 |
}
|
| 2306 |
emitted_events.append(chart_event)
|
| 2307 |
return emitted_events
|
|
@@ -2601,22 +2480,22 @@ class OracleDataset(Dataset):
|
|
| 2601 |
|
| 2602 |
if not all_trades:
|
| 2603 |
# No valid trades for label computation
|
| 2604 |
-
|
| 2605 |
-
|
| 2606 |
-
|
| 2607 |
-
|
| 2608 |
-
|
| 2609 |
return {
|
| 2610 |
'event_sequence': event_sequence,
|
| 2611 |
'wallets': wallet_data,
|
| 2612 |
'tokens': all_token_data,
|
| 2613 |
'graph_links': graph_links,
|
| 2614 |
'embedding_pooler': pooler,
|
|
|
|
|
|
|
| 2615 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2616 |
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2617 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
| 2618 |
-
'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
|
| 2619 |
-
'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
|
| 2620 |
}
|
| 2621 |
|
| 2622 |
# Ensure sorted
|
|
@@ -2696,12 +2575,11 @@ class OracleDataset(Dataset):
|
|
| 2696 |
|
| 2697 |
# DEBUG: Mask summaries removed after validation
|
| 2698 |
|
| 2699 |
-
|
| 2700 |
-
|
| 2701 |
-
|
| 2702 |
-
|
| 2703 |
-
|
| 2704 |
-
|
| 2705 |
return {
|
| 2706 |
'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
|
| 2707 |
'token_address': token_address, # For debugging
|
|
@@ -2711,11 +2589,11 @@ class OracleDataset(Dataset):
|
|
| 2711 |
'tokens': all_token_data,
|
| 2712 |
'graph_links': graph_links,
|
| 2713 |
'embedding_pooler': pooler,
|
|
|
|
|
|
|
| 2714 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 2715 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2716 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
| 2717 |
-
'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
|
| 2718 |
-
'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
|
| 2719 |
}
|
| 2720 |
|
| 2721 |
def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
|
|
|
|
| 18 |
from models.multi_modal_processor import MultiModalEncoder
|
| 19 |
from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
|
| 20 |
from data.context_targets import derive_movement_targets
|
| 21 |
+
from data.quant_ohlc_feature_schema import (
|
| 22 |
+
FEATURE_VERSION,
|
| 23 |
+
FEATURE_VERSION_ID,
|
| 24 |
+
LOOKBACK_SECONDS,
|
| 25 |
+
TOKENS_PER_SEGMENT,
|
| 26 |
+
WINDOW_SECONDS,
|
| 27 |
+
empty_feature_dict,
|
| 28 |
+
feature_dict_to_vector,
|
| 29 |
+
)
|
| 30 |
+
from signals.patterns import compute_pattern_features
|
| 31 |
+
from signals.rolling_quant import compute_rolling_quant_features
|
| 32 |
+
from signals.support_resistance import compute_support_resistance_features
|
| 33 |
+
from signals.trendlines import compute_trendline_features
|
| 34 |
from requests.adapters import HTTPAdapter
|
| 35 |
from urllib3.util.retry import Retry
|
| 36 |
|
|
|
|
| 1291 |
|
| 1292 |
return full_ohlc
|
| 1293 |
|
| 1294 |
+
def _compute_quant_rolling_features(
|
| 1295 |
+
self,
|
| 1296 |
+
closes: List[float],
|
| 1297 |
+
end_idx: int,
|
| 1298 |
+
) -> Dict[str, float]:
|
| 1299 |
+
return compute_rolling_quant_features(closes, end_idx)
|
| 1300 |
+
|
| 1301 |
+
def _compute_support_resistance_features(
|
| 1302 |
+
self,
|
| 1303 |
+
closes: List[float],
|
| 1304 |
+
highs: List[float],
|
| 1305 |
+
lows: List[float],
|
| 1306 |
+
end_idx: int,
|
| 1307 |
+
window_start: int,
|
| 1308 |
+
window_end: int,
|
| 1309 |
+
timestamps: List[int],
|
| 1310 |
+
) -> Dict[str, float]:
|
| 1311 |
+
return compute_support_resistance_features(
|
| 1312 |
+
closes=closes,
|
| 1313 |
+
highs=highs,
|
| 1314 |
+
lows=lows,
|
| 1315 |
+
end_idx=end_idx,
|
| 1316 |
+
window_start=window_start,
|
| 1317 |
+
window_end=window_end,
|
| 1318 |
+
timestamps=timestamps,
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
def _compute_trendline_features(
|
| 1322 |
+
self,
|
| 1323 |
+
closes: List[float],
|
| 1324 |
+
highs: List[float],
|
| 1325 |
+
lows: List[float],
|
| 1326 |
+
end_idx: int,
|
| 1327 |
+
) -> Dict[str, float]:
|
| 1328 |
+
return compute_trendline_features(closes, highs, lows, end_idx)
|
| 1329 |
+
|
| 1330 |
+
def _compute_optional_pattern_flags(
|
| 1331 |
+
self,
|
| 1332 |
+
closes: List[float],
|
| 1333 |
+
highs: List[float],
|
| 1334 |
+
lows: List[float],
|
| 1335 |
+
end_idx: int,
|
| 1336 |
+
) -> Dict[str, float]:
|
| 1337 |
+
return compute_pattern_features(closes, highs, lows, end_idx)
|
| 1338 |
+
|
| 1339 |
+
def _extract_quant_ohlc_features_for_segment(
|
| 1340 |
+
self,
|
| 1341 |
+
segment: List[tuple],
|
| 1342 |
+
interval_label: str,
|
| 1343 |
+
) -> List[Dict[str, Any]]:
|
| 1344 |
+
del interval_label
|
| 1345 |
+
if not segment:
|
| 1346 |
+
return []
|
| 1347 |
+
|
| 1348 |
+
timestamps = [int(row[0]) for row in segment]
|
| 1349 |
+
opens = [float(row[1]) for row in segment]
|
| 1350 |
+
closes = [float(row[2]) for row in segment]
|
| 1351 |
+
highs = [max(o, c) for o, c in zip(opens, closes)]
|
| 1352 |
+
lows = [min(o, c) for o, c in zip(opens, closes)]
|
| 1353 |
+
log_closes = np.log(np.clip(np.asarray(closes, dtype=np.float64), 1e-8, None))
|
| 1354 |
+
one_sec_returns = np.diff(log_closes)
|
| 1355 |
+
feature_windows: List[Dict[str, Any]] = []
|
| 1356 |
+
|
| 1357 |
+
for window_idx in range(TOKENS_PER_SEGMENT):
|
| 1358 |
+
window_start = window_idx * WINDOW_SECONDS
|
| 1359 |
+
if window_start >= len(segment):
|
| 1360 |
+
break
|
| 1361 |
+
window_end = min(len(segment), window_start + WINDOW_SECONDS)
|
| 1362 |
+
current_end_idx = window_end - 1
|
| 1363 |
+
window_returns = one_sec_returns[window_start:max(window_start, current_end_idx)]
|
| 1364 |
+
window_closes = closes[window_start:window_end]
|
| 1365 |
+
window_highs = highs[window_start:window_end]
|
| 1366 |
+
window_lows = lows[window_start:window_end]
|
| 1367 |
+
features = empty_feature_dict()
|
| 1368 |
+
|
| 1369 |
+
if window_closes:
|
| 1370 |
+
window_close_arr = np.asarray(window_closes, dtype=np.float64)
|
| 1371 |
+
window_return_sum = float(np.sum(window_returns)) if window_returns.size > 0 else 0.0
|
| 1372 |
+
range_width = max(max(window_highs) - min(window_lows), 0.0)
|
| 1373 |
+
first_close = float(window_close_arr[0])
|
| 1374 |
+
last_close = float(window_close_arr[-1])
|
| 1375 |
+
accel_proxy = 0.0
|
| 1376 |
+
if window_returns.size >= 2:
|
| 1377 |
+
accel_proxy = float(window_returns[-1] - window_returns[0])
|
| 1378 |
+
features.update({
|
| 1379 |
+
"cum_log_return": window_return_sum,
|
| 1380 |
+
"mean_log_return_1s": float(np.mean(window_returns)) if window_returns.size > 0 else 0.0,
|
| 1381 |
+
"std_log_return_1s": float(np.std(window_returns)) if window_returns.size > 0 else 0.0,
|
| 1382 |
+
"max_up_1s": float(np.max(window_returns)) if window_returns.size > 0 else 0.0,
|
| 1383 |
+
"max_down_1s": float(np.min(window_returns)) if window_returns.size > 0 else 0.0,
|
| 1384 |
+
"realized_vol": float(np.sqrt(np.sum(np.square(window_returns)))) if window_returns.size > 0 else 0.0,
|
| 1385 |
+
"window_range_frac": range_width / max(abs(last_close), 1e-8),
|
| 1386 |
+
"close_to_close_slope": (last_close - first_close) / max(abs(first_close), 1e-8),
|
| 1387 |
+
"accel_proxy": accel_proxy,
|
| 1388 |
+
"frac_pos_1s": float(np.mean(window_returns > 0)) if window_returns.size > 0 else 0.0,
|
| 1389 |
+
"frac_neg_1s": float(np.mean(window_returns < 0)) if window_returns.size > 0 else 0.0,
|
| 1390 |
+
})
|
| 1391 |
+
|
| 1392 |
+
current_price = closes[current_end_idx]
|
| 1393 |
+
current_high = highs[current_end_idx]
|
| 1394 |
+
current_low = lows[current_end_idx]
|
| 1395 |
+
for lookback in LOOKBACK_SECONDS:
|
| 1396 |
+
prefix = f"lb_{lookback}s"
|
| 1397 |
+
lookback_start = max(0, current_end_idx - lookback + 1)
|
| 1398 |
+
hist_closes = closes[lookback_start: current_end_idx + 1]
|
| 1399 |
+
hist_highs = highs[lookback_start: current_end_idx + 1]
|
| 1400 |
+
hist_lows = lows[lookback_start: current_end_idx + 1]
|
| 1401 |
+
hist_range = max(max(hist_highs) - min(hist_lows), 1e-8)
|
| 1402 |
+
rolling_high = max(hist_highs)
|
| 1403 |
+
rolling_low = min(hist_lows)
|
| 1404 |
+
hist_returns = np.diff(np.log(np.clip(np.asarray(hist_closes, dtype=np.float64), 1e-8, None)))
|
| 1405 |
+
current_width = max(max(window_highs) - min(window_lows), 0.0)
|
| 1406 |
+
prev_hist_width = max(max(hist_highs[:-len(window_highs)]) - min(hist_lows[:-len(window_lows)]), 0.0) if len(hist_highs) > len(window_highs) else current_width
|
| 1407 |
+
prev_close = closes[current_end_idx - 1] if current_end_idx > 0 else current_price
|
| 1408 |
+
|
| 1409 |
+
features.update({
|
| 1410 |
+
f"{prefix}_dist_high": (rolling_high - current_price) / max(abs(current_price), 1e-8),
|
| 1411 |
+
f"{prefix}_dist_low": (current_price - rolling_low) / max(abs(current_price), 1e-8),
|
| 1412 |
+
f"{prefix}_drawdown_high": (current_price - rolling_high) / max(abs(rolling_high), 1e-8),
|
| 1413 |
+
f"{prefix}_rebound_low": (current_price - rolling_low) / max(abs(rolling_low), 1e-8),
|
| 1414 |
+
f"{prefix}_pos_in_range": (current_price - rolling_low) / hist_range,
|
| 1415 |
+
f"{prefix}_range_width": hist_range / max(abs(current_price), 1e-8),
|
| 1416 |
+
f"{prefix}_compression_ratio": current_width / max(prev_hist_width, 1e-8),
|
| 1417 |
+
f"{prefix}_breakout_high": 1.0 if current_high > rolling_high and prev_close <= rolling_high else 0.0,
|
| 1418 |
+
f"{prefix}_breakdown_low": 1.0 if current_low < rolling_low and prev_close >= rolling_low else 0.0,
|
| 1419 |
+
f"{prefix}_reclaim_breakdown": 1.0 if current_low < rolling_low and current_price >= rolling_low else 0.0,
|
| 1420 |
+
f"{prefix}_rejection_breakout": 1.0 if current_high > rolling_high and current_price <= rolling_high else 0.0,
|
| 1421 |
+
})
|
| 1422 |
+
|
| 1423 |
+
features.update(self._compute_support_resistance_features(
|
| 1424 |
+
closes=closes,
|
| 1425 |
+
highs=highs,
|
| 1426 |
+
lows=lows,
|
| 1427 |
+
end_idx=current_end_idx,
|
| 1428 |
+
window_start=window_start,
|
| 1429 |
+
window_end=window_end,
|
| 1430 |
+
timestamps=timestamps,
|
| 1431 |
+
))
|
| 1432 |
+
features.update(self._compute_trendline_features(
|
| 1433 |
+
closes=closes,
|
| 1434 |
+
highs=highs,
|
| 1435 |
+
lows=lows,
|
| 1436 |
+
end_idx=current_end_idx,
|
| 1437 |
+
))
|
| 1438 |
+
features.update(self._compute_quant_rolling_features(
|
| 1439 |
+
closes=closes,
|
| 1440 |
+
end_idx=current_end_idx,
|
| 1441 |
+
))
|
| 1442 |
+
features.update(self._compute_optional_pattern_flags(
|
| 1443 |
+
closes=closes,
|
| 1444 |
+
highs=highs,
|
| 1445 |
+
lows=lows,
|
| 1446 |
+
end_idx=current_end_idx,
|
| 1447 |
+
))
|
| 1448 |
+
|
| 1449 |
+
feature_windows.append({
|
| 1450 |
+
"start_ts": timestamps[window_start],
|
| 1451 |
+
"end_ts": timestamps[current_end_idx],
|
| 1452 |
+
"window_seconds": WINDOW_SECONDS,
|
| 1453 |
+
"feature_vector": feature_dict_to_vector(features),
|
| 1454 |
+
"feature_names_version": FEATURE_VERSION,
|
| 1455 |
+
"feature_version_id": FEATURE_VERSION_ID,
|
| 1456 |
+
"level_snapshot": {
|
| 1457 |
+
"support_distance": features.get("nearest_support_dist", 0.0),
|
| 1458 |
+
"resistance_distance": features.get("nearest_resistance_dist", 0.0),
|
| 1459 |
+
"support_strength": features.get("support_strength", 0.0),
|
| 1460 |
+
"resistance_strength": features.get("resistance_strength", 0.0),
|
| 1461 |
+
},
|
| 1462 |
+
"pattern_flags": {
|
| 1463 |
+
key.replace("pattern_", "").replace("_confidence", ""): features[key]
|
| 1464 |
+
for key in features.keys()
|
| 1465 |
+
if key.startswith("pattern_") and key.endswith("_confidence")
|
| 1466 |
+
},
|
| 1467 |
+
})
|
| 1468 |
|
| 1469 |
+
return feature_windows
|
|
|
|
| 1470 |
|
| 1471 |
+
def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 1472 |
+
"""
|
| 1473 |
+
Loads data from cache.
|
| 1474 |
"""
|
| 1475 |
import time as _time
|
| 1476 |
_timings = {}
|
|
|
|
| 1494 |
if not cached_data:
|
| 1495 |
raise RuntimeError(f"No data loaded for index {idx}")
|
| 1496 |
|
| 1497 |
+
has_context_shape = (
|
| 1498 |
+
isinstance(cached_data, dict) and
|
| 1499 |
+
'event_sequence' in cached_data and
|
| 1500 |
+
'tokens' in cached_data and
|
| 1501 |
+
'wallets' in cached_data and
|
| 1502 |
+
'labels' in cached_data and
|
| 1503 |
+
'labels_mask' in cached_data
|
| 1504 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1505 |
|
| 1506 |
+
if has_context_shape:
|
| 1507 |
+
# Return pre-computed training context directly.
|
|
|
|
| 1508 |
_timings['total'] = _time.perf_counter() - _total_start
|
| 1509 |
|
| 1510 |
if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
|
|
|
|
| 1525 |
)
|
| 1526 |
|
| 1527 |
if idx % 100 == 0:
|
| 1528 |
+
print(f"[Sample {idx}] context cache | cache_load: {_timings['cache_load']*1000:.1f}ms | "
|
| 1529 |
f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
|
| 1530 |
|
| 1531 |
return cached_data
|
| 1532 |
|
| 1533 |
+
raise RuntimeError(
|
| 1534 |
+
f"Cached item at {filepath} is not a valid context cache. "
|
| 1535 |
+
"Rebuild the cache with scripts/cache_dataset.py."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1536 |
)
|
| 1537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1538 |
def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler,
|
| 1539 |
T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
|
| 1540 |
"""
|
|
|
|
| 2178 |
'relative_ts': int(last_ts) - int(t0_timestamp),
|
| 2179 |
'opens': self._normalize_price_series(opens_raw),
|
| 2180 |
'closes': self._normalize_price_series(closes_raw),
|
| 2181 |
+
'i': interval_label,
|
| 2182 |
+
'quant_ohlc_features': self._extract_quant_ohlc_features_for_segment(segment, interval_label) if interval_label == "1s" else [],
|
| 2183 |
+
'quant_feature_version': FEATURE_VERSION,
|
| 2184 |
}
|
| 2185 |
emitted_events.append(chart_event)
|
| 2186 |
return emitted_events
|
|
|
|
| 2480 |
|
| 2481 |
if not all_trades:
|
| 2482 |
# No valid trades for label computation
|
| 2483 |
+
quant_payload = [
|
| 2484 |
+
event.get('quant_ohlc_features', [])
|
| 2485 |
+
for event in event_sequence
|
| 2486 |
+
if event.get('event_type') == 'Chart_Segment'
|
| 2487 |
+
]
|
| 2488 |
return {
|
| 2489 |
'event_sequence': event_sequence,
|
| 2490 |
'wallets': wallet_data,
|
| 2491 |
'tokens': all_token_data,
|
| 2492 |
'graph_links': graph_links,
|
| 2493 |
'embedding_pooler': pooler,
|
| 2494 |
+
'quant_ohlc_features': quant_payload,
|
| 2495 |
+
'quant_feature_version': FEATURE_VERSION,
|
| 2496 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2497 |
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2498 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
|
|
|
|
|
|
| 2499 |
}
|
| 2500 |
|
| 2501 |
# Ensure sorted
|
|
|
|
| 2575 |
|
| 2576 |
# DEBUG: Mask summaries removed after validation
|
| 2577 |
|
| 2578 |
+
quant_payload = [
|
| 2579 |
+
event.get('quant_ohlc_features', [])
|
| 2580 |
+
for event in event_sequence
|
| 2581 |
+
if event.get('event_type') == 'Chart_Segment'
|
| 2582 |
+
]
|
|
|
|
| 2583 |
return {
|
| 2584 |
'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
|
| 2585 |
'token_address': token_address, # For debugging
|
|
|
|
| 2589 |
'tokens': all_token_data,
|
| 2590 |
'graph_links': graph_links,
|
| 2591 |
'embedding_pooler': pooler,
|
| 2592 |
+
'quant_ohlc_features': quant_payload,
|
| 2593 |
+
'quant_feature_version': FEATURE_VERSION,
|
| 2594 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 2595 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2596 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
|
|
|
|
|
|
| 2597 |
}
|
| 2598 |
|
| 2599 |
def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
|
data/quant_ohlc_feature_schema.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Dict, Iterable, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
FEATURE_VERSION = "qohlc_v1"
|
| 6 |
+
FEATURE_VERSION_ID = 1
|
| 7 |
+
WINDOW_SECONDS = 5
|
| 8 |
+
SEGMENT_SECONDS = 300
|
| 9 |
+
TOKENS_PER_SEGMENT = SEGMENT_SECONDS // WINDOW_SECONDS
|
| 10 |
+
PATTERN_NAMES = [
|
| 11 |
+
"double_top",
|
| 12 |
+
"double_bottom",
|
| 13 |
+
"ascending_triangle",
|
| 14 |
+
"descending_triangle",
|
| 15 |
+
"head_shoulders",
|
| 16 |
+
"inverse_head_shoulders",
|
| 17 |
+
]
|
| 18 |
+
LOOKBACK_SECONDS = [15, 30, 60, 120]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
FEATURE_NAMES: List[str] = [
|
| 22 |
+
"cum_log_return",
|
| 23 |
+
"mean_log_return_1s",
|
| 24 |
+
"std_log_return_1s",
|
| 25 |
+
"max_up_1s",
|
| 26 |
+
"max_down_1s",
|
| 27 |
+
"realized_vol",
|
| 28 |
+
"window_range_frac",
|
| 29 |
+
"close_to_close_slope",
|
| 30 |
+
"accel_proxy",
|
| 31 |
+
"frac_pos_1s",
|
| 32 |
+
"frac_neg_1s",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
for lookback in LOOKBACK_SECONDS:
|
| 36 |
+
prefix = f"lb_{lookback}s"
|
| 37 |
+
FEATURE_NAMES.extend([
|
| 38 |
+
f"{prefix}_dist_high",
|
| 39 |
+
f"{prefix}_dist_low",
|
| 40 |
+
f"{prefix}_drawdown_high",
|
| 41 |
+
f"{prefix}_rebound_low",
|
| 42 |
+
f"{prefix}_pos_in_range",
|
| 43 |
+
f"{prefix}_range_width",
|
| 44 |
+
f"{prefix}_compression_ratio",
|
| 45 |
+
f"{prefix}_breakout_high",
|
| 46 |
+
f"{prefix}_breakdown_low",
|
| 47 |
+
f"{prefix}_reclaim_breakdown",
|
| 48 |
+
f"{prefix}_rejection_breakout",
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
FEATURE_NAMES.extend([
|
| 52 |
+
"nearest_support_dist",
|
| 53 |
+
"nearest_resistance_dist",
|
| 54 |
+
"support_touch_count",
|
| 55 |
+
"resistance_touch_count",
|
| 56 |
+
"support_age_sec",
|
| 57 |
+
"resistance_age_sec",
|
| 58 |
+
"support_strength",
|
| 59 |
+
"resistance_strength",
|
| 60 |
+
"inside_support_zone",
|
| 61 |
+
"inside_resistance_zone",
|
| 62 |
+
"support_swept",
|
| 63 |
+
"resistance_swept",
|
| 64 |
+
"support_reclaim",
|
| 65 |
+
"resistance_reject",
|
| 66 |
+
"lower_trendline_slope",
|
| 67 |
+
"upper_trendline_slope",
|
| 68 |
+
"dist_to_lower_line",
|
| 69 |
+
"dist_to_upper_line",
|
| 70 |
+
"trend_channel_width",
|
| 71 |
+
"trend_convergence",
|
| 72 |
+
"trend_breakout_upper",
|
| 73 |
+
"trend_breakdown_lower",
|
| 74 |
+
"trend_reentry",
|
| 75 |
+
"ema_fast",
|
| 76 |
+
"ema_medium",
|
| 77 |
+
"sma_fast",
|
| 78 |
+
"sma_medium",
|
| 79 |
+
"price_minus_ema_fast",
|
| 80 |
+
"price_minus_ema_medium",
|
| 81 |
+
"ema_spread",
|
| 82 |
+
"price_zscore",
|
| 83 |
+
"mean_reversion_score",
|
| 84 |
+
"rolling_vol_zscore",
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
for pattern_name in PATTERN_NAMES:
|
| 88 |
+
FEATURE_NAMES.append(f"pattern_{pattern_name}_confidence")
|
| 89 |
+
|
| 90 |
+
FEATURE_NAMES.extend([
|
| 91 |
+
"sr_available",
|
| 92 |
+
"trendline_available",
|
| 93 |
+
"pattern_available",
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
FEATURE_INDEX = {name: idx for idx, name in enumerate(FEATURE_NAMES)}
|
| 97 |
+
NUM_QUANT_OHLC_FEATURES = len(FEATURE_NAMES)
|
| 98 |
+
|
| 99 |
+
FEATURE_GROUPS = OrderedDict([
|
| 100 |
+
("price_path", [
|
| 101 |
+
"cum_log_return",
|
| 102 |
+
"mean_log_return_1s",
|
| 103 |
+
"std_log_return_1s",
|
| 104 |
+
"max_up_1s",
|
| 105 |
+
"max_down_1s",
|
| 106 |
+
"realized_vol",
|
| 107 |
+
"window_range_frac",
|
| 108 |
+
"close_to_close_slope",
|
| 109 |
+
"accel_proxy",
|
| 110 |
+
"frac_pos_1s",
|
| 111 |
+
"frac_neg_1s",
|
| 112 |
+
]),
|
| 113 |
+
("relative_structure", [name for name in FEATURE_NAMES if name.startswith("lb_")]),
|
| 114 |
+
("levels_breaks", [
|
| 115 |
+
"nearest_support_dist",
|
| 116 |
+
"nearest_resistance_dist",
|
| 117 |
+
"support_touch_count",
|
| 118 |
+
"resistance_touch_count",
|
| 119 |
+
"support_age_sec",
|
| 120 |
+
"resistance_age_sec",
|
| 121 |
+
"support_strength",
|
| 122 |
+
"resistance_strength",
|
| 123 |
+
"inside_support_zone",
|
| 124 |
+
"inside_resistance_zone",
|
| 125 |
+
"support_swept",
|
| 126 |
+
"resistance_swept",
|
| 127 |
+
"support_reclaim",
|
| 128 |
+
"resistance_reject",
|
| 129 |
+
]),
|
| 130 |
+
("trendlines", [
|
| 131 |
+
"lower_trendline_slope",
|
| 132 |
+
"upper_trendline_slope",
|
| 133 |
+
"dist_to_lower_line",
|
| 134 |
+
"dist_to_upper_line",
|
| 135 |
+
"trend_channel_width",
|
| 136 |
+
"trend_convergence",
|
| 137 |
+
"trend_breakout_upper",
|
| 138 |
+
"trend_breakdown_lower",
|
| 139 |
+
"trend_reentry",
|
| 140 |
+
]),
|
| 141 |
+
("rolling_quant", [
|
| 142 |
+
"ema_fast",
|
| 143 |
+
"ema_medium",
|
| 144 |
+
"sma_fast",
|
| 145 |
+
"sma_medium",
|
| 146 |
+
"price_minus_ema_fast",
|
| 147 |
+
"price_minus_ema_medium",
|
| 148 |
+
"ema_spread",
|
| 149 |
+
"price_zscore",
|
| 150 |
+
"mean_reversion_score",
|
| 151 |
+
"rolling_vol_zscore",
|
| 152 |
+
]),
|
| 153 |
+
("patterns", [name for name in FEATURE_NAMES if name.startswith("pattern_")]),
|
| 154 |
+
("availability", [
|
| 155 |
+
"sr_available",
|
| 156 |
+
"trendline_available",
|
| 157 |
+
"pattern_available",
|
| 158 |
+
]),
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def empty_feature_dict() -> Dict[str, float]:
|
| 163 |
+
return {name: 0.0 for name in FEATURE_NAMES}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def feature_dict_to_vector(features: Dict[str, float]) -> List[float]:
|
| 167 |
+
out: List[float] = []
|
| 168 |
+
for name in FEATURE_NAMES:
|
| 169 |
+
value = features.get(name, 0.0)
|
| 170 |
+
try:
|
| 171 |
+
out.append(float(value))
|
| 172 |
+
except Exception:
|
| 173 |
+
out.append(0.0)
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def group_feature_indices(group_names: Iterable[str]) -> List[int]:
|
| 178 |
+
indices: List[int] = []
|
| 179 |
+
for group_name in group_names:
|
| 180 |
+
for feature_name in FEATURE_GROUPS[group_name]:
|
| 181 |
+
indices.append(FEATURE_INDEX[feature_name])
|
| 182 |
+
return sorted(set(indices))
|
inference.py
CHANGED
|
@@ -15,7 +15,9 @@ from models.token_encoder import TokenEncoder
|
|
| 15 |
from models.wallet_encoder import WalletEncoder
|
| 16 |
from models.graph_updater import GraphUpdater
|
| 17 |
from models.ohlc_embedder import OHLCEmbedder
|
|
|
|
| 18 |
import models.vocabulary as vocab
|
|
|
|
| 19 |
|
| 20 |
# --- NEW: Import database clients ---
|
| 21 |
from clickhouse_driver import Client as ClickHouseClient
|
|
@@ -56,7 +58,11 @@ if __name__ == "__main__":
|
|
| 56 |
|
| 57 |
real_ohlc_emb = OHLCEmbedder(
|
| 58 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
dtype=dtype
|
| 61 |
)
|
| 62 |
|
|
@@ -96,7 +102,8 @@ if __name__ == "__main__":
|
|
| 96 |
quantiles=_test_quantiles,
|
| 97 |
horizons_seconds=_test_horizons,
|
| 98 |
dtype=dtype,
|
| 99 |
-
ohlc_embedder=real_ohlc_emb
|
|
|
|
| 100 |
).to(device)
|
| 101 |
model.eval()
|
| 102 |
print(f"Oracle d_model: {model.d_model}")
|
|
|
|
| 15 |
from models.wallet_encoder import WalletEncoder
|
| 16 |
from models.graph_updater import GraphUpdater
|
| 17 |
from models.ohlc_embedder import OHLCEmbedder
|
| 18 |
+
from models.quant_ohlc_embedder import QuantOHLCEmbedder
|
| 19 |
import models.vocabulary as vocab
|
| 20 |
+
from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
|
| 21 |
|
| 22 |
# --- NEW: Import database clients ---
|
| 23 |
from clickhouse_driver import Client as ClickHouseClient
|
|
|
|
| 58 |
|
| 59 |
real_ohlc_emb = OHLCEmbedder(
|
| 60 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 61 |
+
dtype=dtype
|
| 62 |
+
)
|
| 63 |
+
real_quant_ohlc_emb = QuantOHLCEmbedder(
|
| 64 |
+
num_features=NUM_QUANT_OHLC_FEATURES,
|
| 65 |
+
sequence_length=TOKENS_PER_SEGMENT,
|
| 66 |
dtype=dtype
|
| 67 |
)
|
| 68 |
|
|
|
|
| 102 |
quantiles=_test_quantiles,
|
| 103 |
horizons_seconds=_test_horizons,
|
| 104 |
dtype=dtype,
|
| 105 |
+
ohlc_embedder=real_ohlc_emb,
|
| 106 |
+
quant_ohlc_embedder=real_quant_ohlc_emb
|
| 107 |
).to(device)
|
| 108 |
model.eval()
|
| 109 |
print(f"Oracle d_model: {model.d_model}")
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49c4d500d58301a7c158851716c6cf0c7e6bc60cdf59f7a166fcb92b4e77b04a
|
| 3 |
+
size 390587
|
models/model.py
CHANGED
|
@@ -14,6 +14,7 @@ from models.token_encoder import TokenEncoder
|
|
| 14 |
from models.wallet_encoder import WalletEncoder
|
| 15 |
from models.graph_updater import GraphUpdater
|
| 16 |
from models.ohlc_embedder import OHLCEmbedder
|
|
|
|
| 17 |
from models.HoldersEncoder import HolderDistributionEncoder # NEW
|
| 18 |
from models.SocialEncoders import SocialEncoder # NEW
|
| 19 |
import models.vocabulary as vocab # For vocab sizes
|
|
@@ -28,6 +29,7 @@ class Oracle(nn.Module):
|
|
| 28 |
wallet_encoder: WalletEncoder,
|
| 29 |
graph_updater: GraphUpdater,
|
| 30 |
ohlc_embedder: OHLCEmbedder, # NEW
|
|
|
|
| 31 |
time_encoder: ContextualTimeEncoder,
|
| 32 |
num_event_types: int,
|
| 33 |
multi_modal_dim: int,
|
|
@@ -124,7 +126,8 @@ class Oracle(nn.Module):
|
|
| 124 |
self.token_encoder = token_encoder
|
| 125 |
self.wallet_encoder = wallet_encoder
|
| 126 |
self.graph_updater = graph_updater
|
| 127 |
-
self.ohlc_embedder = ohlc_embedder
|
|
|
|
| 128 |
self.time_encoder = time_encoder # Store time_encoder
|
| 129 |
|
| 130 |
self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
|
|
@@ -140,7 +143,7 @@ class Oracle(nn.Module):
|
|
| 140 |
# --- 5. Define Entity Padding (Learnable) ---
|
| 141 |
self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
|
| 142 |
self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
|
| 143 |
-
self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.
|
| 144 |
self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
|
| 145 |
|
| 146 |
# --- NEW: Instantiate HolderDistributionEncoder internally ---
|
|
@@ -157,7 +160,16 @@ class Oracle(nn.Module):
|
|
| 157 |
self.rel_ts_norm = nn.LayerNorm(1)
|
| 158 |
self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
|
| 159 |
self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
|
| 160 |
-
self.ohlc_proj = nn.Linear(self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
# self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
|
| 162 |
|
| 163 |
|
|
@@ -309,7 +321,7 @@ class Oracle(nn.Module):
|
|
| 309 |
|
| 310 |
@classmethod
|
| 311 |
def from_pretrained(cls, load_directory: str,
|
| 312 |
-
token_encoder, wallet_encoder, graph_updater, ohlc_embedder, time_encoder):
|
| 313 |
"""
|
| 314 |
Loads the Oracle model from a saved directory.
|
| 315 |
Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
|
|
@@ -329,6 +341,7 @@ class Oracle(nn.Module):
|
|
| 329 |
wallet_encoder=wallet_encoder,
|
| 330 |
graph_updater=graph_updater,
|
| 331 |
ohlc_embedder=ohlc_embedder,
|
|
|
|
| 332 |
time_encoder=time_encoder,
|
| 333 |
num_event_types=config["num_event_types"],
|
| 334 |
multi_modal_dim=config["multi_modal_dim"],
|
|
@@ -431,6 +444,14 @@ class Oracle(nn.Module):
|
|
| 431 |
neginf=0.0
|
| 432 |
)
|
| 433 |
ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
graph_updater_links = batch['graph_updater_links']
|
| 435 |
|
| 436 |
# 1a. Encode Tokens
|
|
@@ -490,9 +511,39 @@ class Oracle(nn.Module):
|
|
| 490 |
|
| 491 |
# 1c. Encode OHLC
|
| 492 |
if ohlc_price_tensors.shape[0] > 0:
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
else:
|
| 495 |
-
batch_ohlc_embeddings_raw = torch.empty(0, self.
|
| 496 |
|
| 497 |
# 1d. Run Graph Updater
|
| 498 |
pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
|
|
|
|
| 14 |
from models.wallet_encoder import WalletEncoder
|
| 15 |
from models.graph_updater import GraphUpdater
|
| 16 |
from models.ohlc_embedder import OHLCEmbedder
|
| 17 |
+
from models.quant_ohlc_embedder import QuantOHLCEmbedder
|
| 18 |
from models.HoldersEncoder import HolderDistributionEncoder # NEW
|
| 19 |
from models.SocialEncoders import SocialEncoder # NEW
|
| 20 |
import models.vocabulary as vocab # For vocab sizes
|
|
|
|
| 29 |
wallet_encoder: WalletEncoder,
|
| 30 |
graph_updater: GraphUpdater,
|
| 31 |
ohlc_embedder: OHLCEmbedder, # NEW
|
| 32 |
+
quant_ohlc_embedder: QuantOHLCEmbedder,
|
| 33 |
time_encoder: ContextualTimeEncoder,
|
| 34 |
num_event_types: int,
|
| 35 |
multi_modal_dim: int,
|
|
|
|
| 126 |
self.token_encoder = token_encoder
|
| 127 |
self.wallet_encoder = wallet_encoder
|
| 128 |
self.graph_updater = graph_updater
|
| 129 |
+
self.ohlc_embedder = ohlc_embedder
|
| 130 |
+
self.quant_ohlc_embedder = quant_ohlc_embedder
|
| 131 |
self.time_encoder = time_encoder # Store time_encoder
|
| 132 |
|
| 133 |
self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
|
|
|
|
| 143 |
# --- 5. Define Entity Padding (Learnable) ---
|
| 144 |
self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
|
| 145 |
self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
|
| 146 |
+
self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim))
|
| 147 |
self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
|
| 148 |
|
| 149 |
# --- NEW: Instantiate HolderDistributionEncoder internally ---
|
|
|
|
| 160 |
self.rel_ts_norm = nn.LayerNorm(1)
|
| 161 |
self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
|
| 162 |
self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
|
| 163 |
+
self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model)
|
| 164 |
+
self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0)
|
| 165 |
+
fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32
|
| 166 |
+
self.chart_fusion = nn.Sequential(
|
| 167 |
+
nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim),
|
| 168 |
+
nn.GELU(),
|
| 169 |
+
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
|
| 170 |
+
nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim),
|
| 171 |
+
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
|
| 172 |
+
)
|
| 173 |
# self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
|
| 174 |
|
| 175 |
|
|
|
|
| 321 |
|
| 322 |
@classmethod
|
| 323 |
def from_pretrained(cls, load_directory: str,
|
| 324 |
+
token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder):
|
| 325 |
"""
|
| 326 |
Loads the Oracle model from a saved directory.
|
| 327 |
Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
|
|
|
|
| 341 |
wallet_encoder=wallet_encoder,
|
| 342 |
graph_updater=graph_updater,
|
| 343 |
ohlc_embedder=ohlc_embedder,
|
| 344 |
+
quant_ohlc_embedder=quant_ohlc_embedder,
|
| 345 |
time_encoder=time_encoder,
|
| 346 |
num_event_types=config["num_event_types"],
|
| 347 |
multi_modal_dim=config["multi_modal_dim"],
|
|
|
|
| 444 |
neginf=0.0
|
| 445 |
)
|
| 446 |
ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
|
| 447 |
+
quant_ohlc_feature_tensors = torch.nan_to_num(
|
| 448 |
+
batch['quant_ohlc_feature_tensors'].to(device, self.dtype),
|
| 449 |
+
nan=0.0,
|
| 450 |
+
posinf=0.0,
|
| 451 |
+
neginf=0.0
|
| 452 |
+
)
|
| 453 |
+
quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device)
|
| 454 |
+
quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device)
|
| 455 |
graph_updater_links = batch['graph_updater_links']
|
| 456 |
|
| 457 |
# 1a. Encode Tokens
|
|
|
|
| 511 |
|
| 512 |
# 1c. Encode OHLC
|
| 513 |
if ohlc_price_tensors.shape[0] > 0:
|
| 514 |
+
raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
|
| 515 |
+
else:
|
| 516 |
+
raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
|
| 517 |
+
if quant_ohlc_feature_tensors.shape[0] > 0:
|
| 518 |
+
quant_chart_embeddings = self.quant_ohlc_embedder(
|
| 519 |
+
quant_ohlc_feature_tensors,
|
| 520 |
+
quant_ohlc_feature_mask,
|
| 521 |
+
quant_ohlc_feature_version_ids,
|
| 522 |
+
)
|
| 523 |
+
else:
|
| 524 |
+
quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
|
| 525 |
+
num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0])
|
| 526 |
+
if num_chart_segments > 0:
|
| 527 |
+
if raw_chart_embeddings.shape[0] == 0:
|
| 528 |
+
raw_chart_embeddings = torch.zeros(
|
| 529 |
+
num_chart_segments,
|
| 530 |
+
self.ohlc_embedder.output_dim,
|
| 531 |
+
device=device,
|
| 532 |
+
dtype=self.dtype,
|
| 533 |
+
)
|
| 534 |
+
if quant_chart_embeddings.shape[0] == 0:
|
| 535 |
+
quant_chart_embeddings = torch.zeros(
|
| 536 |
+
num_chart_segments,
|
| 537 |
+
self.quant_ohlc_embedder.output_dim,
|
| 538 |
+
device=device,
|
| 539 |
+
dtype=self.dtype,
|
| 540 |
+
)
|
| 541 |
+
interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype)
|
| 542 |
+
batch_ohlc_embeddings_raw = self.chart_fusion(
|
| 543 |
+
torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1)
|
| 544 |
+
)
|
| 545 |
else:
|
| 546 |
+
batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
|
| 547 |
|
| 548 |
# 1d. Run Graph Updater
|
| 549 |
pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
|
models/ohlc_embedder.py
CHANGED
|
@@ -19,11 +19,11 @@ class OHLCEmbedder(nn.Module):
|
|
| 19 |
num_intervals: int,
|
| 20 |
input_channels: int = 2, # Open, Close
|
| 21 |
# sequence_length: int = 300, # REMOVED: HARDCODED
|
| 22 |
-
cnn_channels: List[int] = [
|
| 23 |
kernel_sizes: List[int] = [3, 3, 3],
|
| 24 |
# --- NEW: Interval embedding dim ---
|
| 25 |
-
interval_embed_dim: int =
|
| 26 |
-
output_dim: int =
|
| 27 |
dtype: torch.dtype = torch.float16
|
| 28 |
):
|
| 29 |
super().__init__()
|
|
@@ -116,4 +116,3 @@ class OHLCEmbedder(nn.Module):
|
|
| 116 |
# Shape: [batch_size, output_dim]
|
| 117 |
|
| 118 |
return x
|
| 119 |
-
|
|
|
|
| 19 |
num_intervals: int,
|
| 20 |
input_channels: int = 2, # Open, Close
|
| 21 |
# sequence_length: int = 300, # REMOVED: HARDCODED
|
| 22 |
+
cnn_channels: List[int] = [8, 16, 32],
|
| 23 |
kernel_sizes: List[int] = [3, 3, 3],
|
| 24 |
# --- NEW: Interval embedding dim ---
|
| 25 |
+
interval_embed_dim: int = 16,
|
| 26 |
+
output_dim: int = 512,
|
| 27 |
dtype: torch.dtype = torch.float16
|
| 28 |
):
|
| 29 |
super().__init__()
|
|
|
|
| 116 |
# Shape: [batch_size, output_dim]
|
| 117 |
|
| 118 |
return x
|
|
|
models/quant_ohlc_embedder.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class QuantOHLCEmbedder(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
num_features: int,
|
| 9 |
+
sequence_length: int = 60,
|
| 10 |
+
version_vocab_size: int = 4,
|
| 11 |
+
hidden_dim: int = 320,
|
| 12 |
+
num_layers: int = 3,
|
| 13 |
+
num_heads: int = 8,
|
| 14 |
+
output_dim: int = 1536,
|
| 15 |
+
dtype: torch.dtype = torch.float16,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.num_features = num_features
|
| 19 |
+
self.sequence_length = sequence_length
|
| 20 |
+
self.output_dim = output_dim
|
| 21 |
+
self.dtype = dtype
|
| 22 |
+
|
| 23 |
+
self.feature_proj = nn.Sequential(
|
| 24 |
+
nn.LayerNorm(num_features),
|
| 25 |
+
nn.Linear(num_features, hidden_dim),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
)
|
| 28 |
+
self.position_embedding = nn.Parameter(torch.zeros(1, sequence_length, hidden_dim))
|
| 29 |
+
self.version_embedding = nn.Embedding(version_vocab_size, hidden_dim, padding_idx=0)
|
| 30 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 31 |
+
d_model=hidden_dim,
|
| 32 |
+
nhead=num_heads,
|
| 33 |
+
dim_feedforward=hidden_dim * 4,
|
| 34 |
+
dropout=0.0,
|
| 35 |
+
batch_first=True,
|
| 36 |
+
activation="gelu",
|
| 37 |
+
norm_first=True,
|
| 38 |
+
)
|
| 39 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 40 |
+
self.output_head = nn.Sequential(
|
| 41 |
+
nn.LayerNorm(hidden_dim),
|
| 42 |
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
| 43 |
+
nn.GELU(),
|
| 44 |
+
nn.LayerNorm(hidden_dim * 2),
|
| 45 |
+
nn.Linear(hidden_dim * 2, output_dim),
|
| 46 |
+
nn.LayerNorm(output_dim),
|
| 47 |
+
)
|
| 48 |
+
self.to(dtype)
|
| 49 |
+
|
| 50 |
+
def forward(
|
| 51 |
+
self,
|
| 52 |
+
feature_tokens: torch.Tensor,
|
| 53 |
+
feature_mask: torch.Tensor,
|
| 54 |
+
version_ids: torch.Tensor,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
if feature_tokens.ndim != 3:
|
| 57 |
+
raise ValueError(f"Expected [B, T, F], got {feature_tokens.shape}")
|
| 58 |
+
if feature_tokens.shape[1] != self.sequence_length:
|
| 59 |
+
raise ValueError(f"Expected T={self.sequence_length}, got {feature_tokens.shape[1]}")
|
| 60 |
+
if feature_tokens.shape[2] != self.num_features:
|
| 61 |
+
raise ValueError(f"Expected F={self.num_features}, got {feature_tokens.shape[2]}")
|
| 62 |
+
|
| 63 |
+
x = self.feature_proj(feature_tokens.to(self.dtype))
|
| 64 |
+
version_embed = self.version_embedding(version_ids).unsqueeze(1)
|
| 65 |
+
x = x + self.position_embedding[:, : x.shape[1], :].to(x.dtype) + version_embed
|
| 66 |
+
key_padding_mask = ~(feature_mask > 0)
|
| 67 |
+
x = self.encoder(x, src_key_padding_mask=key_padding_mask)
|
| 68 |
+
|
| 69 |
+
mask = feature_mask.to(x.dtype).unsqueeze(-1)
|
| 70 |
+
valid_any = (feature_mask.sum(dim=1, keepdim=True) > 0).to(x.dtype)
|
| 71 |
+
denom = mask.sum(dim=1).clamp_min(1.0)
|
| 72 |
+
pooled = (x * mask).sum(dim=1) / denom
|
| 73 |
+
out = self.output_head(pooled)
|
| 74 |
+
return out * valid_any
|
pre_cache.sh
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
CONTEXT_LENGTH=4096
|
| 5 |
MIN_TRADES=10
|
| 6 |
SAMPLES_PER_TOKEN=1
|
|
|
|
| 7 |
NUM_WORKERS=1
|
| 8 |
|
| 9 |
OUTPUT_DIR="data/cache"
|
|
@@ -20,6 +21,7 @@ echo "========================================"
|
|
| 20 |
echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
|
| 21 |
echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
|
| 22 |
echo "Samples per Token: $SAMPLES_PER_TOKEN"
|
|
|
|
| 23 |
echo "Num Workers: $NUM_WORKERS"
|
| 24 |
echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
|
| 25 |
echo "Quantiles: ${QUANTILES[*]}"
|
|
@@ -36,10 +38,10 @@ python3 scripts/cache_dataset.py \
|
|
| 36 |
--context_length "$CONTEXT_LENGTH" \
|
| 37 |
--min_trades "$MIN_TRADES" \
|
| 38 |
--samples_per_token "$SAMPLES_PER_TOKEN" \
|
|
|
|
| 39 |
--num_workers "$NUM_WORKERS" \
|
| 40 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 41 |
--quantiles "${QUANTILES[@]}" \
|
| 42 |
-
--max_samples 10 \
|
| 43 |
"$@"
|
| 44 |
|
| 45 |
echo "Done!"
|
|
|
|
| 4 |
CONTEXT_LENGTH=4096
|
| 5 |
MIN_TRADES=10
|
| 6 |
SAMPLES_PER_TOKEN=1
|
| 7 |
+
TARGET_CONTEXTS_PER_CLASS=10
|
| 8 |
NUM_WORKERS=1
|
| 9 |
|
| 10 |
OUTPUT_DIR="data/cache"
|
|
|
|
| 21 |
echo "Context Length (H/B/H threshold): $CONTEXT_LENGTH"
|
| 22 |
echo "Min Trades (T_cutoff threshold): $MIN_TRADES"
|
| 23 |
echo "Samples per Token: $SAMPLES_PER_TOKEN"
|
| 24 |
+
echo "Target Contexts per Class: $TARGET_CONTEXTS_PER_CLASS"
|
| 25 |
echo "Num Workers: $NUM_WORKERS"
|
| 26 |
echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
|
| 27 |
echo "Quantiles: ${QUANTILES[*]}"
|
|
|
|
| 38 |
--context_length "$CONTEXT_LENGTH" \
|
| 39 |
--min_trades "$MIN_TRADES" \
|
| 40 |
--samples_per_token "$SAMPLES_PER_TOKEN" \
|
| 41 |
+
--target_contexts_per_class "$TARGET_CONTEXTS_PER_CLASS" \
|
| 42 |
--num_workers "$NUM_WORKERS" \
|
| 43 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 44 |
--quantiles "${QUANTILES[@]}" \
|
|
|
|
| 45 |
"$@"
|
| 46 |
|
| 47 |
echo "Done!"
|
sample_2kGqvM18kGLby9bY_5.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/cache_dataset.py
CHANGED
|
@@ -6,15 +6,13 @@ import argparse
|
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
| 9 |
-
import math
|
| 10 |
from pathlib import Path
|
| 11 |
from tqdm import tqdm
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
import huggingface_hub
|
| 14 |
import logging
|
| 15 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 16 |
import multiprocessing as mp
|
| 17 |
-
from
|
| 18 |
|
| 19 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 20 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
@@ -23,7 +21,9 @@ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
|
| 23 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 24 |
|
| 25 |
from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
|
| 26 |
-
from scripts.compute_quality_score import get_token_quality_scores
|
|
|
|
|
|
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
| 29 |
from neo4j import GraphDatabase
|
|
@@ -58,6 +58,32 @@ def _representative_context_polarity(context):
|
|
| 58 |
return "positive" if max(valid_returns) > 0.0 else "negative"
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
|
| 62 |
if len(contexts) <= max_keep:
|
| 63 |
polarity_counts = {}
|
|
@@ -113,68 +139,6 @@ def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desi
|
|
| 113 |
return selected[:max_keep], polarity_counts
|
| 114 |
|
| 115 |
|
| 116 |
-
def _allocate_class_targets(mints_by_class, target_total, positive_balance_min_class, positive_ratio):
|
| 117 |
-
from collections import defaultdict
|
| 118 |
-
import random
|
| 119 |
-
|
| 120 |
-
class_ids = sorted(mints_by_class.keys())
|
| 121 |
-
if not class_ids:
|
| 122 |
-
return {}, {}, {}
|
| 123 |
-
|
| 124 |
-
target_per_class = target_total // len(class_ids)
|
| 125 |
-
remainder = target_total % len(class_ids)
|
| 126 |
-
|
| 127 |
-
token_plans = {}
|
| 128 |
-
class_targets = {}
|
| 129 |
-
class_polarity_targets = {}
|
| 130 |
-
|
| 131 |
-
for pos, class_id in enumerate(class_ids):
|
| 132 |
-
class_target = target_per_class + (1 if pos < remainder else 0)
|
| 133 |
-
class_targets[class_id] = class_target
|
| 134 |
-
|
| 135 |
-
token_list = list(mints_by_class[class_id])
|
| 136 |
-
random.shuffle(token_list)
|
| 137 |
-
if not token_list or class_target <= 0:
|
| 138 |
-
class_polarity_targets[class_id] = {"positive": 0, "negative": 0}
|
| 139 |
-
continue
|
| 140 |
-
|
| 141 |
-
if class_id >= positive_balance_min_class:
|
| 142 |
-
positive_target = int(round(class_target * positive_ratio))
|
| 143 |
-
positive_target = min(max(positive_target, 0), class_target)
|
| 144 |
-
else:
|
| 145 |
-
positive_target = 0
|
| 146 |
-
negative_target = class_target - positive_target
|
| 147 |
-
class_polarity_targets[class_id] = {
|
| 148 |
-
"positive": positive_target,
|
| 149 |
-
"negative": negative_target,
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
assigned_positive = 0
|
| 153 |
-
assigned_negative = 0
|
| 154 |
-
token_count = len(token_list)
|
| 155 |
-
for sample_num in range(class_target):
|
| 156 |
-
token_idx, mint_record = token_list[sample_num % token_count]
|
| 157 |
-
mint_addr = mint_record["mint_address"]
|
| 158 |
-
plan_key = (token_idx, mint_addr)
|
| 159 |
-
if plan_key not in token_plans:
|
| 160 |
-
token_plans[plan_key] = {
|
| 161 |
-
"samples_to_keep": 0,
|
| 162 |
-
"desired_positive": 0,
|
| 163 |
-
"desired_negative": 0,
|
| 164 |
-
"class_id": class_id,
|
| 165 |
-
}
|
| 166 |
-
token_plans[plan_key]["samples_to_keep"] += 1
|
| 167 |
-
|
| 168 |
-
if assigned_positive < positive_target:
|
| 169 |
-
token_plans[plan_key]["desired_positive"] += 1
|
| 170 |
-
assigned_positive += 1
|
| 171 |
-
else:
|
| 172 |
-
token_plans[plan_key]["desired_negative"] += 1
|
| 173 |
-
assigned_negative += 1
|
| 174 |
-
|
| 175 |
-
return token_plans, class_targets, class_polarity_targets
|
| 176 |
-
|
| 177 |
-
|
| 178 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 179 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 180 |
from data.data_loader import OracleDataset
|
|
@@ -200,7 +164,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 200 |
|
| 201 |
_worker_dataset = OracleDataset(
|
| 202 |
data_fetcher=data_fetcher,
|
| 203 |
-
max_samples=dataset_config['max_samples'],
|
| 204 |
start_date=dataset_config['start_date'],
|
| 205 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 206 |
quantiles=dataset_config['quantiles'],
|
|
@@ -214,7 +177,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 214 |
|
| 215 |
|
| 216 |
def _process_single_token_context(args):
|
| 217 |
-
idx, mint_addr, samples_per_token,
|
| 218 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 219 |
try:
|
| 220 |
class_id = _worker_return_class_map.get(mint_addr)
|
|
@@ -244,15 +207,6 @@ def _process_single_token_context(args):
|
|
| 244 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 245 |
if q_score is None:
|
| 246 |
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 247 |
-
saved_files = []
|
| 248 |
-
for ctx_idx, ctx in enumerate(contexts):
|
| 249 |
-
ctx["quality_score"] = q_score
|
| 250 |
-
ctx["class_id"] = class_id
|
| 251 |
-
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 252 |
-
output_path = Path(output_dir) / filename
|
| 253 |
-
|
| 254 |
-
torch.save(ctx, output_path)
|
| 255 |
-
saved_files.append(filename)
|
| 256 |
return {
|
| 257 |
'status': 'success',
|
| 258 |
'mint': mint_addr,
|
|
@@ -260,7 +214,7 @@ def _process_single_token_context(args):
|
|
| 260 |
'q_score': q_score,
|
| 261 |
'n_contexts': len(contexts),
|
| 262 |
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
|
| 263 |
-
'
|
| 264 |
'polarity_counts': polarity_counts,
|
| 265 |
}
|
| 266 |
except Exception as e:
|
|
@@ -284,7 +238,6 @@ def main():
|
|
| 284 |
|
| 285 |
parser = argparse.ArgumentParser()
|
| 286 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 287 |
-
parser.add_argument("--max_samples", type=int, default=None)
|
| 288 |
parser.add_argument("--start_date", type=str, default=None)
|
| 289 |
|
| 290 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
|
@@ -292,8 +245,8 @@ def main():
|
|
| 292 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 293 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 294 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
|
|
|
| 295 |
parser.add_argument("--context_oversample_factor", type=int, default=4)
|
| 296 |
-
parser.add_argument("--cache_balance_mode", type=str, default="hybrid", choices=["class", "uniform", "hybrid"])
|
| 297 |
parser.add_argument("--positive_balance_min_class", type=int, default=2)
|
| 298 |
parser.add_argument("--positive_context_ratio", type=float, default=0.5)
|
| 299 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
|
@@ -308,6 +261,10 @@ def main():
|
|
| 308 |
|
| 309 |
if args.num_workers == 0:
|
| 310 |
args.num_workers = max(1, mp.cpu_count() - 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
output_dir = Path(args.output_dir)
|
| 313 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -334,7 +291,7 @@ def main():
|
|
| 334 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 335 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 336 |
|
| 337 |
-
dataset = OracleDataset(data_fetcher=data_fetcher,
|
| 338 |
|
| 339 |
if len(dataset) == 0:
|
| 340 |
print("WARNING: No samples. Exiting.")
|
|
@@ -370,93 +327,73 @@ def main():
|
|
| 370 |
print(f"INFO: Workers: {args.num_workers}")
|
| 371 |
|
| 372 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 373 |
-
dataset_config = {'
|
| 374 |
|
| 375 |
-
# Build tasks with class-aware multi-sampling for balanced cache
|
| 376 |
import random
|
| 377 |
-
from collections import Counter, defaultdict
|
| 378 |
|
| 379 |
-
# Count eligible tokens per class
|
| 380 |
eligible_class_counts = Counter()
|
| 381 |
-
mints_by_class = defaultdict(list)
|
| 382 |
for i, m in enumerate(filtered_mints):
|
| 383 |
cid = return_class_map.get(m['mint_address'])
|
| 384 |
if cid is not None:
|
| 385 |
eligible_class_counts[cid] += 1
|
| 386 |
-
mints_by_class[cid].append((i, m))
|
| 387 |
|
| 388 |
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
|
| 404 |
-
|
|
|
|
| 405 |
print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
|
| 406 |
print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
|
| 407 |
|
| 408 |
-
|
| 409 |
-
tasks
|
| 410 |
-
|
| 411 |
-
target_tokens = len(filtered_mints)
|
| 412 |
-
if args.max_samples:
|
| 413 |
-
target_tokens = min(len(filtered_mints), max(1, math.ceil(args.max_samples / max(args.samples_per_token, 1))))
|
| 414 |
-
mint_pool = list(enumerate(filtered_mints))
|
| 415 |
-
random.shuffle(mint_pool)
|
| 416 |
-
for i, m in mint_pool[:target_tokens]:
|
| 417 |
-
tasks.append((i, m['mint_address'], args.samples_per_token, str(output_dir), args.context_oversample_factor, 0, args.samples_per_token))
|
| 418 |
-
else:
|
| 419 |
-
for (token_idx, mint_addr), plan in token_plans.items():
|
| 420 |
-
tasks.append((
|
| 421 |
-
token_idx,
|
| 422 |
-
mint_addr,
|
| 423 |
-
plan["samples_to_keep"],
|
| 424 |
-
str(output_dir),
|
| 425 |
-
args.context_oversample_factor,
|
| 426 |
-
plan["desired_positive"],
|
| 427 |
-
plan["desired_negative"],
|
| 428 |
-
))
|
| 429 |
-
|
| 430 |
-
random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
|
| 431 |
-
expected_files = sum(task[2] for task in tasks)
|
| 432 |
-
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 433 |
|
| 434 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 435 |
-
class_distribution =
|
| 436 |
-
polarity_distribution =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
#
|
| 439 |
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
| 440 |
if existing_files:
|
| 441 |
-
pre_resume = len(tasks)
|
| 442 |
-
filtered_tasks = []
|
| 443 |
already_cached = 0
|
| 444 |
-
for
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 462 |
process_fn = _process_single_token_context
|
|
@@ -484,64 +421,113 @@ def main():
|
|
| 484 |
error_log_path = Path(args.output_dir) / "cache_errors.log"
|
| 485 |
error_samples = [] # First 20 unique error messages
|
| 486 |
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
success_count += 1
|
| 501 |
-
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 502 |
-
for polarity, count in result.get('polarity_counts', {}).items():
|
| 503 |
-
polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
|
| 504 |
-
elif result['status'] == 'skipped':
|
| 505 |
-
skipped_count += 1
|
| 506 |
else:
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
start_time
|
| 516 |
-
recent_times = []
|
| 517 |
-
with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
|
| 518 |
-
futures = {executor.submit(process_fn, task): task for task in tasks}
|
| 519 |
-
for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
|
| 520 |
-
t0 = _time.perf_counter()
|
| 521 |
-
try:
|
| 522 |
-
result = future.result(timeout=300)
|
| 523 |
-
elapsed = _time.perf_counter() - t0
|
| 524 |
-
recent_times.append(elapsed)
|
| 525 |
-
if len(recent_times) > 50:
|
| 526 |
-
recent_times.pop(0)
|
| 527 |
-
if result['status'] == 'success':
|
| 528 |
-
success_count += 1
|
| 529 |
-
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 530 |
-
for polarity, count in result.get('polarity_counts', {}).items():
|
| 531 |
-
polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
|
| 532 |
-
elif result['status'] == 'skipped':
|
| 533 |
-
skipped_count += 1
|
| 534 |
-
else:
|
| 535 |
-
error_count += 1
|
| 536 |
-
err_msg = result.get('error', 'unknown')
|
| 537 |
-
if len(error_samples) < 20:
|
| 538 |
-
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 539 |
-
if error_count <= 5:
|
| 540 |
-
tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
|
| 541 |
-
except Exception as e:
|
| 542 |
-
error_count += 1
|
| 543 |
-
tqdm.write(f"WORKER ERROR: {e}")
|
| 544 |
-
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 545 |
|
| 546 |
# Write error log
|
| 547 |
if error_samples:
|
|
@@ -553,28 +539,24 @@ def main():
|
|
| 553 |
print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
|
| 554 |
|
| 555 |
print("INFO: Building metadata...")
|
| 556 |
-
file_class_map = {}
|
| 557 |
-
for f in sorted(output_dir.glob("sample_*.pt")):
|
| 558 |
-
try:
|
| 559 |
-
file_class_map[f.name] = torch.load(f, map_location="cpu", weights_only=False).get("class_id", 0)
|
| 560 |
-
except:
|
| 561 |
-
pass
|
| 562 |
-
|
| 563 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 564 |
json.dump({
|
| 565 |
'file_class_map': file_class_map,
|
|
|
|
|
|
|
| 566 |
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 567 |
'num_workers': args.num_workers,
|
| 568 |
'horizons_seconds': args.horizons_seconds,
|
| 569 |
'quantiles': args.quantiles,
|
| 570 |
'target_total': target_total,
|
| 571 |
-
'
|
| 572 |
-
'cache_balance_mode': args.cache_balance_mode,
|
| 573 |
'context_polarity_distribution': polarity_distribution,
|
| 574 |
'class_targets': {str(k): v for k, v in class_targets.items()},
|
| 575 |
'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
|
|
|
|
| 576 |
'positive_balance_min_class': args.positive_balance_min_class,
|
| 577 |
'positive_context_ratio': args.positive_context_ratio,
|
|
|
|
| 578 |
}, f, indent=2)
|
| 579 |
|
| 580 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
|
|
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from tqdm import tqdm
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
import huggingface_hub
|
| 13 |
import logging
|
|
|
|
| 14 |
import multiprocessing as mp
|
| 15 |
+
from collections import Counter, defaultdict
|
| 16 |
|
| 17 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 18 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
|
|
| 21 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 22 |
|
| 23 |
from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
|
| 24 |
+
from scripts.compute_quality_score import get_token_quality_scores
|
| 25 |
+
from data.data_loader import summarize_context_window
|
| 26 |
+
from data.quant_ohlc_feature_schema import FEATURE_VERSION
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
| 29 |
from neo4j import GraphDatabase
|
|
|
|
| 58 |
return "positive" if max(valid_returns) > 0.0 else "negative"
|
| 59 |
|
| 60 |
|
| 61 |
+
def _class_polarity_targets(class_id, target_contexts_per_class, positive_balance_min_class, positive_ratio):
|
| 62 |
+
if class_id >= positive_balance_min_class:
|
| 63 |
+
positive_target = int(round(target_contexts_per_class * positive_ratio))
|
| 64 |
+
positive_target = min(max(positive_target, 0), target_contexts_per_class)
|
| 65 |
+
else:
|
| 66 |
+
positive_target = 0
|
| 67 |
+
return {
|
| 68 |
+
"positive": positive_target,
|
| 69 |
+
"negative": max(0, target_contexts_per_class - positive_target),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _remaining_polarity_targets(class_id, accepted_counts, target_contexts_per_class, positive_balance_min_class, positive_ratio):
|
| 74 |
+
targets = _class_polarity_targets(
|
| 75 |
+
class_id=class_id,
|
| 76 |
+
target_contexts_per_class=target_contexts_per_class,
|
| 77 |
+
positive_balance_min_class=positive_balance_min_class,
|
| 78 |
+
positive_ratio=positive_ratio,
|
| 79 |
+
)
|
| 80 |
+
class_counts = accepted_counts[class_id]
|
| 81 |
+
return {
|
| 82 |
+
"positive": max(0, targets["positive"] - class_counts["positive"]),
|
| 83 |
+
"negative": max(0, targets["negative"] - class_counts["negative"]),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
|
| 88 |
if len(contexts) <= max_keep:
|
| 89 |
polarity_counts = {}
|
|
|
|
| 139 |
return selected[:max_keep], polarity_counts
|
| 140 |
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 143 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 144 |
from data.data_loader import OracleDataset
|
|
|
|
| 164 |
|
| 165 |
_worker_dataset = OracleDataset(
|
| 166 |
data_fetcher=data_fetcher,
|
|
|
|
| 167 |
start_date=dataset_config['start_date'],
|
| 168 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 169 |
quantiles=dataset_config['quantiles'],
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
def _process_single_token_context(args):
|
| 180 |
+
idx, mint_addr, samples_per_token, oversample_factor, desired_positive, desired_negative = args
|
| 181 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 182 |
try:
|
| 183 |
class_id = _worker_return_class_map.get(mint_addr)
|
|
|
|
| 207 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 208 |
if q_score is None:
|
| 209 |
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
return {
|
| 211 |
'status': 'success',
|
| 212 |
'mint': mint_addr,
|
|
|
|
| 214 |
'q_score': q_score,
|
| 215 |
'n_contexts': len(contexts),
|
| 216 |
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
|
| 217 |
+
'contexts': contexts,
|
| 218 |
'polarity_counts': polarity_counts,
|
| 219 |
}
|
| 220 |
except Exception as e:
|
|
|
|
| 238 |
|
| 239 |
parser = argparse.ArgumentParser()
|
| 240 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
|
|
|
| 241 |
parser.add_argument("--start_date", type=str, default=None)
|
| 242 |
|
| 243 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
|
|
|
| 245 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 246 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 247 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
| 248 |
+
parser.add_argument("--target_contexts_per_class", type=int, default=2500)
|
| 249 |
parser.add_argument("--context_oversample_factor", type=int, default=4)
|
|
|
|
| 250 |
parser.add_argument("--positive_balance_min_class", type=int, default=2)
|
| 251 |
parser.add_argument("--positive_context_ratio", type=float, default=0.5)
|
| 252 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
|
|
|
| 261 |
|
| 262 |
if args.num_workers == 0:
|
| 263 |
args.num_workers = max(1, mp.cpu_count() - 4)
|
| 264 |
+
if args.num_workers != 1:
|
| 265 |
+
raise RuntimeError("Quota-based caching requires --num_workers 1 so class counters remain exact.")
|
| 266 |
+
if args.target_contexts_per_class <= 0:
|
| 267 |
+
raise RuntimeError("--target_contexts_per_class must be positive.")
|
| 268 |
|
| 269 |
output_dir = Path(args.output_dir)
|
| 270 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 291 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 292 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 293 |
|
| 294 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
|
| 295 |
|
| 296 |
if len(dataset) == 0:
|
| 297 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 327 |
print(f"INFO: Workers: {args.num_workers}")
|
| 328 |
|
| 329 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 330 |
+
dataset_config = {'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
|
| 331 |
|
|
|
|
| 332 |
import random
|
|
|
|
| 333 |
|
|
|
|
| 334 |
eligible_class_counts = Counter()
|
|
|
|
| 335 |
for i, m in enumerate(filtered_mints):
|
| 336 |
cid = return_class_map.get(m['mint_address'])
|
| 337 |
if cid is not None:
|
| 338 |
eligible_class_counts[cid] += 1
|
|
|
|
| 339 |
|
| 340 |
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 341 |
|
| 342 |
+
class_targets = {
|
| 343 |
+
int(class_id): int(args.target_contexts_per_class)
|
| 344 |
+
for class_id in sorted(eligible_class_counts.keys())
|
| 345 |
+
}
|
| 346 |
+
class_polarity_targets = {
|
| 347 |
+
class_id: _class_polarity_targets(
|
| 348 |
+
class_id=class_id,
|
| 349 |
+
target_contexts_per_class=args.target_contexts_per_class,
|
| 350 |
+
positive_balance_min_class=args.positive_balance_min_class,
|
| 351 |
+
positive_ratio=args.positive_context_ratio,
|
| 352 |
+
)
|
| 353 |
+
for class_id in class_targets
|
| 354 |
+
}
|
| 355 |
|
| 356 |
+
target_total = args.target_contexts_per_class * len(class_targets)
|
| 357 |
+
print(f"INFO: Target total: {target_total}, Target per class: {args.target_contexts_per_class}")
|
| 358 |
print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
|
| 359 |
print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
|
| 360 |
|
| 361 |
+
tasks = list(enumerate(filtered_mints))
|
| 362 |
+
random.shuffle(tasks)
|
| 363 |
+
print(f"INFO: Total candidate tokens: {len(tasks)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 366 |
+
class_distribution = defaultdict(int)
|
| 367 |
+
polarity_distribution = defaultdict(int)
|
| 368 |
+
file_class_map = {}
|
| 369 |
+
file_context_bucket_map = {}
|
| 370 |
+
file_context_summary_map = {}
|
| 371 |
+
accepted_counts = defaultdict(lambda: {"total": 0, "positive": 0, "negative": 0})
|
| 372 |
|
| 373 |
+
# Resume support: count existing files toward quotas.
|
| 374 |
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
| 375 |
if existing_files:
|
|
|
|
|
|
|
| 376 |
already_cached = 0
|
| 377 |
+
for f in sorted(output_dir.glob("sample_*.pt")):
|
| 378 |
+
try:
|
| 379 |
+
cached = torch.load(f, map_location="cpu", weights_only=False)
|
| 380 |
+
except Exception:
|
| 381 |
+
continue
|
| 382 |
+
class_id = cached.get("class_id")
|
| 383 |
+
if class_id is None or int(class_id) not in class_targets:
|
| 384 |
+
continue
|
| 385 |
+
class_id = int(class_id)
|
| 386 |
+
context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
|
| 387 |
+
polarity = _representative_context_polarity(cached)
|
| 388 |
+
file_class_map[f.name] = class_id
|
| 389 |
+
file_context_bucket_map[f.name] = context_summary["context_bucket"]
|
| 390 |
+
file_context_summary_map[f.name] = context_summary
|
| 391 |
+
class_distribution[class_id] += 1
|
| 392 |
+
polarity_distribution[polarity] += 1
|
| 393 |
+
accepted_counts[class_id]["total"] += 1
|
| 394 |
+
accepted_counts[class_id][polarity] += 1
|
| 395 |
+
already_cached += 1
|
| 396 |
+
print(f"INFO: Resume: counted {already_cached} cached contexts toward quotas.")
|
| 397 |
|
| 398 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 399 |
process_fn = _process_single_token_context
|
|
|
|
| 421 |
error_log_path = Path(args.output_dir) / "cache_errors.log"
|
| 422 |
error_samples = [] # First 20 unique error messages
|
| 423 |
|
| 424 |
+
print("INFO: Single-threaded mode...")
|
| 425 |
+
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
| 426 |
+
start_time = _time.perf_counter()
|
| 427 |
+
recent_times = []
|
| 428 |
+
completed_classes = set()
|
| 429 |
+
for task_num, (idx, mint_record) in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
|
| 430 |
+
mint_addr = mint_record["mint_address"]
|
| 431 |
+
class_id = return_class_map.get(mint_addr)
|
| 432 |
+
if class_id is None:
|
| 433 |
+
skipped_count += 1
|
| 434 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 435 |
+
continue
|
| 436 |
+
if accepted_counts[class_id]["total"] >= class_targets[class_id]:
|
| 437 |
+
if class_id not in completed_classes:
|
| 438 |
+
completed_classes.add(class_id)
|
| 439 |
+
tqdm.write(f"INFO: Class {class_id} quota filled. Skipping remaining tokens for this class.")
|
| 440 |
+
skipped_count += 1
|
| 441 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
remaining_total = class_targets[class_id] - accepted_counts[class_id]["total"]
|
| 445 |
+
remaining_polarity = _remaining_polarity_targets(
|
| 446 |
+
class_id=class_id,
|
| 447 |
+
accepted_counts=accepted_counts,
|
| 448 |
+
target_contexts_per_class=args.target_contexts_per_class,
|
| 449 |
+
positive_balance_min_class=args.positive_balance_min_class,
|
| 450 |
+
positive_ratio=args.positive_context_ratio,
|
| 451 |
+
)
|
| 452 |
+
desired_positive = min(remaining_polarity["positive"], args.samples_per_token, remaining_total)
|
| 453 |
+
desired_negative = min(
|
| 454 |
+
remaining_polarity["negative"],
|
| 455 |
+
max(0, min(args.samples_per_token, remaining_total) - desired_positive),
|
| 456 |
+
)
|
| 457 |
+
samples_to_keep = min(args.samples_per_token, remaining_total)
|
| 458 |
+
task = (
|
| 459 |
+
idx,
|
| 460 |
+
mint_addr,
|
| 461 |
+
samples_to_keep,
|
| 462 |
+
args.context_oversample_factor,
|
| 463 |
+
desired_positive,
|
| 464 |
+
desired_negative,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
t0 = _time.perf_counter()
|
| 468 |
+
result = process_fn(task)
|
| 469 |
+
elapsed = _time.perf_counter() - t0
|
| 470 |
+
recent_times.append(elapsed)
|
| 471 |
+
if len(recent_times) > 50:
|
| 472 |
+
recent_times.pop(0)
|
| 473 |
+
|
| 474 |
+
if result["status"] == "success":
|
| 475 |
+
saved_contexts = 0
|
| 476 |
+
for ctx in result.get("contexts", []):
|
| 477 |
+
if accepted_counts[class_id]["total"] >= class_targets[class_id]:
|
| 478 |
+
break
|
| 479 |
+
|
| 480 |
+
polarity = _representative_context_polarity(ctx)
|
| 481 |
+
remaining_polarity = _remaining_polarity_targets(
|
| 482 |
+
class_id=class_id,
|
| 483 |
+
accepted_counts=accepted_counts,
|
| 484 |
+
target_contexts_per_class=args.target_contexts_per_class,
|
| 485 |
+
positive_balance_min_class=args.positive_balance_min_class,
|
| 486 |
+
positive_ratio=args.positive_context_ratio,
|
| 487 |
+
)
|
| 488 |
+
other_polarity = "negative" if polarity == "positive" else "positive"
|
| 489 |
+
if remaining_polarity[polarity] <= 0 and remaining_polarity[other_polarity] > 0:
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
ctx["quality_score"] = result["q_score"]
|
| 493 |
+
ctx["class_id"] = class_id
|
| 494 |
+
ctx["source_token"] = mint_addr
|
| 495 |
+
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
|
| 496 |
+
ctx["context_bucket"] = context_summary["context_bucket"]
|
| 497 |
+
ctx["context_score"] = context_summary["context_score"]
|
| 498 |
+
file_idx = accepted_counts[class_id]["total"]
|
| 499 |
+
filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
|
| 500 |
+
output_path = output_dir / filename
|
| 501 |
+
torch.save(ctx, output_path)
|
| 502 |
+
|
| 503 |
+
file_class_map[filename] = class_id
|
| 504 |
+
file_context_bucket_map[filename] = context_summary["context_bucket"]
|
| 505 |
+
file_context_summary_map[filename] = context_summary
|
| 506 |
+
class_distribution[class_id] += 1
|
| 507 |
+
polarity_distribution[polarity] += 1
|
| 508 |
+
accepted_counts[class_id]["total"] += 1
|
| 509 |
+
accepted_counts[class_id][polarity] += 1
|
| 510 |
+
saved_contexts += 1
|
| 511 |
+
|
| 512 |
+
if saved_contexts > 0:
|
| 513 |
success_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
else:
|
| 515 |
+
skipped_count += 1
|
| 516 |
+
elif result["status"] == "skipped":
|
| 517 |
+
skipped_count += 1
|
| 518 |
+
else:
|
| 519 |
+
error_count += 1
|
| 520 |
+
err_msg = result.get("error", "unknown")
|
| 521 |
+
tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
|
| 522 |
+
if len(error_samples) < 20:
|
| 523 |
+
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 524 |
+
|
| 525 |
+
if all(accepted_counts[cid]["total"] >= class_targets[cid] for cid in class_targets):
|
| 526 |
+
tqdm.write("INFO: All class quotas filled. Stopping early.")
|
| 527 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 528 |
+
break
|
| 529 |
+
|
| 530 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
# Write error log
|
| 533 |
if error_samples:
|
|
|
|
| 539 |
print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
|
| 540 |
|
| 541 |
print("INFO: Building metadata...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 543 |
json.dump({
|
| 544 |
'file_class_map': file_class_map,
|
| 545 |
+
'file_context_bucket_map': file_context_bucket_map,
|
| 546 |
+
'file_context_summary_map': file_context_summary_map,
|
| 547 |
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 548 |
'num_workers': args.num_workers,
|
| 549 |
'horizons_seconds': args.horizons_seconds,
|
| 550 |
'quantiles': args.quantiles,
|
| 551 |
'target_total': target_total,
|
| 552 |
+
'target_contexts_per_class': args.target_contexts_per_class,
|
|
|
|
| 553 |
'context_polarity_distribution': polarity_distribution,
|
| 554 |
'class_targets': {str(k): v for k, v in class_targets.items()},
|
| 555 |
'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
|
| 556 |
+
'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
|
| 557 |
'positive_balance_min_class': args.positive_balance_min_class,
|
| 558 |
'positive_context_ratio': args.positive_context_ratio,
|
| 559 |
+
'quant_feature_version': FEATURE_VERSION,
|
| 560 |
}, f, indent=2)
|
| 561 |
|
| 562 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
scripts/dump_cache_sample.py
CHANGED
|
@@ -121,7 +121,7 @@ def main():
|
|
| 121 |
"__metadata__": {
|
| 122 |
"source_file": str(filepath.absolute()),
|
| 123 |
"dumped_at": datetime.now().isoformat(),
|
| 124 |
-
"
|
| 125 |
},
|
| 126 |
"data": serializable_data
|
| 127 |
}
|
|
@@ -143,7 +143,7 @@ def main():
|
|
| 143 |
if isinstance(data, dict):
|
| 144 |
print(f"\n=== Summary ===")
|
| 145 |
print(f"Top-level keys: {list(data.keys())}")
|
| 146 |
-
print(f"Cache
|
| 147 |
if 'event_sequence' in data:
|
| 148 |
print(f"Event count: {len(data['event_sequence'])}")
|
| 149 |
if 'trades' in data:
|
|
@@ -152,6 +152,10 @@ def main():
|
|
| 152 |
print(f"Source token: {data['source_token']}")
|
| 153 |
if 'class_id' in data:
|
| 154 |
print(f"Class ID: {data['class_id']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if 'quality_score' in data:
|
| 156 |
print(f"Quality score: {data['quality_score']}")
|
| 157 |
|
|
|
|
| 121 |
"__metadata__": {
|
| 122 |
"source_file": str(filepath.absolute()),
|
| 123 |
"dumped_at": datetime.now().isoformat(),
|
| 124 |
+
"cache_format": "context" if isinstance(data, dict) and "event_sequence" in data else "legacy"
|
| 125 |
},
|
| 126 |
"data": serializable_data
|
| 127 |
}
|
|
|
|
| 143 |
if isinstance(data, dict):
|
| 144 |
print(f"\n=== Summary ===")
|
| 145 |
print(f"Top-level keys: {list(data.keys())}")
|
| 146 |
+
print(f"Cache format: {'context' if 'event_sequence' in data else 'legacy'}")
|
| 147 |
if 'event_sequence' in data:
|
| 148 |
print(f"Event count: {len(data['event_sequence'])}")
|
| 149 |
if 'trades' in data:
|
|
|
|
| 152 |
print(f"Source token: {data['source_token']}")
|
| 153 |
if 'class_id' in data:
|
| 154 |
print(f"Class ID: {data['class_id']}")
|
| 155 |
+
if 'context_bucket' in data:
|
| 156 |
+
print(f"Context bucket: {data['context_bucket']}")
|
| 157 |
+
if 'context_score' in data:
|
| 158 |
+
print(f"Context score: {data['context_score']}")
|
| 159 |
if 'quality_score' in data:
|
| 160 |
print(f"Quality score: {data['quality_score']}")
|
| 161 |
|
scripts/evaluate_sample.py
CHANGED
|
@@ -23,8 +23,10 @@ from models.token_encoder import TokenEncoder
|
|
| 23 |
from models.wallet_encoder import WalletEncoder
|
| 24 |
from models.graph_updater import GraphUpdater
|
| 25 |
from models.ohlc_embedder import OHLCEmbedder
|
|
|
|
| 26 |
from models.model import Oracle
|
| 27 |
import models.vocabulary as vocab
|
|
|
|
| 28 |
from train import create_balanced_split
|
| 29 |
from dotenv import load_dotenv
|
| 30 |
from clickhouse_driver import Client as ClickHouseClient
|
|
@@ -43,6 +45,12 @@ ABLATION_SWEEP_MODES = [
|
|
| 43 |
"trade",
|
| 44 |
"onchain",
|
| 45 |
"wallet_graph",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
]
|
| 47 |
|
| 48 |
OHLC_PROBE_MODES = [
|
|
@@ -77,7 +85,7 @@ def parse_args():
|
|
| 77 |
"--ablation",
|
| 78 |
type=str,
|
| 79 |
default="none",
|
| 80 |
-
choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe"],
|
| 81 |
help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
|
| 82 |
)
|
| 83 |
return parser.parse_args()
|
|
@@ -164,6 +172,23 @@ def apply_ablation(batch, mode, device):
|
|
| 164 |
ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
|
| 165 |
if "ohlc_interval_ids" in ablated:
|
| 166 |
ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
if mode in {"trade", "all"}:
|
| 169 |
for key in (
|
|
@@ -627,6 +652,11 @@ def main():
|
|
| 627 |
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
|
| 628 |
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
|
| 629 |
ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
collator = MemecoinCollator(
|
| 632 |
event_type_to_id=vocab.EVENT_TO_ID,
|
|
@@ -641,6 +671,7 @@ def main():
|
|
| 641 |
wallet_encoder=wallet_encoder,
|
| 642 |
graph_updater=graph_updater,
|
| 643 |
ohlc_embedder=ohlc_embedder,
|
|
|
|
| 644 |
time_encoder=time_encoder,
|
| 645 |
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 646 |
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
|
|
|
| 23 |
from models.wallet_encoder import WalletEncoder
|
| 24 |
from models.graph_updater import GraphUpdater
|
| 25 |
from models.ohlc_embedder import OHLCEmbedder
|
| 26 |
+
from models.quant_ohlc_embedder import QuantOHLCEmbedder
|
| 27 |
from models.model import Oracle
|
| 28 |
import models.vocabulary as vocab
|
| 29 |
+
from data.quant_ohlc_feature_schema import FEATURE_GROUPS, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT, group_feature_indices
|
| 30 |
from train import create_balanced_split
|
| 31 |
from dotenv import load_dotenv
|
| 32 |
from clickhouse_driver import Client as ClickHouseClient
|
|
|
|
| 45 |
"trade",
|
| 46 |
"onchain",
|
| 47 |
"wallet_graph",
|
| 48 |
+
"quant_ohlc",
|
| 49 |
+
"quant_levels",
|
| 50 |
+
"quant_trendline",
|
| 51 |
+
"quant_breaks",
|
| 52 |
+
"quant_rolling",
|
| 53 |
+
"quant_patterns",
|
| 54 |
]
|
| 55 |
|
| 56 |
OHLC_PROBE_MODES = [
|
|
|
|
| 85 |
"--ablation",
|
| 86 |
type=str,
|
| 87 |
default="none",
|
| 88 |
+
choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling", "quant_patterns"],
|
| 89 |
help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
|
| 90 |
)
|
| 91 |
return parser.parse_args()
|
|
|
|
| 172 |
ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
|
| 173 |
if "ohlc_interval_ids" in ablated:
|
| 174 |
ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
|
| 175 |
+
if "quant_ohlc_feature_tensors" in ablated:
|
| 176 |
+
ablated["quant_ohlc_feature_tensors"] = torch.zeros_like(ablated["quant_ohlc_feature_tensors"])
|
| 177 |
+
if "quant_ohlc_feature_mask" in ablated:
|
| 178 |
+
ablated["quant_ohlc_feature_mask"] = torch.zeros_like(ablated["quant_ohlc_feature_mask"])
|
| 179 |
+
|
| 180 |
+
quant_group_map = {
|
| 181 |
+
"quant_ohlc": list(FEATURE_GROUPS.keys()),
|
| 182 |
+
"quant_levels": ["levels_breaks"],
|
| 183 |
+
"quant_trendline": ["trendlines"],
|
| 184 |
+
"quant_breaks": ["relative_structure", "levels_breaks"],
|
| 185 |
+
"quant_rolling": ["rolling_quant"],
|
| 186 |
+
"quant_patterns": ["patterns"],
|
| 187 |
+
}
|
| 188 |
+
if mode in quant_group_map and "quant_ohlc_feature_tensors" in ablated:
|
| 189 |
+
idxs = group_feature_indices(quant_group_map[mode])
|
| 190 |
+
if idxs:
|
| 191 |
+
ablated["quant_ohlc_feature_tensors"][:, :, idxs] = 0
|
| 192 |
|
| 193 |
if mode in {"trade", "all"}:
|
| 194 |
for key in (
|
|
|
|
| 652 |
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
|
| 653 |
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
|
| 654 |
ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
|
| 655 |
+
quant_ohlc_embedder = QuantOHLCEmbedder(
|
| 656 |
+
num_features=NUM_QUANT_OHLC_FEATURES,
|
| 657 |
+
sequence_length=TOKENS_PER_SEGMENT,
|
| 658 |
+
dtype=init_dtype,
|
| 659 |
+
)
|
| 660 |
|
| 661 |
collator = MemecoinCollator(
|
| 662 |
event_type_to_id=vocab.EVENT_TO_ID,
|
|
|
|
| 671 |
wallet_encoder=wallet_encoder,
|
| 672 |
graph_updater=graph_updater,
|
| 673 |
ohlc_embedder=ohlc_embedder,
|
| 674 |
+
quant_ohlc_embedder=quant_ohlc_embedder,
|
| 675 |
time_encoder=time_encoder,
|
| 676 |
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 677 |
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
scripts/rebuild_metadata.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from collections import defaultdict
|
| 8 |
from data.data_loader import summarize_context_window
|
|
|
|
| 9 |
|
| 10 |
def rebuild_metadata(cache_dir="data/cache"):
|
| 11 |
cache_path = Path(cache_dir)
|
|
@@ -52,6 +53,7 @@ def rebuild_metadata(cache_dir="data/cache"):
|
|
| 52 |
'num_workers': 1,
|
| 53 |
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
|
| 54 |
'quantiles': [0.1, 0.5, 0.9],
|
|
|
|
| 55 |
}
|
| 56 |
|
| 57 |
out_file = cache_path / "class_metadata.json"
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from collections import defaultdict
|
| 8 |
from data.data_loader import summarize_context_window
|
| 9 |
+
from data.quant_ohlc_feature_schema import FEATURE_VERSION
|
| 10 |
|
| 11 |
def rebuild_metadata(cache_dir="data/cache"):
|
| 12 |
cache_path = Path(cache_dir)
|
|
|
|
| 53 |
'num_workers': 1,
|
| 54 |
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
|
| 55 |
'quantiles': [0.1, 0.5, 0.9],
|
| 56 |
+
'quant_feature_version': FEATURE_VERSION,
|
| 57 |
}
|
| 58 |
|
| 59 |
out_file = cache_path / "class_metadata.json"
|
signals/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Deterministic chart signal extraction package.
|
signals/patterns.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.signal import find_peaks
|
| 5 |
+
|
| 6 |
+
from data.quant_ohlc_feature_schema import PATTERN_NAMES
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _empty_pattern_output() -> Dict[str, float]:
|
| 10 |
+
out = {f"pattern_{name}_confidence": 0.0 for name in PATTERN_NAMES}
|
| 11 |
+
out["pattern_available"] = 1.0
|
| 12 |
+
return out
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _confidence_from_error(error: float, tolerance: float) -> float:
|
| 16 |
+
if tolerance <= 1e-8:
|
| 17 |
+
return 0.0
|
| 18 |
+
return float(max(0.0, min(1.0, 1.0 - (error / tolerance))))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _recent_prominent_peaks(series: np.ndarray, distance: int, prominence: float) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
|
| 22 |
+
peaks, props = find_peaks(series, distance=distance, prominence=prominence)
|
| 23 |
+
if peaks.size == 0:
|
| 24 |
+
return peaks, props
|
| 25 |
+
order = np.argsort(props["prominences"])
|
| 26 |
+
keep = order[-5:]
|
| 27 |
+
keep_sorted = np.sort(keep)
|
| 28 |
+
peaks = peaks[keep_sorted]
|
| 29 |
+
props = {key: value[keep_sorted] for key, value in props.items()}
|
| 30 |
+
return peaks, props
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _double_top_confidence(highs: np.ndarray, current_price: float, tolerance: float) -> float:
|
| 34 |
+
peaks, props = _recent_prominent_peaks(highs, distance=3, prominence=tolerance * 0.5)
|
| 35 |
+
if peaks.size < 2:
|
| 36 |
+
return 0.0
|
| 37 |
+
top1_idx, top2_idx = peaks[-2], peaks[-1]
|
| 38 |
+
top1 = float(highs[top1_idx])
|
| 39 |
+
top2 = float(highs[top2_idx])
|
| 40 |
+
neckline = float(np.min(highs[top1_idx:top2_idx + 1])) if top2_idx > top1_idx else min(top1, top2)
|
| 41 |
+
if current_price > max(top1, top2):
|
| 42 |
+
return 0.0
|
| 43 |
+
symmetry = _confidence_from_error(abs(top1 - top2), tolerance)
|
| 44 |
+
separation = min(1.0, float(top2_idx - top1_idx) / 8.0)
|
| 45 |
+
breakdown = 1.0 if current_price <= neckline else 0.6
|
| 46 |
+
prominence = min(1.0, float(np.mean(props["prominences"][-2:])) / max(tolerance, 1e-8))
|
| 47 |
+
return float(max(0.0, min(1.0, symmetry * separation * breakdown * prominence)))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _double_bottom_confidence(lows: np.ndarray, current_price: float, tolerance: float) -> float:
|
| 51 |
+
troughs, props = _recent_prominent_peaks(-lows, distance=3, prominence=tolerance * 0.5)
|
| 52 |
+
if troughs.size < 2:
|
| 53 |
+
return 0.0
|
| 54 |
+
low1_idx, low2_idx = troughs[-2], troughs[-1]
|
| 55 |
+
low1 = float(lows[low1_idx])
|
| 56 |
+
low2 = float(lows[low2_idx])
|
| 57 |
+
ceiling = float(np.max(lows[low1_idx:low2_idx + 1])) if low2_idx > low1_idx else max(low1, low2)
|
| 58 |
+
if current_price < min(low1, low2):
|
| 59 |
+
return 0.0
|
| 60 |
+
symmetry = _confidence_from_error(abs(low1 - low2), tolerance)
|
| 61 |
+
separation = min(1.0, float(low2_idx - low1_idx) / 8.0)
|
| 62 |
+
breakout = 1.0 if current_price >= ceiling else 0.6
|
| 63 |
+
prominence = min(1.0, float(np.mean(props["prominences"][-2:])) / max(tolerance, 1e-8))
|
| 64 |
+
return float(max(0.0, min(1.0, symmetry * separation * breakout * prominence)))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _triangle_confidences(highs: np.ndarray, lows: np.ndarray, tolerance: float) -> Dict[str, float]:
|
| 68 |
+
out = {
|
| 69 |
+
"ascending_triangle": 0.0,
|
| 70 |
+
"descending_triangle": 0.0,
|
| 71 |
+
}
|
| 72 |
+
peak_idx, _ = _recent_prominent_peaks(highs, distance=3, prominence=tolerance * 0.5)
|
| 73 |
+
trough_idx, _ = _recent_prominent_peaks(-lows, distance=3, prominence=tolerance * 0.5)
|
| 74 |
+
if peak_idx.size < 2 or trough_idx.size < 2:
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
peak_vals = highs[peak_idx[-3:]]
|
| 78 |
+
trough_vals = lows[trough_idx[-3:]]
|
| 79 |
+
peak_slope = np.polyfit(np.arange(len(peak_vals), dtype=np.float64), peak_vals.astype(np.float64), deg=1)[0]
|
| 80 |
+
trough_slope = np.polyfit(np.arange(len(trough_vals), dtype=np.float64), trough_vals.astype(np.float64), deg=1)[0]
|
| 81 |
+
peak_flatness = _confidence_from_error(float(np.max(peak_vals) - np.min(peak_vals)), tolerance)
|
| 82 |
+
trough_flatness = _confidence_from_error(float(np.max(trough_vals) - np.min(trough_vals)), tolerance)
|
| 83 |
+
|
| 84 |
+
out["ascending_triangle"] = float(max(0.0, min(1.0, peak_flatness * max(0.0, trough_slope) / max(tolerance, 1e-8))))
|
| 85 |
+
out["descending_triangle"] = float(max(0.0, min(1.0, trough_flatness * max(0.0, -peak_slope) / max(tolerance, 1e-8))))
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _head_shoulders_confidence(highs: np.ndarray, lows: np.ndarray, tolerance: float, inverse: bool = False) -> float:
|
| 90 |
+
series = -lows if inverse else highs
|
| 91 |
+
pivots, props = _recent_prominent_peaks(series, distance=3, prominence=tolerance * 0.5)
|
| 92 |
+
if pivots.size < 3:
|
| 93 |
+
return 0.0
|
| 94 |
+
idxs = pivots[-3:]
|
| 95 |
+
values = series[idxs]
|
| 96 |
+
left, head, right = [float(v) for v in values]
|
| 97 |
+
shoulders_match = _confidence_from_error(abs(left - right), tolerance)
|
| 98 |
+
if inverse:
|
| 99 |
+
head_margin = max(0.0, min(left, right) - head)
|
| 100 |
+
else:
|
| 101 |
+
head_margin = max(0.0, head - max(left, right))
|
| 102 |
+
head_score = min(1.0, head_margin / max(tolerance, 1e-8))
|
| 103 |
+
spacing = min(1.0, float(min(idxs[1] - idxs[0], idxs[2] - idxs[1])) / 5.0)
|
| 104 |
+
prominence = min(1.0, float(np.mean(props["prominences"][-3:])) / max(tolerance, 1e-8))
|
| 105 |
+
return float(max(0.0, min(1.0, shoulders_match * head_score * spacing * prominence)))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compute_pattern_features(closes, highs, lows, end_idx: int) -> Dict[str, float]:
|
| 109 |
+
out = _empty_pattern_output()
|
| 110 |
+
closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
|
| 111 |
+
highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
|
| 112 |
+
lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
|
| 113 |
+
if closes_np.size < 10:
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
current_price = float(closes_np[-1])
|
| 117 |
+
tolerance = max(float(np.std(closes_np[-20:])) if closes_np.size >= 20 else float(np.std(closes_np)), current_price * 0.003, 1e-5)
|
| 118 |
+
|
| 119 |
+
out["pattern_double_top_confidence"] = _double_top_confidence(highs_np, current_price, tolerance)
|
| 120 |
+
out["pattern_double_bottom_confidence"] = _double_bottom_confidence(lows_np, current_price, tolerance)
|
| 121 |
+
|
| 122 |
+
triangle = _triangle_confidences(highs_np, lows_np, tolerance)
|
| 123 |
+
out["pattern_ascending_triangle_confidence"] = triangle["ascending_triangle"]
|
| 124 |
+
out["pattern_descending_triangle_confidence"] = triangle["descending_triangle"]
|
| 125 |
+
out["pattern_head_shoulders_confidence"] = _head_shoulders_confidence(highs_np, lows_np, tolerance, inverse=False)
|
| 126 |
+
out["pattern_inverse_head_shoulders_confidence"] = _head_shoulders_confidence(highs_np, lows_np, tolerance, inverse=True)
|
| 127 |
+
return out
|
signals/rolling_quant.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from ta.trend import ema_indicator, sma_indicator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _finite_or_zero(value: float) -> float:
|
| 9 |
+
try:
|
| 10 |
+
value = float(value)
|
| 11 |
+
except Exception:
|
| 12 |
+
return 0.0
|
| 13 |
+
if not np.isfinite(value):
|
| 14 |
+
return 0.0
|
| 15 |
+
return value
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def compute_rolling_quant_features(closes: List[float], end_idx: int) -> Dict[str, float]:
|
| 19 |
+
closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
|
| 20 |
+
if closes_np.size == 0:
|
| 21 |
+
return {
|
| 22 |
+
"ema_fast": 0.0,
|
| 23 |
+
"ema_medium": 0.0,
|
| 24 |
+
"sma_fast": 0.0,
|
| 25 |
+
"sma_medium": 0.0,
|
| 26 |
+
"price_minus_ema_fast": 0.0,
|
| 27 |
+
"price_minus_ema_medium": 0.0,
|
| 28 |
+
"ema_spread": 0.0,
|
| 29 |
+
"price_zscore": 0.0,
|
| 30 |
+
"mean_reversion_score": 0.0,
|
| 31 |
+
"rolling_vol_zscore": 0.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
close_series = pd.Series(closes_np)
|
| 35 |
+
current = float(closes_np[-1])
|
| 36 |
+
current_scale = max(abs(current), 1e-8)
|
| 37 |
+
|
| 38 |
+
ema_fast = _finite_or_zero(ema_indicator(close_series, window=8, fillna=True).iloc[-1])
|
| 39 |
+
ema_medium = _finite_or_zero(ema_indicator(close_series, window=21, fillna=True).iloc[-1])
|
| 40 |
+
sma_fast = _finite_or_zero(sma_indicator(close_series, window=8, fillna=True).iloc[-1])
|
| 41 |
+
sma_medium = _finite_or_zero(sma_indicator(close_series, window=21, fillna=True).iloc[-1])
|
| 42 |
+
|
| 43 |
+
mean_all = _finite_or_zero(close_series.mean())
|
| 44 |
+
std_all = _finite_or_zero(close_series.std(ddof=0))
|
| 45 |
+
price_zscore = 0.0 if std_all <= 1e-8 else (current - mean_all) / std_all
|
| 46 |
+
|
| 47 |
+
log_returns = np.diff(np.log(np.clip(closes_np, 1e-8, None)))
|
| 48 |
+
if log_returns.size == 0:
|
| 49 |
+
rolling_vol = 0.0
|
| 50 |
+
mean_vol = 0.0
|
| 51 |
+
std_vol = 0.0
|
| 52 |
+
else:
|
| 53 |
+
abs_log_returns = np.abs(log_returns)
|
| 54 |
+
rolling_vol = _finite_or_zero(np.std(log_returns[-20:]))
|
| 55 |
+
mean_vol = _finite_or_zero(np.mean(abs_log_returns))
|
| 56 |
+
std_vol = _finite_or_zero(np.std(abs_log_returns))
|
| 57 |
+
rolling_vol_zscore = 0.0 if std_vol <= 1e-8 else (rolling_vol - mean_vol) / std_vol
|
| 58 |
+
|
| 59 |
+
denom = max(abs(sma_medium), 1e-8)
|
| 60 |
+
return {
|
| 61 |
+
"ema_fast": ema_fast / current_scale,
|
| 62 |
+
"ema_medium": ema_medium / current_scale,
|
| 63 |
+
"sma_fast": sma_fast / current_scale,
|
| 64 |
+
"sma_medium": sma_medium / current_scale,
|
| 65 |
+
"price_minus_ema_fast": (current - ema_fast) / current_scale,
|
| 66 |
+
"price_minus_ema_medium": (current - ema_medium) / current_scale,
|
| 67 |
+
"ema_spread": (ema_fast - ema_medium) / current_scale,
|
| 68 |
+
"price_zscore": _finite_or_zero(price_zscore),
|
| 69 |
+
"mean_reversion_score": (mean_all - current) / denom,
|
| 70 |
+
"rolling_vol_zscore": _finite_or_zero(rolling_vol_zscore),
|
| 71 |
+
}
|
signals/support_resistance.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.signal import argrelextrema
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _compute_pivots(prices: np.ndarray, order: int = 2) -> Dict[str, List[int]]:
|
| 8 |
+
if prices.size < (2 * order + 1):
|
| 9 |
+
return {"highs": [], "lows": []}
|
| 10 |
+
highs = argrelextrema(prices, np.greater_equal, order=order, mode="clip")[0].tolist()
|
| 11 |
+
lows = argrelextrema(prices, np.less_equal, order=order, mode="clip")[0].tolist()
|
| 12 |
+
highs = [idx for idx in highs if order <= idx < len(prices) - order]
|
| 13 |
+
lows = [idx for idx in lows if order <= idx < len(prices) - order]
|
| 14 |
+
return {"highs": highs, "lows": lows}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _cluster_levels(prices: np.ndarray, pivot_indices: List[int], tolerance: float) -> List[Dict[str, float]]:
|
| 18 |
+
levels: List[Dict[str, float]] = []
|
| 19 |
+
for pivot_idx in pivot_indices:
|
| 20 |
+
price = float(prices[pivot_idx])
|
| 21 |
+
matched = None
|
| 22 |
+
for level in levels:
|
| 23 |
+
if abs(price - level["price"]) <= tolerance:
|
| 24 |
+
matched = level
|
| 25 |
+
break
|
| 26 |
+
if matched is None:
|
| 27 |
+
levels.append({
|
| 28 |
+
"price": price,
|
| 29 |
+
"touches": 1.0,
|
| 30 |
+
"last_idx": float(pivot_idx),
|
| 31 |
+
"first_idx": float(pivot_idx),
|
| 32 |
+
})
|
| 33 |
+
continue
|
| 34 |
+
touches = matched["touches"] + 1.0
|
| 35 |
+
matched["price"] = ((matched["price"] * matched["touches"]) + price) / touches
|
| 36 |
+
matched["touches"] = touches
|
| 37 |
+
matched["last_idx"] = float(pivot_idx)
|
| 38 |
+
return levels
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _nearest_level(levels: List[Dict[str, float]], current_price: float, below: bool) -> Optional[Dict[str, float]]:
|
| 42 |
+
candidates = [level for level in levels if (level["price"] <= current_price if below else level["price"] >= current_price)]
|
| 43 |
+
if not candidates:
|
| 44 |
+
return None
|
| 45 |
+
return min(candidates, key=lambda level: abs(level["price"] - current_price))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_support_resistance_features(
|
| 49 |
+
closes: List[float],
|
| 50 |
+
highs: List[float],
|
| 51 |
+
lows: List[float],
|
| 52 |
+
end_idx: int,
|
| 53 |
+
window_start: int,
|
| 54 |
+
window_end: int,
|
| 55 |
+
timestamps: List[int],
|
| 56 |
+
) -> Dict[str, float]:
|
| 57 |
+
closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
|
| 58 |
+
highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
|
| 59 |
+
lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
|
| 60 |
+
if closes_np.size == 0:
|
| 61 |
+
return {key: 0.0 for key in [
|
| 62 |
+
"nearest_support_dist", "nearest_resistance_dist", "support_touch_count",
|
| 63 |
+
"resistance_touch_count", "support_age_sec", "resistance_age_sec",
|
| 64 |
+
"support_strength", "resistance_strength", "inside_support_zone",
|
| 65 |
+
"inside_resistance_zone", "support_swept", "resistance_swept",
|
| 66 |
+
"support_reclaim", "resistance_reject", "sr_available",
|
| 67 |
+
]}
|
| 68 |
+
|
| 69 |
+
current_price = float(closes_np[-1])
|
| 70 |
+
local_range = max(float(np.max(highs_np) - np.min(lows_np)), current_price * 1e-3, 1e-8)
|
| 71 |
+
tolerance = max(local_range * 0.08, current_price * 0.0025)
|
| 72 |
+
|
| 73 |
+
pivots_high = _compute_pivots(highs_np, order=2)["highs"]
|
| 74 |
+
pivots_low = _compute_pivots(lows_np, order=2)["lows"]
|
| 75 |
+
support_levels = _cluster_levels(lows_np, pivots_low, tolerance)
|
| 76 |
+
resistance_levels = _cluster_levels(highs_np, pivots_high, tolerance)
|
| 77 |
+
|
| 78 |
+
support = _nearest_level(support_levels, current_price, below=True)
|
| 79 |
+
resistance = _nearest_level(resistance_levels, current_price, below=False)
|
| 80 |
+
current_ts = float(timestamps[min(end_idx, len(timestamps) - 1)]) if timestamps else float(end_idx)
|
| 81 |
+
|
| 82 |
+
def _level_stats(level: Optional[Dict[str, float]], is_support: bool) -> Dict[str, float]:
|
| 83 |
+
if level is None:
|
| 84 |
+
return {
|
| 85 |
+
"dist": 0.0,
|
| 86 |
+
"touch_count": 0.0,
|
| 87 |
+
"age_sec": 0.0,
|
| 88 |
+
"strength": 0.0,
|
| 89 |
+
"inside_zone": 0.0,
|
| 90 |
+
"swept": 0.0,
|
| 91 |
+
"confirm": 0.0,
|
| 92 |
+
}
|
| 93 |
+
level_price = float(level["price"])
|
| 94 |
+
zone_half_width = max(tolerance, abs(level_price) * 0.002)
|
| 95 |
+
window_prices_high = highs[window_start:window_end]
|
| 96 |
+
window_prices_low = lows[window_start:window_end]
|
| 97 |
+
swept = 0.0
|
| 98 |
+
confirm = 0.0
|
| 99 |
+
if is_support:
|
| 100 |
+
min_low = min(window_prices_low) if window_prices_low else current_price
|
| 101 |
+
swept = 1.0 if min_low < (level_price - zone_half_width) else 0.0
|
| 102 |
+
confirm = 1.0 if swept > 0 and current_price >= level_price else 0.0
|
| 103 |
+
else:
|
| 104 |
+
max_high = max(window_prices_high) if window_prices_high else current_price
|
| 105 |
+
swept = 1.0 if max_high > (level_price + zone_half_width) else 0.0
|
| 106 |
+
confirm = 1.0 if swept > 0 and current_price <= level_price else 0.0
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"dist": (current_price - level_price) / max(abs(current_price), 1e-8) if is_support else (level_price - current_price) / max(abs(current_price), 1e-8),
|
| 110 |
+
"touch_count": float(level["touches"]),
|
| 111 |
+
"age_sec": max(0.0, current_ts - float(timestamps[int(level["last_idx"])]) if timestamps else current_ts - level["last_idx"]),
|
| 112 |
+
"strength": float(level["touches"]) / max(1.0, float(end_idx + 1)),
|
| 113 |
+
"inside_zone": 1.0 if abs(current_price - level_price) <= zone_half_width else 0.0,
|
| 114 |
+
"swept": swept,
|
| 115 |
+
"confirm": confirm,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
support_stats = _level_stats(support, True)
|
| 119 |
+
resistance_stats = _level_stats(resistance, False)
|
| 120 |
+
return {
|
| 121 |
+
"nearest_support_dist": support_stats["dist"],
|
| 122 |
+
"nearest_resistance_dist": resistance_stats["dist"],
|
| 123 |
+
"support_touch_count": support_stats["touch_count"],
|
| 124 |
+
"resistance_touch_count": resistance_stats["touch_count"],
|
| 125 |
+
"support_age_sec": support_stats["age_sec"],
|
| 126 |
+
"resistance_age_sec": resistance_stats["age_sec"],
|
| 127 |
+
"support_strength": support_stats["strength"],
|
| 128 |
+
"resistance_strength": resistance_stats["strength"],
|
| 129 |
+
"inside_support_zone": support_stats["inside_zone"],
|
| 130 |
+
"inside_resistance_zone": resistance_stats["inside_zone"],
|
| 131 |
+
"support_swept": support_stats["swept"],
|
| 132 |
+
"resistance_swept": resistance_stats["swept"],
|
| 133 |
+
"support_reclaim": support_stats["confirm"],
|
| 134 |
+
"resistance_reject": resistance_stats["confirm"],
|
| 135 |
+
"sr_available": 1.0 if support_levels or resistance_levels else 0.0,
|
| 136 |
+
}
|
signals/trendlines.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Iterable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import trendln
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _empty_trendline_features() -> Dict[str, float]:
|
| 8 |
+
return {
|
| 9 |
+
"lower_trendline_slope": 0.0,
|
| 10 |
+
"upper_trendline_slope": 0.0,
|
| 11 |
+
"dist_to_lower_line": 0.0,
|
| 12 |
+
"dist_to_upper_line": 0.0,
|
| 13 |
+
"trend_channel_width": 0.0,
|
| 14 |
+
"trend_convergence": 0.0,
|
| 15 |
+
"trend_breakout_upper": 0.0,
|
| 16 |
+
"trend_breakdown_lower": 0.0,
|
| 17 |
+
"trend_reentry": 0.0,
|
| 18 |
+
"trendline_available": 0.0,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _extract_best_line(line_data: object) -> Optional[Tuple[float, float]]:
|
| 23 |
+
if not isinstance(line_data, Iterable):
|
| 24 |
+
return None
|
| 25 |
+
line_list = list(line_data)
|
| 26 |
+
if not line_list:
|
| 27 |
+
return None
|
| 28 |
+
best = line_list[0]
|
| 29 |
+
if not isinstance(best, tuple) or len(best) < 2 or not isinstance(best[1], tuple) or len(best[1]) < 2:
|
| 30 |
+
return None
|
| 31 |
+
slope = float(best[1][0])
|
| 32 |
+
intercept = float(best[1][1])
|
| 33 |
+
if not np.isfinite(slope) or not np.isfinite(intercept):
|
| 34 |
+
return None
|
| 35 |
+
return slope, intercept
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _extract_overall_line(summary_data: object) -> Optional[Tuple[float, float]]:
|
| 39 |
+
if isinstance(summary_data, (tuple, list)) and len(summary_data) >= 2:
|
| 40 |
+
slope = float(summary_data[0])
|
| 41 |
+
intercept = float(summary_data[1])
|
| 42 |
+
if np.isfinite(slope) and np.isfinite(intercept):
|
| 43 |
+
return slope, intercept
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _fit_with_trendln(values: np.ndarray) -> Tuple[Optional[Tuple[float, float]], Optional[Tuple[float, float]]]:
|
| 48 |
+
window = max(10, min(125, int(values.size)))
|
| 49 |
+
out = trendln.calc_support_resistance(
|
| 50 |
+
values,
|
| 51 |
+
extmethod=trendln.METHOD_NAIVE,
|
| 52 |
+
method=trendln.METHOD_NSQUREDLOGN,
|
| 53 |
+
window=window,
|
| 54 |
+
errpct=0.01,
|
| 55 |
+
)
|
| 56 |
+
support_part, resistance_part = out
|
| 57 |
+
support_line = _extract_best_line(support_part[2]) or _extract_overall_line(support_part[1])
|
| 58 |
+
resistance_line = _extract_best_line(resistance_part[2]) or _extract_overall_line(resistance_part[1])
|
| 59 |
+
return support_line, resistance_line
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_trendline_features(closes, highs, lows, end_idx: int) -> Dict[str, float]:
|
| 63 |
+
closes_np = np.asarray(closes[: end_idx + 1], dtype=np.float64)
|
| 64 |
+
highs_np = np.asarray(highs[: end_idx + 1], dtype=np.float64)
|
| 65 |
+
lows_np = np.asarray(lows[: end_idx + 1], dtype=np.float64)
|
| 66 |
+
if closes_np.size < 5:
|
| 67 |
+
return _empty_trendline_features()
|
| 68 |
+
|
| 69 |
+
current_price = float(closes_np[-1])
|
| 70 |
+
idx = float(closes_np.size - 1)
|
| 71 |
+
prev_idx = max(0.0, idx - 1.0)
|
| 72 |
+
prev_close = float(closes_np[-2]) if closes_np.size > 1 else current_price
|
| 73 |
+
price_scale = max(abs(current_price), 1e-8)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
support_line, resistance_line = _fit_with_trendln(closes_np)
|
| 77 |
+
except Exception:
|
| 78 |
+
support_line, resistance_line = None, None
|
| 79 |
+
|
| 80 |
+
if support_line is None:
|
| 81 |
+
try:
|
| 82 |
+
support_line, _ = _fit_with_trendln(lows_np)
|
| 83 |
+
except Exception:
|
| 84 |
+
support_line = None
|
| 85 |
+
if resistance_line is None:
|
| 86 |
+
try:
|
| 87 |
+
_, resistance_line = _fit_with_trendln(highs_np)
|
| 88 |
+
except Exception:
|
| 89 |
+
resistance_line = None
|
| 90 |
+
|
| 91 |
+
if support_line is None and resistance_line is None:
|
| 92 |
+
return _empty_trendline_features()
|
| 93 |
+
|
| 94 |
+
lower_pred = support_line[0] * idx + support_line[1] if support_line is not None else current_price
|
| 95 |
+
upper_pred = resistance_line[0] * idx + resistance_line[1] if resistance_line is not None else current_price
|
| 96 |
+
lower_prev = support_line[0] * prev_idx + support_line[1] if support_line is not None else lower_pred
|
| 97 |
+
upper_prev = resistance_line[0] * prev_idx + resistance_line[1] if resistance_line is not None else upper_pred
|
| 98 |
+
width = max(abs(upper_pred - lower_pred), 1e-8)
|
| 99 |
+
|
| 100 |
+
breakout_upper = 1.0 if current_price > upper_pred and prev_close <= upper_prev else 0.0
|
| 101 |
+
breakdown_lower = 1.0 if current_price < lower_pred and prev_close >= lower_prev else 0.0
|
| 102 |
+
reentry = 1.0 if (
|
| 103 |
+
(prev_close > upper_prev and current_price <= upper_pred) or
|
| 104 |
+
(prev_close < lower_prev and current_price >= lower_pred)
|
| 105 |
+
) else 0.0
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"lower_trendline_slope": 0.0 if support_line is None else support_line[0] / price_scale,
|
| 109 |
+
"upper_trendline_slope": 0.0 if resistance_line is None else resistance_line[0] / price_scale,
|
| 110 |
+
"dist_to_lower_line": (current_price - lower_pred) / price_scale,
|
| 111 |
+
"dist_to_upper_line": (upper_pred - current_price) / price_scale,
|
| 112 |
+
"trend_channel_width": width / price_scale,
|
| 113 |
+
"trend_convergence": 0.0 if support_line is None or resistance_line is None else (resistance_line[0] - support_line[0]) / price_scale,
|
| 114 |
+
"trend_breakout_upper": breakout_upper,
|
| 115 |
+
"trend_breakdown_lower": breakdown_lower,
|
| 116 |
+
"trend_reentry": reentry,
|
| 117 |
+
"trendline_available": 1.0,
|
| 118 |
+
}
|
train.py
CHANGED
|
@@ -59,8 +59,10 @@ from models.token_encoder import TokenEncoder
|
|
| 59 |
from models.wallet_encoder import WalletEncoder
|
| 60 |
from models.graph_updater import GraphUpdater
|
| 61 |
from models.ohlc_embedder import OHLCEmbedder
|
|
|
|
| 62 |
from models.model import Oracle
|
| 63 |
import models.vocabulary as vocab
|
|
|
|
| 64 |
|
| 65 |
# Setup Logger
|
| 66 |
logger = get_logger(__name__)
|
|
@@ -671,6 +673,11 @@ def main() -> None:
|
|
| 671 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 672 |
dtype=init_dtype
|
| 673 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
|
| 675 |
collator = MemecoinCollator(
|
| 676 |
event_type_to_id=vocab.EVENT_TO_ID,
|
|
@@ -809,6 +816,7 @@ def main() -> None:
|
|
| 809 |
wallet_encoder=wallet_encoder,
|
| 810 |
graph_updater=graph_updater,
|
| 811 |
ohlc_embedder=ohlc_embedder,
|
|
|
|
| 812 |
time_encoder=time_encoder,
|
| 813 |
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 814 |
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
|
|
|
| 59 |
from models.wallet_encoder import WalletEncoder
|
| 60 |
from models.graph_updater import GraphUpdater
|
| 61 |
from models.ohlc_embedder import OHLCEmbedder
|
| 62 |
+
from models.quant_ohlc_embedder import QuantOHLCEmbedder
|
| 63 |
from models.model import Oracle
|
| 64 |
import models.vocabulary as vocab
|
| 65 |
+
from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
|
| 66 |
|
| 67 |
# Setup Logger
|
| 68 |
logger = get_logger(__name__)
|
|
|
|
| 673 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 674 |
dtype=init_dtype
|
| 675 |
)
|
| 676 |
+
quant_ohlc_embedder = QuantOHLCEmbedder(
|
| 677 |
+
num_features=NUM_QUANT_OHLC_FEATURES,
|
| 678 |
+
sequence_length=TOKENS_PER_SEGMENT,
|
| 679 |
+
dtype=init_dtype,
|
| 680 |
+
)
|
| 681 |
|
| 682 |
collator = MemecoinCollator(
|
| 683 |
event_type_to_id=vocab.EVENT_TO_ID,
|
|
|
|
| 816 |
wallet_encoder=wallet_encoder,
|
| 817 |
graph_updater=graph_updater,
|
| 818 |
ohlc_embedder=ohlc_embedder,
|
| 819 |
+
quant_ohlc_embedder=quant_ohlc_embedder,
|
| 820 |
time_encoder=time_encoder,
|
| 821 |
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 822 |
multi_modal_dim=multi_modal_encoder.embedding_dim,
|