Upload folder using huggingface_hub
Browse files- data/data_collator.py +15 -45
- data/data_loader.py +65 -1
- data/ohlc_stats.npz +1 -1
- log.log +2 -2
- models/multi_modal_processor.py +18 -2
- pre_cache.sh +1 -8
- scripts/cache_dataset.py +26 -1
- scripts/dump_cache_sample.py +18 -1
- train.py +6 -2
data/data_collator.py
CHANGED
|
@@ -6,31 +6,7 @@ from torch.nn.utils.rnn import pad_sequence
|
|
| 6 |
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 7 |
from collections import defaultdict
|
| 8 |
from PIL import Image
|
| 9 |
-
# --- GLOBAL SINGLETON FOR WORKER PROCESSES ---
|
| 10 |
-
_WORKER_ENCODER = None
|
| 11 |
-
|
| 12 |
-
def _set_worker_encoder(encoder):
|
| 13 |
-
"""
|
| 14 |
-
Pre-set the encoder for workers (called from main process before forking).
|
| 15 |
-
This avoids lazy-loading on first batch.
|
| 16 |
-
"""
|
| 17 |
-
global _WORKER_ENCODER
|
| 18 |
-
_WORKER_ENCODER = encoder
|
| 19 |
-
|
| 20 |
-
def _get_worker_encoder(model_id: str, dtype: torch.dtype, device: torch.device):
|
| 21 |
-
"""
|
| 22 |
-
Lazy-loads the encoder on the worker process.
|
| 23 |
-
FORCED TO CPU to save VRAM when using multiple workers.
|
| 24 |
-
"""
|
| 25 |
-
global _WORKER_ENCODER
|
| 26 |
-
if _WORKER_ENCODER is None:
|
| 27 |
-
print(f"[Worker] Initializing MultiModalEncoder (SigLIP) on CPU (VRAM optimization)...")
|
| 28 |
-
# Local import to avoid top-level dependency issues
|
| 29 |
-
from models.multi_modal_processor import MultiModalEncoder
|
| 30 |
-
# Explicitly pass device="cpu"
|
| 31 |
-
_WORKER_ENCODER = MultiModalEncoder(model_id=model_id, dtype=dtype, device="cpu")
|
| 32 |
-
|
| 33 |
-
return _WORKER_ENCODER
|
| 34 |
|
| 35 |
import models.vocabulary as vocab
|
| 36 |
from data.data_loader import EmbeddingPooler
|
|
@@ -235,27 +211,21 @@ class MemecoinCollator:
|
|
| 235 |
|
| 236 |
# --- 2. Create the single, batch-wide embedding pool tensor ---
|
| 237 |
all_items_sorted = batch_wide_pooler.get_all_items()
|
| 238 |
-
texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
|
| 239 |
-
images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
elif isinstance(item_data['item'], Image.Image):
|
| 256 |
-
if image_embeds.numel() > 0:
|
| 257 |
-
batch_embedding_pool[i] = image_embeds[image_cursor]
|
| 258 |
-
image_cursor += 1
|
| 259 |
|
| 260 |
# --- 3. Remap all indices in the batch data ---
|
| 261 |
for i, item in enumerate(batch):
|
|
|
|
| 6 |
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 7 |
from collections import defaultdict
|
| 8 |
from PIL import Image
|
| 9 |
+
# --- GLOBAL SINGLETON FOR WORKER PROCESSES REMOVED ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
import models.vocabulary as vocab
|
| 12 |
from data.data_loader import EmbeddingPooler
|
|
|
|
| 211 |
|
| 212 |
# --- 2. Create the single, batch-wide embedding pool tensor ---
|
| 213 |
all_items_sorted = batch_wide_pooler.get_all_items()
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
if not all_items_sorted:
|
| 216 |
+
# Handle edge case of absolutely no embeddings in batch
|
| 217 |
+
# Create a dummy empty tensor
|
| 218 |
+
batch_embedding_pool = torch.empty(0, 768, device=self.device, dtype=self.dtype) # Default SigLIP dim is 1152 actually, but standard is 768. Better to infer or default.
|
| 219 |
+
# Actually, if empty, it doesn't matter much as long as it's not accessed.
|
| 220 |
+
else:
|
| 221 |
+
first_item = all_items_sorted[0]['item']
|
| 222 |
+
if not isinstance(first_item, torch.Tensor):
|
| 223 |
+
raise RuntimeError(f"Collator expects pre-computed embeddings (torch.Tensor), found {type(first_item)}. Please rebuild cache.")
|
| 224 |
+
|
| 225 |
+
# Stack all embeddings
|
| 226 |
+
# They should already be CPU tensors from the loader
|
| 227 |
+
# Move to device and cast to dtype
|
| 228 |
+
batch_embedding_pool = torch.stack([d['item'] for d in all_items_sorted]).to(device=self.device, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
# --- 3. Remap all indices in the batch data ---
|
| 231 |
for i, item in enumerate(batch):
|
data/data_loader.py
CHANGED
|
@@ -90,6 +90,8 @@ class EmbeddingPooler:
|
|
| 90 |
key = item.strip() # use normalized text key
|
| 91 |
elif isinstance(item, Image.Image):
|
| 92 |
key = id(item) # unique memory address for images
|
|
|
|
|
|
|
| 93 |
else:
|
| 94 |
key = item # fallback: use object itself if hashable
|
| 95 |
|
|
@@ -142,6 +144,8 @@ class OracleDataset(Dataset):
|
|
| 142 |
# initialization falls through an unexpected branch.
|
| 143 |
self.cached_files = []
|
| 144 |
self.weights_list = []
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# If a fetcher is provided, we can determine the number of samples.
|
| 147 |
# Otherwise, we are likely in a test mode where __len__ might not be called
|
|
@@ -2475,7 +2479,58 @@ class OracleDataset(Dataset):
|
|
| 2475 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
| 2476 |
}
|
| 2477 |
|
| 2478 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2479 |
"""
|
| 2480 |
Generates fully processed training contexts for caching.
|
| 2481 |
|
|
@@ -2968,5 +3023,14 @@ class OracleDataset(Dataset):
|
|
| 2968 |
results.append(result)
|
| 2969 |
pass # Per-context verbose logging removed for caching speed
|
| 2970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2971 |
# Final count logged via tqdm in cache_dataset.py
|
| 2972 |
return results
|
|
|
|
| 90 |
key = item.strip() # use normalized text key
|
| 91 |
elif isinstance(item, Image.Image):
|
| 92 |
key = id(item) # unique memory address for images
|
| 93 |
+
elif isinstance(item, torch.Tensor):
|
| 94 |
+
key = id(item) # unique memory address for tensors
|
| 95 |
else:
|
| 96 |
key = item # fallback: use object itself if hashable
|
| 97 |
|
|
|
|
| 144 |
# initialization falls through an unexpected branch.
|
| 145 |
self.cached_files = []
|
| 146 |
self.weights_list = []
|
| 147 |
+
|
| 148 |
+
|
| 149 |
|
| 150 |
# If a fetcher is provided, we can determine the number of samples.
|
| 151 |
# Otherwise, we are likely in a test mode where __len__ might not be called
|
|
|
|
| 2479 |
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
| 2480 |
}
|
| 2481 |
|
| 2482 |
+
def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
|
| 2483 |
+
"""
|
| 2484 |
+
Helper to replace raw items in the embedding pooler with pre-computed embeddings
|
| 2485 |
+
using the provided encoder (on GPU).
|
| 2486 |
+
"""
|
| 2487 |
+
pooler = context.get('embedding_pooler')
|
| 2488 |
+
if not pooler:
|
| 2489 |
+
return
|
| 2490 |
+
|
| 2491 |
+
# Direct access to pool_map
|
| 2492 |
+
keys_to_embed_img = []
|
| 2493 |
+
images_to_embed = []
|
| 2494 |
+
|
| 2495 |
+
keys_to_embed_text = []
|
| 2496 |
+
texts_to_embed = []
|
| 2497 |
+
|
| 2498 |
+
for key, entry in pooler.pool_map.items():
|
| 2499 |
+
item = entry['item']
|
| 2500 |
+
if isinstance(item, str):
|
| 2501 |
+
# Strings (text)
|
| 2502 |
+
keys_to_embed_text.append(key)
|
| 2503 |
+
texts_to_embed.append(item)
|
| 2504 |
+
elif hasattr(item, 'resize') and not isinstance(item, torch.Tensor): # Duck typing to catch all PIL images
|
| 2505 |
+
keys_to_embed_img.append(key)
|
| 2506 |
+
images_to_embed.append(item)
|
| 2507 |
+
|
| 2508 |
+
# Batch encode images
|
| 2509 |
+
if images_to_embed:
|
| 2510 |
+
# print(f"DEBUG: Found {len(images_to_embed)} images to embed", flush=True)
|
| 2511 |
+
with torch.no_grad():
|
| 2512 |
+
img_embeddings = encoder(images_to_embed)
|
| 2513 |
+
|
| 2514 |
+
# Update pool_map directly for images
|
| 2515 |
+
for i, (key, emb) in enumerate(zip(keys_to_embed_img, img_embeddings)):
|
| 2516 |
+
if key in pooler.pool_map:
|
| 2517 |
+
old_entry = pooler.pool_map[key]
|
| 2518 |
+
pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']}
|
| 2519 |
+
|
| 2520 |
+
# Batch encode text
|
| 2521 |
+
if texts_to_embed:
|
| 2522 |
+
# print(f"DEBUG: Found {len(texts_to_embed)} text items to embed", flush=True)
|
| 2523 |
+
with torch.no_grad():
|
| 2524 |
+
text_embeddings = encoder(texts_to_embed)
|
| 2525 |
+
|
| 2526 |
+
# Update pool_map directly for text
|
| 2527 |
+
for i, (key, emb) in enumerate(zip(keys_to_embed_text, text_embeddings)):
|
| 2528 |
+
if key in pooler.pool_map:
|
| 2529 |
+
old_entry = pooler.pool_map[key]
|
| 2530 |
+
pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']}
|
| 2531 |
+
|
| 2532 |
+
|
| 2533 |
+
def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1, encoder: Optional[Any] = None) -> List[Optional[Dict[str, Any]]]:
|
| 2534 |
"""
|
| 2535 |
Generates fully processed training contexts for caching.
|
| 2536 |
|
|
|
|
| 3023 |
results.append(result)
|
| 3024 |
pass # Per-context verbose logging removed for caching speed
|
| 3025 |
|
| 3026 |
+
# --- OPTIONAL: Pre-compute Embeddings (if encoder provided) ---
|
| 3027 |
+
if encoder is not None:
|
| 3028 |
+
# print(f"DEBUG: Encoder provided to loader for {len(results)} contexts", flush=True)
|
| 3029 |
+
for ctx in results:
|
| 3030 |
+
self._embed_context(ctx, encoder)
|
| 3031 |
+
else:
|
| 3032 |
+
if idx == 0:
|
| 3033 |
+
print("DEBUG: No encoder provided to __cacheitem_context__", flush=True)
|
| 3034 |
+
|
| 3035 |
# Final count logged via tqdm in cache_dataset.py
|
| 3036 |
return results
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ecfddd649b981eacb14b68ac183e53a273cab3571ec566eaa64700bbd871e42
|
| 3 |
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6751984e91418c4e318a9a01327bc60f6d0ee18282e7f12b42fe35562611c0d
|
| 3 |
+
size 13919
|
models/multi_modal_processor.py
CHANGED
|
@@ -88,9 +88,24 @@ class MultiModalEncoder:
|
|
| 88 |
inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
| 89 |
embeddings = self.model.get_text_features(**inputs)
|
| 90 |
else:
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
embeddings = self.model.get_image_features(**inputs)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# Normalize in float32 for numerical stability
|
| 95 |
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
|
| 96 |
|
|
@@ -99,7 +114,8 @@ class MultiModalEncoder:
|
|
| 99 |
|
| 100 |
except Exception as e:
|
| 101 |
# Silently fail or log debug only if needed
|
| 102 |
-
|
|
|
|
| 103 |
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 104 |
|
| 105 |
# --- Test block (SigLIP) ---
|
|
|
|
| 88 |
inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
| 89 |
embeddings = self.model.get_text_features(**inputs)
|
| 90 |
else:
|
| 91 |
+
# Ensure all images are RGB to avoid "Unable to infer channel dimension format"
|
| 92 |
+
valid_images = [img.convert("RGB") for img in x]
|
| 93 |
+
inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device)
|
| 94 |
embeddings = self.model.get_image_features(**inputs)
|
| 95 |
|
| 96 |
+
# EXTRACT TENSOR IF OUTPUT IS A MODEL OUTPUT OBJECT
|
| 97 |
+
if not isinstance(embeddings, torch.Tensor):
|
| 98 |
+
if hasattr(embeddings, 'pooler_output'):
|
| 99 |
+
embeddings = embeddings.pooler_output
|
| 100 |
+
elif hasattr(embeddings, 'last_hidden_state'):
|
| 101 |
+
# Fallback for models without pooler_output but with hidden state (e.g. usage of [CLS] or mean pooling needed?)
|
| 102 |
+
# For SigLIP/CLIP get_image_features, it should return the features.
|
| 103 |
+
# If it returns an object, it might be the raw output.
|
| 104 |
+
# Let's try to assume it matches the expected embedding dim.
|
| 105 |
+
embeddings = embeddings.last_hidden_state
|
| 106 |
+
elif isinstance(embeddings, (tuple, list)):
|
| 107 |
+
embeddings = embeddings[0]
|
| 108 |
+
|
| 109 |
# Normalize in float32 for numerical stability
|
| 110 |
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
|
| 111 |
|
|
|
|
| 114 |
|
| 115 |
except Exception as e:
|
| 116 |
# Silently fail or log debug only if needed
|
| 117 |
+
print(f"ERROR in MultiModalEncoder: {e}", flush=True)
|
| 118 |
+
traceback.print_exc()
|
| 119 |
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 120 |
|
| 121 |
# --- Test block (SigLIP) ---
|
pre_cache.sh
CHANGED
|
@@ -1,11 +1,3 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
# Pre-caches the dataset for training in context mode
|
| 3 |
-
#
|
| 4 |
-
# Usage:
|
| 5 |
-
# ./pre_cache.sh
|
| 6 |
-
|
| 7 |
-
set -euo pipefail
|
| 8 |
-
|
| 9 |
# =========================
|
| 10 |
# Hardcoded cache settings
|
| 11 |
# =========================
|
|
@@ -48,6 +40,7 @@ python3 scripts/cache_dataset.py \
|
|
| 48 |
--num_workers "$NUM_WORKERS" \
|
| 49 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 50 |
--quantiles "${QUANTILES[@]}" \
|
|
|
|
| 51 |
"$@"
|
| 52 |
|
| 53 |
echo "Done!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# =========================
|
| 2 |
# Hardcoded cache settings
|
| 3 |
# =========================
|
|
|
|
| 40 |
--num_workers "$NUM_WORKERS" \
|
| 41 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 42 |
--quantiles "${QUANTILES[@]}" \
|
| 43 |
+
--max_samples 150000 \
|
| 44 |
"$@"
|
| 45 |
|
| 46 |
echo "Done!"
|
scripts/cache_dataset.py
CHANGED
|
@@ -14,6 +14,7 @@ import huggingface_hub
|
|
| 14 |
import logging
|
| 15 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 16 |
import multiprocessing as mp
|
|
|
|
| 17 |
|
| 18 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 19 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
@@ -30,6 +31,7 @@ from neo4j import GraphDatabase
|
|
| 30 |
_worker_dataset = None
|
| 31 |
_worker_return_class_map = None
|
| 32 |
_worker_quality_scores_map = None
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
|
@@ -40,6 +42,20 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 40 |
clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
|
| 41 |
neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
|
| 42 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
_worker_dataset = OracleDataset(
|
| 45 |
data_fetcher=data_fetcher,
|
|
@@ -63,7 +79,15 @@ def _process_single_token_context(args):
|
|
| 63 |
class_id = _worker_return_class_map.get(mint_addr)
|
| 64 |
if class_id is None:
|
| 65 |
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if not contexts:
|
| 68 |
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
|
| 69 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
|
@@ -75,6 +99,7 @@ def _process_single_token_context(args):
|
|
| 75 |
ctx["class_id"] = class_id
|
| 76 |
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 77 |
output_path = Path(output_dir) / filename
|
|
|
|
| 78 |
torch.save(ctx, output_path)
|
| 79 |
saved_files.append(filename)
|
| 80 |
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
|
|
|
|
| 14 |
import logging
|
| 15 |
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 16 |
import multiprocessing as mp
|
| 17 |
+
from PIL import Image
|
| 18 |
|
| 19 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 20 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
|
|
| 31 |
_worker_dataset = None
|
| 32 |
_worker_return_class_map = None
|
| 33 |
_worker_quality_scores_map = None
|
| 34 |
+
_worker_encoder = None
|
| 35 |
|
| 36 |
|
| 37 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
|
|
|
| 42 |
clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
|
| 43 |
neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
|
| 44 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 45 |
+
|
| 46 |
+
# --- NEW: Init Encoder on GPU ---
|
| 47 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 48 |
+
# Using float16 for efficiency on GPU
|
| 49 |
+
global _worker_encoder
|
| 50 |
+
try:
|
| 51 |
+
_worker_encoder = MultiModalEncoder(
|
| 52 |
+
model_id="google/siglip-so400m-patch16-256-i18n",
|
| 53 |
+
device="cuda",
|
| 54 |
+
dtype=torch.float16
|
| 55 |
+
)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"WARN: Failed to initialize MultiModalEncoder on worker: {e}")
|
| 58 |
+
_worker_encoder = None
|
| 59 |
|
| 60 |
_worker_dataset = OracleDataset(
|
| 61 |
data_fetcher=data_fetcher,
|
|
|
|
| 79 |
class_id = _worker_return_class_map.get(mint_addr)
|
| 80 |
if class_id is None:
|
| 81 |
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
|
| 82 |
+
|
| 83 |
+
# Pass the global encoder (if initialized) to pre-compute embeddings
|
| 84 |
+
global _worker_encoder
|
| 85 |
+
encoder = _worker_encoder
|
| 86 |
+
# print(f"DEBUG: Worker encoder status: {type(encoder)}", flush=True) # Commented out to reduce noise if it works
|
| 87 |
+
if encoder is None:
|
| 88 |
+
print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
|
| 89 |
+
|
| 90 |
+
contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token, encoder=encoder)
|
| 91 |
if not contexts:
|
| 92 |
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
|
| 93 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
|
|
|
| 99 |
ctx["class_id"] = class_id
|
| 100 |
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 101 |
output_path = Path(output_dir) / filename
|
| 102 |
+
|
| 103 |
torch.save(ctx, output_path)
|
| 104 |
saved_files.append(filename)
|
| 105 |
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
|
scripts/dump_cache_sample.py
CHANGED
|
@@ -36,7 +36,24 @@ def convert_to_serializable(obj):
|
|
| 36 |
if isinstance(obj, np.ndarray):
|
| 37 |
return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
|
| 38 |
if isinstance(obj, torch.Tensor):
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
if isinstance(obj, datetime):
|
| 41 |
return {"__type__": "datetime", "value": obj.isoformat()}
|
| 42 |
if isinstance(obj, bytes):
|
|
|
|
| 36 |
if isinstance(obj, np.ndarray):
|
| 37 |
return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
|
| 38 |
if isinstance(obj, torch.Tensor):
|
| 39 |
+
data = obj.tolist()
|
| 40 |
+
# Truncate large tensors for readability
|
| 41 |
+
if obj.numel() > 50:
|
| 42 |
+
flat = obj.flatten().tolist()
|
| 43 |
+
data = flat[:20] + [f"... ({obj.numel()} elements total)"]
|
| 44 |
+
return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": data}
|
| 45 |
+
|
| 46 |
+
# Handle EmbeddingPooler specifically
|
| 47 |
+
if type(obj).__name__ == 'EmbeddingPooler':
|
| 48 |
+
try:
|
| 49 |
+
items = obj.get_all_items()
|
| 50 |
+
return {
|
| 51 |
+
"__type__": "EmbeddingPooler",
|
| 52 |
+
"count": len(items),
|
| 53 |
+
"items": [convert_to_serializable(item) for item in items]
|
| 54 |
+
}
|
| 55 |
+
except:
|
| 56 |
+
return {"__type__": "EmbeddingPooler", "repr": str(obj)}
|
| 57 |
if isinstance(obj, datetime):
|
| 58 |
return {"__type__": "datetime", "value": obj.isoformat()}
|
| 59 |
if isinstance(obj, bytes):
|
train.py
CHANGED
|
@@ -132,8 +132,12 @@ def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
|
| 132 |
# Group indices by class_id - use dataset's existing map if available
|
| 133 |
class_to_indices = defaultdict(list)
|
| 134 |
|
| 135 |
-
# Fast path: use dataset's
|
| 136 |
-
if hasattr(dataset, '
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
for idx, cached_file in enumerate(dataset.cached_files):
|
| 138 |
# file_class_map uses filename strings as keys, cached_files are Path objects
|
| 139 |
fname = cached_file.name if hasattr(cached_file, 'name') else str(cached_file)
|
|
|
|
| 132 |
# Group indices by class_id - use dataset's existing map if available
|
| 133 |
class_to_indices = defaultdict(list)
|
| 134 |
|
| 135 |
+
# Fast path: use dataset's sample_labels (aligned with __getitem__)
|
| 136 |
+
if hasattr(dataset, 'sample_labels') and dataset.sample_labels:
|
| 137 |
+
for idx, class_id in enumerate(dataset.sample_labels):
|
| 138 |
+
class_to_indices[class_id].append(idx)
|
| 139 |
+
# Legacy path: use dataset's file_class_map (for 1-file-1-sample datasets)
|
| 140 |
+
elif hasattr(dataset, 'file_class_map') and dataset.file_class_map:
|
| 141 |
for idx, cached_file in enumerate(dataset.cached_files):
|
| 142 |
# file_class_map uses filename strings as keys, cached_files are Path objects
|
| 143 |
fname = cached_file.name if hasattr(cached_file, 'name') else str(cached_file)
|