zirobtc commited on
Commit
3780496
·
1 Parent(s): bb2313b

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  log.log filter=lfs diff=lfs merge=lfs -text
37
  store/74c/74c70007-cccd-4669-bfd4-e25f8348ad8c/all_1_35_2/primary.cidx filter=lfs diff=lfs merge=lfs -text
38
  data/quality_scores.jsonl filter=lfs diff=lfs merge=lfs -text
 
 
36
  log.log filter=lfs diff=lfs merge=lfs -text
37
  store/74c/74c70007-cccd-4669-bfd4-e25f8348ad8c/all_1_35_2/primary.cidx filter=lfs diff=lfs merge=lfs -text
38
  data/quality_scores.jsonl filter=lfs diff=lfs merge=lfs -text
39
+ data/backup_20260225_073238.log filter=lfs diff=lfs merge=lfs -text
audit_cache.py CHANGED
@@ -7,7 +7,7 @@ from collections import defaultdict
7
  import glob
8
  from tqdm import tqdm
9
 
10
- def audit_cache(cache_dir, num_samples=1000):
11
  files = glob.glob(os.path.join(cache_dir, "sample_*.pt"))
12
  if not files:
13
  print(f"No .pt files found in {cache_dir}")
 
7
  import glob
8
  from tqdm import tqdm
9
 
10
+ def audit_cache(cache_dir, num_samples=10000):
11
  files = glob.glob(os.path.join(cache_dir, "sample_*.pt"))
12
  if not files:
13
  print(f"No .pt files found in {cache_dir}")
cache_dataset.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import sys
4
  import argparse
5
- import numpy as np
6
  import datetime
7
  import torch
8
  import json
@@ -45,7 +45,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
45
  data_fetcher=data_fetcher,
46
  max_samples=dataset_config['max_samples'],
47
  start_date=dataset_config['start_date'],
48
- ohlc_stats_path=dataset_config['ohlc_stats_path'],
49
  horizons_seconds=dataset_config['horizons_seconds'],
50
  quantiles=dataset_config['quantiles'],
51
  min_trade_usd=dataset_config['min_trade_usd'],
@@ -110,21 +109,6 @@ def _process_single_token_raw(args):
110
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
111
 
112
 
113
- def compute_save_ohlc_stats(client, output_path):
114
- print(f"INFO: Computing OHLC stats...")
115
- query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
116
- try:
117
- result = client.execute(query)
118
- if result and result[0]:
119
- row = result[0]
120
- stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
121
- else:
122
- stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
123
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
124
- np.savez(output_path, **stats)
125
- print(f"INFO: Saved OHLC stats to {output_path}")
126
- except Exception as e:
127
- print(f"ERROR: Failed to compute OHLC stats: {e}")
128
 
129
 
130
  def main():
@@ -140,7 +124,7 @@ def main():
140
  parser.add_argument("--output_dir", type=str, default="data/cache")
141
  parser.add_argument("--max_samples", type=int, default=None)
142
  parser.add_argument("--start_date", type=str, default=None)
143
- parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
144
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
145
  parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
146
  parser.add_argument("--context_length", type=int, default=8192)
@@ -166,7 +150,6 @@ def main():
166
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
167
 
168
  try:
169
- compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
170
 
171
  from data.data_loader import OracleDataset
172
  from data.data_fetcher import DataFetcher
@@ -180,7 +163,7 @@ def main():
180
  quality_scores_map = get_token_quality_scores(clickhouse_client)
181
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
182
 
183
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
184
 
185
  if len(dataset) == 0:
186
  print("WARNING: No samples. Exiting.")
@@ -198,7 +181,7 @@ def main():
198
  print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
199
 
200
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
201
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
202
 
203
  # Build tasks from filtered_mints directly
204
  tasks = []
 
2
  import os
3
  import sys
4
  import argparse
5
+
6
  import datetime
7
  import torch
8
  import json
 
45
  data_fetcher=data_fetcher,
46
  max_samples=dataset_config['max_samples'],
47
  start_date=dataset_config['start_date'],
 
48
  horizons_seconds=dataset_config['horizons_seconds'],
49
  quantiles=dataset_config['quantiles'],
50
  min_trade_usd=dataset_config['min_trade_usd'],
 
109
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def main():
 
124
  parser.add_argument("--output_dir", type=str, default="data/cache")
125
  parser.add_argument("--max_samples", type=int, default=None)
126
  parser.add_argument("--start_date", type=str, default=None)
127
+
128
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
129
  parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
130
  parser.add_argument("--context_length", type=int, default=8192)
 
150
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
151
 
152
  try:
 
153
 
154
  from data.data_loader import OracleDataset
155
  from data.data_fetcher import DataFetcher
 
163
  quality_scores_map = get_token_quality_scores(clickhouse_client)
164
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
165
 
166
+ dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
167
 
168
  if len(dataset) == 0:
169
  print("WARNING: No samples. Exiting.")
 
181
  print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
182
 
183
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
184
+ dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
185
 
186
  # Build tasks from filtered_mints directly
187
  tasks = []
data/all_files.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/backup_20260225_073238.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0454111eda3beeb85fd701110438e8dbced5f4440037a6964bcd0d5c527607d
3
+ size 13931583
data/batch_list_aa ADDED
The diff for this file is too large to render. See raw diff
 
data/batch_list_ab ADDED
The diff for this file is too large to render. See raw diff
 
data/batch_list_ac ADDED
The diff for this file is too large to render. See raw diff
 
data/batch_list_ad ADDED
The diff for this file is too large to render. See raw diff
 
data/batch_list_ae ADDED
The diff for this file is too large to render. See raw diff
 
data/data_loader.py CHANGED
@@ -122,7 +122,7 @@ class OracleDataset(Dataset):
122
  horizons_seconds: List[int] = [],
123
  quantiles: List[float] = [],
124
  max_samples: Optional[int] = None,
125
- ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz",
126
  token_allowlist: Optional[List[str]] = None,
127
  t_cutoff_seconds: int = 60,
128
  cache_dir: Optional[Union[str, Path]] = None,
@@ -136,7 +136,6 @@ class OracleDataset(Dataset):
136
  # --- P99 data-driven clamp values (replace hardcoded min/max) ---
137
  self.p99_clamps = {
138
  'slippage': 1.0,
139
- 'priority_fee': 0.1,
140
  'total_usd': 100000.0,
141
  'history_bought_cost_sol': 30.0,
142
  'realized_profit_sol': 150.0,
@@ -316,20 +315,7 @@ class OracleDataset(Dataset):
316
  else:
317
  self.max_cache_horizon_seconds = 3600
318
 
319
- # --- NEW: Load global OHLC normalization stats ---
320
- self.ohlc_price_mean = 0.0
321
- self.ohlc_price_std = 1.0
322
-
323
- if ohlc_stats_path:
324
- stats_path = Path(ohlc_stats_path)
325
- if stats_path.exists():
326
- stats = np.load(stats_path)
327
- self.ohlc_price_mean = float(stats.get('mean_price_usd', 0.0))
328
- self.ohlc_price_std = float(stats.get('std_price_usd', 1.0)) or 1.0
329
- else:
330
- print(f"WARNING: OHLC stats file not found at {stats_path}. Using default normalization (mean=0, std=1).")
331
- else:
332
- print("INFO: No OHLC stats path provided. Using default normalization.")
333
 
334
  self.min_trade_usd = min_trade_usd
335
  self._uri_fail_counts: Dict[str, int] = {}
 
122
  horizons_seconds: List[int] = [],
123
  quantiles: List[float] = [],
124
  max_samples: Optional[int] = None,
125
+
126
  token_allowlist: Optional[List[str]] = None,
127
  t_cutoff_seconds: int = 60,
128
  cache_dir: Optional[Union[str, Path]] = None,
 
136
  # --- P99 data-driven clamp values (replace hardcoded min/max) ---
137
  self.p99_clamps = {
138
  'slippage': 1.0,
 
139
  'total_usd': 100000.0,
140
  'history_bought_cost_sol': 30.0,
141
  'realized_profit_sol': 150.0,
 
315
  else:
316
  self.max_cache_horizon_seconds = 3600
317
 
318
+
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  self.min_trade_usd = min_trade_usd
321
  self._uri_fail_counts: Dict[str, int] = {}
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1d757990a0158118444be61f3d944dfb125237928809b4568ac209ab260f032e
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:667badf0d42d97e84ec60d58a1a4594f3141199325ab02b652adcf474d0a34f7
3
  size 1660
ingest.sh CHANGED
@@ -20,7 +20,7 @@ error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
20
  #===============================================================================
21
  header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
22
 
23
- EPOCHS=(844)
24
 
25
 
26
  log "Processing epochs one at a time to minimize disk usage..."
 
20
  #===============================================================================
21
  header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
22
 
23
+ EPOCHS=(844 845 846 847 848 849 850 851)
24
 
25
 
26
  log "Processing epochs one at a time to minimize disk usage..."
pre_cache.sh CHANGED
@@ -5,7 +5,7 @@ CONTEXT_LENGTH=4096
5
  MIN_TRADES=10
6
  SAMPLES_PER_TOKEN=1
7
  NUM_WORKERS=1
8
- OHLC_STATS_PATH="/workspace/apollo/data/ohlc_stats.npz"
9
  OUTPUT_DIR="data/cache"
10
 
11
  # Label horizons in seconds, relative to each sampled T_cutoff.
@@ -24,7 +24,7 @@ echo "Num Workers: $NUM_WORKERS"
24
  echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
25
  echo "Quantiles: ${QUANTILES[*]}"
26
  echo "Output Directory: $OUTPUT_DIR"
27
- echo "OHLC Stats Path: $OHLC_STATS_PATH"
28
  echo "========================================"
29
 
30
  echo "Starting dataset caching..."
@@ -32,7 +32,6 @@ echo "Starting dataset caching..."
32
  mkdir -p "$OUTPUT_DIR"
33
 
34
  python3 scripts/cache_dataset.py \
35
- --ohlc_stats_path "$OHLC_STATS_PATH" \
36
  --output_dir "$OUTPUT_DIR" \
37
  --context_length "$CONTEXT_LENGTH" \
38
  --min_trades "$MIN_TRADES" \
@@ -40,7 +39,7 @@ python3 scripts/cache_dataset.py \
40
  --num_workers "$NUM_WORKERS" \
41
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
42
  --quantiles "${QUANTILES[@]}" \
43
- --max_samples 150000 \
44
  "$@"
45
 
46
  echo "Done!"
 
5
  MIN_TRADES=10
6
  SAMPLES_PER_TOKEN=1
7
  NUM_WORKERS=1
8
+
9
  OUTPUT_DIR="data/cache"
10
 
11
  # Label horizons in seconds, relative to each sampled T_cutoff.
 
24
  echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
25
  echo "Quantiles: ${QUANTILES[*]}"
26
  echo "Output Directory: $OUTPUT_DIR"
27
+
28
  echo "========================================"
29
 
30
  echo "Starting dataset caching..."
 
32
  mkdir -p "$OUTPUT_DIR"
33
 
34
  python3 scripts/cache_dataset.py \
 
35
  --output_dir "$OUTPUT_DIR" \
36
  --context_length "$CONTEXT_LENGTH" \
37
  --min_trades "$MIN_TRADES" \
 
39
  --num_workers "$NUM_WORKERS" \
40
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
41
  --quantiles "${QUANTILES[@]}" \
42
+ --max_samples 300000 \
43
  "$@"
44
 
45
  echo "Done!"
sample_121MxrQDsaY35gC4_0.json ADDED
The diff for this file is too large to render. See raw diff
 
sample_14CfRkQ9CFP4o9nV_3.json ADDED
The diff for this file is too large to render. See raw diff
 
sample_2tYvBaQqXYy7Y5Qk_3.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/analyze_distribution.py CHANGED
@@ -37,7 +37,6 @@ def compute_p99_clamps(client):
37
  trade_query = """
38
  SELECT
39
  quantile(0.99)(abs(slippage)) AS p99_slippage,
40
- quantile(0.99)(priority_fee) AS p99_priority_fee,
41
  quantile(0.99)(total_usd) AS p99_total_usd
42
  FROM trades
43
  WHERE success = 1
@@ -56,7 +55,7 @@ def compute_p99_clamps(client):
56
  clamps = {
57
  # Defaults as fallback if queries return nothing
58
  'slippage': 1.0,
59
- 'priority_fee': 0.1,
60
  'total_usd': 100000.0,
61
  'history_bought_cost_sol': 30.0,
62
  'realized_profit_sol': 150.0,
@@ -65,8 +64,7 @@ def compute_p99_clamps(client):
65
  if trade_row and trade_row[0]:
66
  r = trade_row[0]
67
  clamps['slippage'] = max(float(r[0]), 0.01)
68
- clamps['priority_fee'] = max(float(r[1]), 1e-9)
69
- clamps['total_usd'] = max(float(r[2]), 1.0)
70
 
71
  if holdings_row and holdings_row[0]:
72
  r = holdings_row[0]
 
37
  trade_query = """
38
  SELECT
39
  quantile(0.99)(abs(slippage)) AS p99_slippage,
 
40
  quantile(0.99)(total_usd) AS p99_total_usd
41
  FROM trades
42
  WHERE success = 1
 
55
  clamps = {
56
  # Defaults as fallback if queries return nothing
57
  'slippage': 1.0,
58
+
59
  'total_usd': 100000.0,
60
  'history_bought_cost_sol': 30.0,
61
  'realized_profit_sol': 150.0,
 
64
  if trade_row and trade_row[0]:
65
  r = trade_row[0]
66
  clamps['slippage'] = max(float(r[0]), 0.01)
67
+ clamps['total_usd'] = max(float(r[1]), 1.0)
 
68
 
69
  if holdings_row and holdings_row[0]:
70
  r = holdings_row[0]
scripts/cache_dataset.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import sys
4
  import argparse
5
- import numpy as np
6
  import datetime
7
  import torch
8
  import json
@@ -61,7 +61,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
61
  data_fetcher=data_fetcher,
62
  max_samples=dataset_config['max_samples'],
63
  start_date=dataset_config['start_date'],
64
- ohlc_stats_path=dataset_config['ohlc_stats_path'],
65
  horizons_seconds=dataset_config['horizons_seconds'],
66
  quantiles=dataset_config['quantiles'],
67
  min_trade_usd=dataset_config['min_trade_usd'],
@@ -112,21 +111,6 @@ def _process_single_token_context(args):
112
 
113
 
114
 
115
- def compute_save_ohlc_stats(client, output_path):
116
- print(f"INFO: Computing OHLC stats...")
117
- query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
118
- try:
119
- result = client.execute(query)
120
- if result and result[0]:
121
- row = result[0]
122
- stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
123
- else:
124
- stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
125
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
126
- np.savez(output_path, **stats)
127
- print(f"INFO: Saved OHLC stats to {output_path}")
128
- except Exception as e:
129
- print(f"ERROR: Failed to compute OHLC stats: {e}")
130
 
131
 
132
  def main():
@@ -142,7 +126,7 @@ def main():
142
  parser.add_argument("--output_dir", type=str, default="data/cache")
143
  parser.add_argument("--max_samples", type=int, default=None)
144
  parser.add_argument("--start_date", type=str, default=None)
145
- parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
146
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
147
 
148
  parser.add_argument("--context_length", type=int, default=8192)
@@ -170,7 +154,6 @@ def main():
170
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
171
 
172
  try:
173
- compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
174
 
175
  from data.data_loader import OracleDataset
176
  from data.data_fetcher import DataFetcher
@@ -187,7 +170,7 @@ def main():
187
  quality_scores_map = get_token_quality_scores(clickhouse_client)
188
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
189
 
190
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
191
 
192
  if len(dataset) == 0:
193
  print("WARNING: No samples. Exiting.")
@@ -223,7 +206,7 @@ def main():
223
  print(f"INFO: Workers: {args.num_workers}")
224
 
225
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
226
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
227
 
228
  # Build tasks with class-aware multi-sampling for balanced cache
229
  import random
 
2
  import os
3
  import sys
4
  import argparse
5
+
6
  import datetime
7
  import torch
8
  import json
 
61
  data_fetcher=data_fetcher,
62
  max_samples=dataset_config['max_samples'],
63
  start_date=dataset_config['start_date'],
 
64
  horizons_seconds=dataset_config['horizons_seconds'],
65
  quantiles=dataset_config['quantiles'],
66
  min_trade_usd=dataset_config['min_trade_usd'],
 
111
 
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  def main():
 
126
  parser.add_argument("--output_dir", type=str, default="data/cache")
127
  parser.add_argument("--max_samples", type=int, default=None)
128
  parser.add_argument("--start_date", type=str, default=None)
129
+
130
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
131
 
132
  parser.add_argument("--context_length", type=int, default=8192)
 
154
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
155
 
156
  try:
 
157
 
158
  from data.data_loader import OracleDataset
159
  from data.data_fetcher import DataFetcher
 
170
  quality_scores_map = get_token_quality_scores(clickhouse_client)
171
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
172
 
173
+ dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
174
 
175
  if len(dataset) == 0:
176
  print("WARNING: No samples. Exiting.")
 
206
  print(f"INFO: Workers: {args.num_workers}")
207
 
208
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
209
+ dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
210
 
211
  # Build tasks with class-aware multi-sampling for balanced cache
212
  import random
scripts/cache_parallel.py CHANGED
@@ -29,7 +29,6 @@ def cache_chunk(args):
29
 
30
  ds = OracleDataset(
31
  data_fetcher=fetcher,
32
- ohlc_stats_path=db_args['ohlc_stats_path'],
33
  horizons_seconds=[30, 60, 120, 240, 420],
34
  quantiles=[0.1, 0.5, 0.9],
35
  )
@@ -91,7 +90,7 @@ def main():
91
 
92
  fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
93
  return_map, _ = get_return_class_map(ch)
94
- ds = OracleDataset(data_fetcher=fetcher, ohlc_stats_path="data/ohlc_stats.npz",
95
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
96
  ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
97
  total = len(ds)
@@ -108,7 +107,7 @@ def main():
108
  'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
109
  'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
110
  'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),
111
- 'ohlc_stats_path': "data/ohlc_stats.npz",
112
  }
113
 
114
  tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]
 
29
 
30
  ds = OracleDataset(
31
  data_fetcher=fetcher,
 
32
  horizons_seconds=[30, 60, 120, 240, 420],
33
  quantiles=[0.1, 0.5, 0.9],
34
  )
 
90
 
91
  fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
92
  return_map, _ = get_return_class_map(ch)
93
+ ds = OracleDataset(data_fetcher=fetcher,
94
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
95
  ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
96
  total = len(ds)
 
107
  'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
108
  'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
109
  'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),
110
+
111
  }
112
 
113
  tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]
train.py CHANGED
@@ -328,7 +328,7 @@ def parse_args() -> argparse.Namespace:
328
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
329
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
330
  parser.add_argument("--max_samples", type=int, default=None)
331
- parser.add_argument("--ohlc_stats_path", type=str, default="./data/ohlc_stats.npz")
332
  parser.add_argument("--t_cutoff_seconds", type=int, default=60)
333
  parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
334
  parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
@@ -473,7 +473,6 @@ def main() -> None:
473
  horizons_seconds=horizons,
474
  quantiles=quantiles,
475
  max_samples=args.max_samples,
476
- ohlc_stats_path=args.ohlc_stats_path,
477
  t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
478
  cache_dir="/workspace/apollo/data/cache"
479
  )
 
328
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
329
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
330
  parser.add_argument("--max_samples", type=int, default=None)
331
+
332
  parser.add_argument("--t_cutoff_seconds", type=int, default=60)
333
  parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
334
  parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
 
473
  horizons_seconds=horizons,
474
  quantiles=quantiles,
475
  max_samples=args.max_samples,
 
476
  t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
477
  cache_dir="/workspace/apollo/data/cache"
478
  )
train.sh CHANGED
@@ -16,7 +16,6 @@ accelerate launch train.py \
16
  --max_seq_len 4096 \
17
  --horizons_seconds 300 900 1800 3600 7200 \
18
  --quantiles 0.1 0.5 0.9 \
19
- --ohlc_stats_path ./data/ohlc_stats.npz \
20
  --num_workers 0 \
21
  --val_samples_per_class 2 \
22
  --val_every 100 \
 
16
  --max_seq_len 4096 \
17
  --horizons_seconds 300 900 1800 3600 7200 \
18
  --quantiles 0.1 0.5 0.9 \
 
19
  --num_workers 0 \
20
  --val_samples_per_class 2 \
21
  --val_every 100 \
train.yaml CHANGED
@@ -14,7 +14,6 @@ data:
14
  quantiles: [0.1, 0.5, 0.9]
15
  max_seq_len: 4096
16
  ohlc_seq_len: 300
17
- ohlc_stats_path: ./data/ohlc_stats.npz
18
  t_cutoff_seconds: 60
19
  shuffle: true
20
  num_workers: 4
 
14
  quantiles: [0.1, 0.5, 0.9]
15
  max_seq_len: 4096
16
  ohlc_seq_len: 300
 
17
  t_cutoff_seconds: 60
18
  shuffle: true
19
  num_workers: 4
validate.py CHANGED
@@ -100,7 +100,6 @@ def main() -> None:
100
  ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
101
  default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
102
  t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
103
- ohlc_stats_path = data_cfg.get("ohlc_stats_path", "./data/ohlc_stats.npz")
104
 
105
  multi_modal_encoder = MultiModalEncoder(dtype=dtype)
106
  time_encoder = ContextualTimeEncoder(dtype=dtype)
@@ -140,7 +139,6 @@ def main() -> None:
140
  horizons_seconds=horizons,
141
  quantiles=quantiles,
142
  max_samples=max_samples,
143
- ohlc_stats_path=ohlc_stats_path,
144
  token_allowlist=[args.token_address] if args.token_address else None,
145
  t_cutoff_seconds=t_cutoff_seconds
146
  )
 
100
  ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
101
  default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
102
  t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
 
103
 
104
  multi_modal_encoder = MultiModalEncoder(dtype=dtype)
105
  time_encoder = ContextualTimeEncoder(dtype=dtype)
 
139
  horizons_seconds=horizons,
140
  quantiles=quantiles,
141
  max_samples=max_samples,
 
142
  token_allowlist=[args.token_address] if args.token_address else None,
143
  t_cutoff_seconds=t_cutoff_seconds
144
  )