Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- audit_cache.py +1 -1
- cache_dataset.py +4 -21
- data/all_files.txt +0 -0
- data/backup_20260225_073238.log +3 -0
- data/batch_list_aa +0 -0
- data/batch_list_ab +0 -0
- data/batch_list_ac +0 -0
- data/batch_list_ad +0 -0
- data/batch_list_ae +0 -0
- data/data_loader.py +2 -16
- data/ohlc_stats.npz +1 -1
- ingest.sh +1 -1
- pre_cache.sh +3 -4
- sample_121MxrQDsaY35gC4_0.json +0 -0
- sample_14CfRkQ9CFP4o9nV_3.json +0 -0
- sample_2tYvBaQqXYy7Y5Qk_3.json +0 -0
- scripts/analyze_distribution.py +2 -4
- scripts/cache_dataset.py +4 -21
- scripts/cache_parallel.py +2 -3
- train.py +1 -2
- train.sh +0 -1
- train.yaml +0 -1
- validate.py +0 -2
.gitattributes
CHANGED
|
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
log.log filter=lfs diff=lfs merge=lfs -text
|
| 37 |
store/74c/74c70007-cccd-4669-bfd4-e25f8348ad8c/all_1_35_2/primary.cidx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
data/quality_scores.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 36 |
log.log filter=lfs diff=lfs merge=lfs -text
|
| 37 |
store/74c/74c70007-cccd-4669-bfd4-e25f8348ad8c/all_1_35_2/primary.cidx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
data/quality_scores.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/backup_20260225_073238.log filter=lfs diff=lfs merge=lfs -text
|
audit_cache.py
CHANGED
|
@@ -7,7 +7,7 @@ from collections import defaultdict
|
|
| 7 |
import glob
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
-
def audit_cache(cache_dir, num_samples=
|
| 11 |
files = glob.glob(os.path.join(cache_dir, "sample_*.pt"))
|
| 12 |
if not files:
|
| 13 |
print(f"No .pt files found in {cache_dir}")
|
|
|
|
| 7 |
import glob
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
+
def audit_cache(cache_dir, num_samples=10000):
|
| 11 |
files = glob.glob(os.path.join(cache_dir, "sample_*.pt"))
|
| 12 |
if not files:
|
| 13 |
print(f"No .pt files found in {cache_dir}")
|
cache_dataset.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
-
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
@@ -45,7 +45,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 45 |
data_fetcher=data_fetcher,
|
| 46 |
max_samples=dataset_config['max_samples'],
|
| 47 |
start_date=dataset_config['start_date'],
|
| 48 |
-
ohlc_stats_path=dataset_config['ohlc_stats_path'],
|
| 49 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 50 |
quantiles=dataset_config['quantiles'],
|
| 51 |
min_trade_usd=dataset_config['min_trade_usd'],
|
|
@@ -110,21 +109,6 @@ def _process_single_token_raw(args):
|
|
| 110 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
| 111 |
|
| 112 |
|
| 113 |
-
def compute_save_ohlc_stats(client, output_path):
|
| 114 |
-
print(f"INFO: Computing OHLC stats...")
|
| 115 |
-
query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
|
| 116 |
-
try:
|
| 117 |
-
result = client.execute(query)
|
| 118 |
-
if result and result[0]:
|
| 119 |
-
row = result[0]
|
| 120 |
-
stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
|
| 121 |
-
else:
|
| 122 |
-
stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
|
| 123 |
-
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
-
np.savez(output_path, **stats)
|
| 125 |
-
print(f"INFO: Saved OHLC stats to {output_path}")
|
| 126 |
-
except Exception as e:
|
| 127 |
-
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 128 |
|
| 129 |
|
| 130 |
def main():
|
|
@@ -140,7 +124,7 @@ def main():
|
|
| 140 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 141 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 142 |
parser.add_argument("--start_date", type=str, default=None)
|
| 143 |
-
|
| 144 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 145 |
parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
|
| 146 |
parser.add_argument("--context_length", type=int, default=8192)
|
|
@@ -166,7 +150,6 @@ def main():
|
|
| 166 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 167 |
|
| 168 |
try:
|
| 169 |
-
compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
|
| 170 |
|
| 171 |
from data.data_loader import OracleDataset
|
| 172 |
from data.data_fetcher import DataFetcher
|
|
@@ -180,7 +163,7 @@ def main():
|
|
| 180 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 181 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 182 |
|
| 183 |
-
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt,
|
| 184 |
|
| 185 |
if len(dataset) == 0:
|
| 186 |
print("WARNING: No samples. Exiting.")
|
|
@@ -198,7 +181,7 @@ def main():
|
|
| 198 |
print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
|
| 199 |
|
| 200 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 201 |
-
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, '
|
| 202 |
|
| 203 |
# Build tasks from filtered_mints directly
|
| 204 |
tasks = []
|
|
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
+
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
|
|
| 45 |
data_fetcher=data_fetcher,
|
| 46 |
max_samples=dataset_config['max_samples'],
|
| 47 |
start_date=dataset_config['start_date'],
|
|
|
|
| 48 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 49 |
quantiles=dataset_config['quantiles'],
|
| 50 |
min_trade_usd=dataset_config['min_trade_usd'],
|
|
|
|
| 109 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def main():
|
|
|
|
| 124 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 125 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 126 |
parser.add_argument("--start_date", type=str, default=None)
|
| 127 |
+
|
| 128 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 129 |
parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
|
| 130 |
parser.add_argument("--context_length", type=int, default=8192)
|
|
|
|
| 150 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 151 |
|
| 152 |
try:
|
|
|
|
| 153 |
|
| 154 |
from data.data_loader import OracleDataset
|
| 155 |
from data.data_fetcher import DataFetcher
|
|
|
|
| 163 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 164 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 165 |
|
| 166 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
|
| 167 |
|
| 168 |
if len(dataset) == 0:
|
| 169 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 181 |
print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
|
| 182 |
|
| 183 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 184 |
+
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
|
| 185 |
|
| 186 |
# Build tasks from filtered_mints directly
|
| 187 |
tasks = []
|
data/all_files.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/backup_20260225_073238.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0454111eda3beeb85fd701110438e8dbced5f4440037a6964bcd0d5c527607d
|
| 3 |
+
size 13931583
|
data/batch_list_aa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/batch_list_ab
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/batch_list_ac
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/batch_list_ad
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/batch_list_ae
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/data_loader.py
CHANGED
|
@@ -122,7 +122,7 @@ class OracleDataset(Dataset):
|
|
| 122 |
horizons_seconds: List[int] = [],
|
| 123 |
quantiles: List[float] = [],
|
| 124 |
max_samples: Optional[int] = None,
|
| 125 |
-
|
| 126 |
token_allowlist: Optional[List[str]] = None,
|
| 127 |
t_cutoff_seconds: int = 60,
|
| 128 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
@@ -136,7 +136,6 @@ class OracleDataset(Dataset):
|
|
| 136 |
# --- P99 data-driven clamp values (replace hardcoded min/max) ---
|
| 137 |
self.p99_clamps = {
|
| 138 |
'slippage': 1.0,
|
| 139 |
-
'priority_fee': 0.1,
|
| 140 |
'total_usd': 100000.0,
|
| 141 |
'history_bought_cost_sol': 30.0,
|
| 142 |
'realized_profit_sol': 150.0,
|
|
@@ -316,20 +315,7 @@ class OracleDataset(Dataset):
|
|
| 316 |
else:
|
| 317 |
self.max_cache_horizon_seconds = 3600
|
| 318 |
|
| 319 |
-
|
| 320 |
-
self.ohlc_price_mean = 0.0
|
| 321 |
-
self.ohlc_price_std = 1.0
|
| 322 |
-
|
| 323 |
-
if ohlc_stats_path:
|
| 324 |
-
stats_path = Path(ohlc_stats_path)
|
| 325 |
-
if stats_path.exists():
|
| 326 |
-
stats = np.load(stats_path)
|
| 327 |
-
self.ohlc_price_mean = float(stats.get('mean_price_usd', 0.0))
|
| 328 |
-
self.ohlc_price_std = float(stats.get('std_price_usd', 1.0)) or 1.0
|
| 329 |
-
else:
|
| 330 |
-
print(f"WARNING: OHLC stats file not found at {stats_path}. Using default normalization (mean=0, std=1).")
|
| 331 |
-
else:
|
| 332 |
-
print("INFO: No OHLC stats path provided. Using default normalization.")
|
| 333 |
|
| 334 |
self.min_trade_usd = min_trade_usd
|
| 335 |
self._uri_fail_counts: Dict[str, int] = {}
|
|
|
|
| 122 |
horizons_seconds: List[int] = [],
|
| 123 |
quantiles: List[float] = [],
|
| 124 |
max_samples: Optional[int] = None,
|
| 125 |
+
|
| 126 |
token_allowlist: Optional[List[str]] = None,
|
| 127 |
t_cutoff_seconds: int = 60,
|
| 128 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
|
|
| 136 |
# --- P99 data-driven clamp values (replace hardcoded min/max) ---
|
| 137 |
self.p99_clamps = {
|
| 138 |
'slippage': 1.0,
|
|
|
|
| 139 |
'total_usd': 100000.0,
|
| 140 |
'history_bought_cost_sol': 30.0,
|
| 141 |
'realized_profit_sol': 150.0,
|
|
|
|
| 315 |
else:
|
| 316 |
self.max_cache_horizon_seconds = 3600
|
| 317 |
|
| 318 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
self.min_trade_usd = min_trade_usd
|
| 321 |
self._uri_fail_counts: Dict[str, int] = {}
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:667badf0d42d97e84ec60d58a1a4594f3141199325ab02b652adcf474d0a34f7
|
| 3 |
size 1660
|
ingest.sh
CHANGED
|
@@ -20,7 +20,7 @@ error() { echo -e "${RED}[ERROR]${NC} $1"; exit 1; }
|
|
| 20 |
#===============================================================================
|
| 21 |
header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
|
| 22 |
|
| 23 |
-
EPOCHS=(844)
|
| 24 |
|
| 25 |
|
| 26 |
log "Processing epochs one at a time to minimize disk usage..."
|
|
|
|
| 20 |
#===============================================================================
|
| 21 |
header "Step 5-6/7: Processing Epochs (Download → Ingest → Delete)"
|
| 22 |
|
| 23 |
+
EPOCHS=(844 845 846 847 848 849 850 851)
|
| 24 |
|
| 25 |
|
| 26 |
log "Processing epochs one at a time to minimize disk usage..."
|
pre_cache.sh
CHANGED
|
@@ -5,7 +5,7 @@ CONTEXT_LENGTH=4096
|
|
| 5 |
MIN_TRADES=10
|
| 6 |
SAMPLES_PER_TOKEN=1
|
| 7 |
NUM_WORKERS=1
|
| 8 |
-
|
| 9 |
OUTPUT_DIR="data/cache"
|
| 10 |
|
| 11 |
# Label horizons in seconds, relative to each sampled T_cutoff.
|
|
@@ -24,7 +24,7 @@ echo "Num Workers: $NUM_WORKERS"
|
|
| 24 |
echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
|
| 25 |
echo "Quantiles: ${QUANTILES[*]}"
|
| 26 |
echo "Output Directory: $OUTPUT_DIR"
|
| 27 |
-
|
| 28 |
echo "========================================"
|
| 29 |
|
| 30 |
echo "Starting dataset caching..."
|
|
@@ -32,7 +32,6 @@ echo "Starting dataset caching..."
|
|
| 32 |
mkdir -p "$OUTPUT_DIR"
|
| 33 |
|
| 34 |
python3 scripts/cache_dataset.py \
|
| 35 |
-
--ohlc_stats_path "$OHLC_STATS_PATH" \
|
| 36 |
--output_dir "$OUTPUT_DIR" \
|
| 37 |
--context_length "$CONTEXT_LENGTH" \
|
| 38 |
--min_trades "$MIN_TRADES" \
|
|
@@ -40,7 +39,7 @@ python3 scripts/cache_dataset.py \
|
|
| 40 |
--num_workers "$NUM_WORKERS" \
|
| 41 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 42 |
--quantiles "${QUANTILES[@]}" \
|
| 43 |
-
--max_samples
|
| 44 |
"$@"
|
| 45 |
|
| 46 |
echo "Done!"
|
|
|
|
| 5 |
MIN_TRADES=10
|
| 6 |
SAMPLES_PER_TOKEN=1
|
| 7 |
NUM_WORKERS=1
|
| 8 |
+
|
| 9 |
OUTPUT_DIR="data/cache"
|
| 10 |
|
| 11 |
# Label horizons in seconds, relative to each sampled T_cutoff.
|
|
|
|
| 24 |
echo "Horizons (sec): ${HORIZONS_SECONDS[*]}"
|
| 25 |
echo "Quantiles: ${QUANTILES[*]}"
|
| 26 |
echo "Output Directory: $OUTPUT_DIR"
|
| 27 |
+
|
| 28 |
echo "========================================"
|
| 29 |
|
| 30 |
echo "Starting dataset caching..."
|
|
|
|
| 32 |
mkdir -p "$OUTPUT_DIR"
|
| 33 |
|
| 34 |
python3 scripts/cache_dataset.py \
|
|
|
|
| 35 |
--output_dir "$OUTPUT_DIR" \
|
| 36 |
--context_length "$CONTEXT_LENGTH" \
|
| 37 |
--min_trades "$MIN_TRADES" \
|
|
|
|
| 39 |
--num_workers "$NUM_WORKERS" \
|
| 40 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 41 |
--quantiles "${QUANTILES[@]}" \
|
| 42 |
+
--max_samples 300000 \
|
| 43 |
"$@"
|
| 44 |
|
| 45 |
echo "Done!"
|
sample_121MxrQDsaY35gC4_0.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_14CfRkQ9CFP4o9nV_3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_2tYvBaQqXYy7Y5Qk_3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/analyze_distribution.py
CHANGED
|
@@ -37,7 +37,6 @@ def compute_p99_clamps(client):
|
|
| 37 |
trade_query = """
|
| 38 |
SELECT
|
| 39 |
quantile(0.99)(abs(slippage)) AS p99_slippage,
|
| 40 |
-
quantile(0.99)(priority_fee) AS p99_priority_fee,
|
| 41 |
quantile(0.99)(total_usd) AS p99_total_usd
|
| 42 |
FROM trades
|
| 43 |
WHERE success = 1
|
|
@@ -56,7 +55,7 @@ def compute_p99_clamps(client):
|
|
| 56 |
clamps = {
|
| 57 |
# Defaults as fallback if queries return nothing
|
| 58 |
'slippage': 1.0,
|
| 59 |
-
|
| 60 |
'total_usd': 100000.0,
|
| 61 |
'history_bought_cost_sol': 30.0,
|
| 62 |
'realized_profit_sol': 150.0,
|
|
@@ -65,8 +64,7 @@ def compute_p99_clamps(client):
|
|
| 65 |
if trade_row and trade_row[0]:
|
| 66 |
r = trade_row[0]
|
| 67 |
clamps['slippage'] = max(float(r[0]), 0.01)
|
| 68 |
-
clamps['
|
| 69 |
-
clamps['total_usd'] = max(float(r[2]), 1.0)
|
| 70 |
|
| 71 |
if holdings_row and holdings_row[0]:
|
| 72 |
r = holdings_row[0]
|
|
|
|
| 37 |
trade_query = """
|
| 38 |
SELECT
|
| 39 |
quantile(0.99)(abs(slippage)) AS p99_slippage,
|
|
|
|
| 40 |
quantile(0.99)(total_usd) AS p99_total_usd
|
| 41 |
FROM trades
|
| 42 |
WHERE success = 1
|
|
|
|
| 55 |
clamps = {
|
| 56 |
# Defaults as fallback if queries return nothing
|
| 57 |
'slippage': 1.0,
|
| 58 |
+
|
| 59 |
'total_usd': 100000.0,
|
| 60 |
'history_bought_cost_sol': 30.0,
|
| 61 |
'realized_profit_sol': 150.0,
|
|
|
|
| 64 |
if trade_row and trade_row[0]:
|
| 65 |
r = trade_row[0]
|
| 66 |
clamps['slippage'] = max(float(r[0]), 0.01)
|
| 67 |
+
clamps['total_usd'] = max(float(r[1]), 1.0)
|
|
|
|
| 68 |
|
| 69 |
if holdings_row and holdings_row[0]:
|
| 70 |
r = holdings_row[0]
|
scripts/cache_dataset.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
-
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
@@ -61,7 +61,6 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 61 |
data_fetcher=data_fetcher,
|
| 62 |
max_samples=dataset_config['max_samples'],
|
| 63 |
start_date=dataset_config['start_date'],
|
| 64 |
-
ohlc_stats_path=dataset_config['ohlc_stats_path'],
|
| 65 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 66 |
quantiles=dataset_config['quantiles'],
|
| 67 |
min_trade_usd=dataset_config['min_trade_usd'],
|
|
@@ -112,21 +111,6 @@ def _process_single_token_context(args):
|
|
| 112 |
|
| 113 |
|
| 114 |
|
| 115 |
-
def compute_save_ohlc_stats(client, output_path):
|
| 116 |
-
print(f"INFO: Computing OHLC stats...")
|
| 117 |
-
query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
|
| 118 |
-
try:
|
| 119 |
-
result = client.execute(query)
|
| 120 |
-
if result and result[0]:
|
| 121 |
-
row = result[0]
|
| 122 |
-
stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
|
| 123 |
-
else:
|
| 124 |
-
stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
|
| 125 |
-
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 126 |
-
np.savez(output_path, **stats)
|
| 127 |
-
print(f"INFO: Saved OHLC stats to {output_path}")
|
| 128 |
-
except Exception as e:
|
| 129 |
-
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 130 |
|
| 131 |
|
| 132 |
def main():
|
|
@@ -142,7 +126,7 @@ def main():
|
|
| 142 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 143 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 144 |
parser.add_argument("--start_date", type=str, default=None)
|
| 145 |
-
|
| 146 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 147 |
|
| 148 |
parser.add_argument("--context_length", type=int, default=8192)
|
|
@@ -170,7 +154,6 @@ def main():
|
|
| 170 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 171 |
|
| 172 |
try:
|
| 173 |
-
compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
|
| 174 |
|
| 175 |
from data.data_loader import OracleDataset
|
| 176 |
from data.data_fetcher import DataFetcher
|
|
@@ -187,7 +170,7 @@ def main():
|
|
| 187 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 188 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 189 |
|
| 190 |
-
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt,
|
| 191 |
|
| 192 |
if len(dataset) == 0:
|
| 193 |
print("WARNING: No samples. Exiting.")
|
|
@@ -223,7 +206,7 @@ def main():
|
|
| 223 |
print(f"INFO: Workers: {args.num_workers}")
|
| 224 |
|
| 225 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 226 |
-
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, '
|
| 227 |
|
| 228 |
# Build tasks with class-aware multi-sampling for balanced cache
|
| 229 |
import random
|
|
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
+
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
|
|
| 61 |
data_fetcher=data_fetcher,
|
| 62 |
max_samples=dataset_config['max_samples'],
|
| 63 |
start_date=dataset_config['start_date'],
|
|
|
|
| 64 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 65 |
quantiles=dataset_config['quantiles'],
|
| 66 |
min_trade_usd=dataset_config['min_trade_usd'],
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def main():
|
|
|
|
| 126 |
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 127 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 128 |
parser.add_argument("--start_date", type=str, default=None)
|
| 129 |
+
|
| 130 |
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 131 |
|
| 132 |
parser.add_argument("--context_length", type=int, default=8192)
|
|
|
|
| 154 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 155 |
|
| 156 |
try:
|
|
|
|
| 157 |
|
| 158 |
from data.data_loader import OracleDataset
|
| 159 |
from data.data_fetcher import DataFetcher
|
|
|
|
| 170 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 171 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 172 |
|
| 173 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
|
| 174 |
|
| 175 |
if len(dataset) == 0:
|
| 176 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 206 |
print(f"INFO: Workers: {args.num_workers}")
|
| 207 |
|
| 208 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 209 |
+
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
|
| 210 |
|
| 211 |
# Build tasks with class-aware multi-sampling for balanced cache
|
| 212 |
import random
|
scripts/cache_parallel.py
CHANGED
|
@@ -29,7 +29,6 @@ def cache_chunk(args):
|
|
| 29 |
|
| 30 |
ds = OracleDataset(
|
| 31 |
data_fetcher=fetcher,
|
| 32 |
-
ohlc_stats_path=db_args['ohlc_stats_path'],
|
| 33 |
horizons_seconds=[30, 60, 120, 240, 420],
|
| 34 |
quantiles=[0.1, 0.5, 0.9],
|
| 35 |
)
|
|
@@ -91,7 +90,7 @@ def main():
|
|
| 91 |
|
| 92 |
fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
|
| 93 |
return_map, _ = get_return_class_map(ch)
|
| 94 |
-
ds = OracleDataset(data_fetcher=fetcher,
|
| 95 |
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
|
| 96 |
ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
|
| 97 |
total = len(ds)
|
|
@@ -108,7 +107,7 @@ def main():
|
|
| 108 |
'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
|
| 109 |
'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
|
| 110 |
'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),
|
| 111 |
-
|
| 112 |
}
|
| 113 |
|
| 114 |
tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]
|
|
|
|
| 29 |
|
| 30 |
ds = OracleDataset(
|
| 31 |
data_fetcher=fetcher,
|
|
|
|
| 32 |
horizons_seconds=[30, 60, 120, 240, 420],
|
| 33 |
quantiles=[0.1, 0.5, 0.9],
|
| 34 |
)
|
|
|
|
| 90 |
|
| 91 |
fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
|
| 92 |
return_map, _ = get_return_class_map(ch)
|
| 93 |
+
ds = OracleDataset(data_fetcher=fetcher,
|
| 94 |
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
|
| 95 |
ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
|
| 96 |
total = len(ds)
|
|
|
|
| 107 |
'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
|
| 108 |
'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
|
| 109 |
'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),
|
| 110 |
+
|
| 111 |
}
|
| 112 |
|
| 113 |
tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]
|
train.py
CHANGED
|
@@ -328,7 +328,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 328 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 329 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 330 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 331 |
-
|
| 332 |
parser.add_argument("--t_cutoff_seconds", type=int, default=60)
|
| 333 |
parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
|
| 334 |
parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
|
|
@@ -473,7 +473,6 @@ def main() -> None:
|
|
| 473 |
horizons_seconds=horizons,
|
| 474 |
quantiles=quantiles,
|
| 475 |
max_samples=args.max_samples,
|
| 476 |
-
ohlc_stats_path=args.ohlc_stats_path,
|
| 477 |
t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
|
| 478 |
cache_dir="/workspace/apollo/data/cache"
|
| 479 |
)
|
|
|
|
| 328 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 329 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 330 |
parser.add_argument("--max_samples", type=int, default=None)
|
| 331 |
+
|
| 332 |
parser.add_argument("--t_cutoff_seconds", type=int, default=60)
|
| 333 |
parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
|
| 334 |
parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
|
|
|
|
| 473 |
horizons_seconds=horizons,
|
| 474 |
quantiles=quantiles,
|
| 475 |
max_samples=args.max_samples,
|
|
|
|
| 476 |
t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
|
| 477 |
cache_dir="/workspace/apollo/data/cache"
|
| 478 |
)
|
train.sh
CHANGED
|
@@ -16,7 +16,6 @@ accelerate launch train.py \
|
|
| 16 |
--max_seq_len 4096 \
|
| 17 |
--horizons_seconds 300 900 1800 3600 7200 \
|
| 18 |
--quantiles 0.1 0.5 0.9 \
|
| 19 |
-
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 20 |
--num_workers 0 \
|
| 21 |
--val_samples_per_class 2 \
|
| 22 |
--val_every 100 \
|
|
|
|
| 16 |
--max_seq_len 4096 \
|
| 17 |
--horizons_seconds 300 900 1800 3600 7200 \
|
| 18 |
--quantiles 0.1 0.5 0.9 \
|
|
|
|
| 19 |
--num_workers 0 \
|
| 20 |
--val_samples_per_class 2 \
|
| 21 |
--val_every 100 \
|
train.yaml
CHANGED
|
@@ -14,7 +14,6 @@ data:
|
|
| 14 |
quantiles: [0.1, 0.5, 0.9]
|
| 15 |
max_seq_len: 4096
|
| 16 |
ohlc_seq_len: 300
|
| 17 |
-
ohlc_stats_path: ./data/ohlc_stats.npz
|
| 18 |
t_cutoff_seconds: 60
|
| 19 |
shuffle: true
|
| 20 |
num_workers: 4
|
|
|
|
| 14 |
quantiles: [0.1, 0.5, 0.9]
|
| 15 |
max_seq_len: 4096
|
| 16 |
ohlc_seq_len: 300
|
|
|
|
| 17 |
t_cutoff_seconds: 60
|
| 18 |
shuffle: true
|
| 19 |
num_workers: 4
|
validate.py
CHANGED
|
@@ -100,7 +100,6 @@ def main() -> None:
|
|
| 100 |
ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
|
| 101 |
default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
|
| 102 |
t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
|
| 103 |
-
ohlc_stats_path = data_cfg.get("ohlc_stats_path", "./data/ohlc_stats.npz")
|
| 104 |
|
| 105 |
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
|
| 106 |
time_encoder = ContextualTimeEncoder(dtype=dtype)
|
|
@@ -140,7 +139,6 @@ def main() -> None:
|
|
| 140 |
horizons_seconds=horizons,
|
| 141 |
quantiles=quantiles,
|
| 142 |
max_samples=max_samples,
|
| 143 |
-
ohlc_stats_path=ohlc_stats_path,
|
| 144 |
token_allowlist=[args.token_address] if args.token_address else None,
|
| 145 |
t_cutoff_seconds=t_cutoff_seconds
|
| 146 |
)
|
|
|
|
| 100 |
ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
|
| 101 |
default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
|
| 102 |
t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
|
|
|
|
| 103 |
|
| 104 |
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
|
| 105 |
time_encoder = ContextualTimeEncoder(dtype=dtype)
|
|
|
|
| 139 |
horizons_seconds=horizons,
|
| 140 |
quantiles=quantiles,
|
| 141 |
max_samples=max_samples,
|
|
|
|
| 142 |
token_allowlist=[args.token_address] if args.token_address else None,
|
| 143 |
t_cutoff_seconds=t_cutoff_seconds
|
| 144 |
)
|