zirobtc commited on
Commit
41e0423
·
1 Parent(s): 16f4534

Upload folder using huggingface_hub

Browse files
data/data_collator.py CHANGED
@@ -6,31 +6,7 @@ from torch.nn.utils.rnn import pad_sequence
6
  from typing import List, Dict, Any, Tuple, Optional, Union
7
  from collections import defaultdict
8
  from PIL import Image
9
- # --- GLOBAL SINGLETON FOR WORKER PROCESSES ---
10
- _WORKER_ENCODER = None
11
-
12
- def _set_worker_encoder(encoder):
13
- """
14
- Pre-set the encoder for workers (called from main process before forking).
15
- This avoids lazy-loading on first batch.
16
- """
17
- global _WORKER_ENCODER
18
- _WORKER_ENCODER = encoder
19
-
20
- def _get_worker_encoder(model_id: str, dtype: torch.dtype, device: torch.device):
21
- """
22
- Lazy-loads the encoder on the worker process.
23
- FORCED TO CPU to save VRAM when using multiple workers.
24
- """
25
- global _WORKER_ENCODER
26
- if _WORKER_ENCODER is None:
27
- print(f"[Worker] Initializing MultiModalEncoder (SigLIP) on CPU (VRAM optimization)...")
28
- # Local import to avoid top-level dependency issues
29
- from models.multi_modal_processor import MultiModalEncoder
30
- # Explicitly pass device="cpu"
31
- _WORKER_ENCODER = MultiModalEncoder(model_id=model_id, dtype=dtype, device="cpu")
32
-
33
- return _WORKER_ENCODER
34
 
35
  import models.vocabulary as vocab
36
  from data.data_loader import EmbeddingPooler
@@ -235,27 +211,21 @@ class MemecoinCollator:
235
 
236
  # --- 2. Create the single, batch-wide embedding pool tensor ---
237
  all_items_sorted = batch_wide_pooler.get_all_items()
238
- texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
239
- images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
240
 
241
- # LAZY LOAD ENCODER
242
- encoder = _get_worker_encoder(self.model_id, self.dtype, self.device)
243
-
244
- text_embeds = encoder(texts_to_encode).to(self.device) if texts_to_encode else torch.empty(0)
245
- image_embeds = encoder(images_to_encode).to(self.device) if images_to_encode else torch.empty(0)
246
-
247
- # Create the final lookup tensor and fill it based on original item type
248
- batch_embedding_pool = torch.zeros(len(all_items_sorted), encoder.embedding_dim, device=self.device, dtype=self.dtype)
249
- text_cursor, image_cursor = 0, 0
250
- for i, item_data in enumerate(all_items_sorted):
251
- if isinstance(item_data['item'], str):
252
- if text_embeds.numel() > 0:
253
- batch_embedding_pool[i] = text_embeds[text_cursor]
254
- text_cursor += 1
255
- elif isinstance(item_data['item'], Image.Image):
256
- if image_embeds.numel() > 0:
257
- batch_embedding_pool[i] = image_embeds[image_cursor]
258
- image_cursor += 1
259
 
260
  # --- 3. Remap all indices in the batch data ---
261
  for i, item in enumerate(batch):
 
6
  from typing import List, Dict, Any, Tuple, Optional, Union
7
  from collections import defaultdict
8
  from PIL import Image
9
+ # --- GLOBAL SINGLETON FOR WORKER PROCESSES REMOVED ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import models.vocabulary as vocab
12
  from data.data_loader import EmbeddingPooler
 
211
 
212
  # --- 2. Create the single, batch-wide embedding pool tensor ---
213
  all_items_sorted = batch_wide_pooler.get_all_items()
 
 
214
 
215
+ if not all_items_sorted:
216
+ # Handle edge case of absolutely no embeddings in batch
217
+ # Create a dummy empty tensor
218
+ batch_embedding_pool = torch.empty(0, 768, device=self.device, dtype=self.dtype) # Default SigLIP dim is 1152 actually, but standard is 768. Better to infer or default.
219
+ # Actually, if empty, it doesn't matter much as long as it's not accessed.
220
+ else:
221
+ first_item = all_items_sorted[0]['item']
222
+ if not isinstance(first_item, torch.Tensor):
223
+ raise RuntimeError(f"Collator expects pre-computed embeddings (torch.Tensor), found {type(first_item)}. Please rebuild cache.")
224
+
225
+ # Stack all embeddings
226
+ # They should already be CPU tensors from the loader
227
+ # Move to device and cast to dtype
228
+ batch_embedding_pool = torch.stack([d['item'] for d in all_items_sorted]).to(device=self.device, dtype=self.dtype)
 
 
 
 
229
 
230
  # --- 3. Remap all indices in the batch data ---
231
  for i, item in enumerate(batch):
data/data_loader.py CHANGED
@@ -90,6 +90,8 @@ class EmbeddingPooler:
90
  key = item.strip() # use normalized text key
91
  elif isinstance(item, Image.Image):
92
  key = id(item) # unique memory address for images
 
 
93
  else:
94
  key = item # fallback: use object itself if hashable
95
 
@@ -142,6 +144,8 @@ class OracleDataset(Dataset):
142
  # initialization falls through an unexpected branch.
143
  self.cached_files = []
144
  self.weights_list = []
 
 
145
 
146
  # If a fetcher is provided, we can determine the number of samples.
147
  # Otherwise, we are likely in a test mode where __len__ might not be called
@@ -2475,7 +2479,58 @@ class OracleDataset(Dataset):
2475
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
2476
  }
2477
 
2478
- def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1) -> List[Optional[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2479
  """
2480
  Generates fully processed training contexts for caching.
2481
 
@@ -2968,5 +3023,14 @@ class OracleDataset(Dataset):
2968
  results.append(result)
2969
  pass # Per-context verbose logging removed for caching speed
2970
 
 
 
 
 
 
 
 
 
 
2971
  # Final count logged via tqdm in cache_dataset.py
2972
  return results
 
90
  key = item.strip() # use normalized text key
91
  elif isinstance(item, Image.Image):
92
  key = id(item) # unique memory address for images
93
+ elif isinstance(item, torch.Tensor):
94
+ key = id(item) # unique memory address for tensors
95
  else:
96
  key = item # fallback: use object itself if hashable
97
 
 
144
  # initialization falls through an unexpected branch.
145
  self.cached_files = []
146
  self.weights_list = []
147
+
148
+
149
 
150
  # If a fetcher is provided, we can determine the number of samples.
151
  # Otherwise, we are likely in a test mode where __len__ might not be called
 
2479
  'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
2480
  }
2481
 
2482
+ def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
2483
+ """
2484
+ Helper to replace raw items in the embedding pooler with pre-computed embeddings
2485
+ using the provided encoder (on GPU).
2486
+ """
2487
+ pooler = context.get('embedding_pooler')
2488
+ if not pooler:
2489
+ return
2490
+
2491
+ # Direct access to pool_map
2492
+ keys_to_embed_img = []
2493
+ images_to_embed = []
2494
+
2495
+ keys_to_embed_text = []
2496
+ texts_to_embed = []
2497
+
2498
+ for key, entry in pooler.pool_map.items():
2499
+ item = entry['item']
2500
+ if isinstance(item, str):
2501
+ # Strings (text)
2502
+ keys_to_embed_text.append(key)
2503
+ texts_to_embed.append(item)
2504
+ elif hasattr(item, 'resize') and not isinstance(item, torch.Tensor): # Duck typing to catch all PIL images
2505
+ keys_to_embed_img.append(key)
2506
+ images_to_embed.append(item)
2507
+
2508
+ # Batch encode images
2509
+ if images_to_embed:
2510
+ # print(f"DEBUG: Found {len(images_to_embed)} images to embed", flush=True)
2511
+ with torch.no_grad():
2512
+ img_embeddings = encoder(images_to_embed)
2513
+
2514
+ # Update pool_map directly for images
2515
+ for i, (key, emb) in enumerate(zip(keys_to_embed_img, img_embeddings)):
2516
+ if key in pooler.pool_map:
2517
+ old_entry = pooler.pool_map[key]
2518
+ pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']}
2519
+
2520
+ # Batch encode text
2521
+ if texts_to_embed:
2522
+ # print(f"DEBUG: Found {len(texts_to_embed)} text items to embed", flush=True)
2523
+ with torch.no_grad():
2524
+ text_embeddings = encoder(texts_to_embed)
2525
+
2526
+ # Update pool_map directly for text
2527
+ for i, (key, emb) in enumerate(zip(keys_to_embed_text, text_embeddings)):
2528
+ if key in pooler.pool_map:
2529
+ old_entry = pooler.pool_map[key]
2530
+ pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']}
2531
+
2532
+
2533
+ def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1, encoder: Optional[Any] = None) -> List[Optional[Dict[str, Any]]]:
2534
  """
2535
  Generates fully processed training contexts for caching.
2536
 
 
3023
  results.append(result)
3024
  pass # Per-context verbose logging removed for caching speed
3025
 
3026
+ # --- OPTIONAL: Pre-compute Embeddings (if encoder provided) ---
3027
+ if encoder is not None:
3028
+ # print(f"DEBUG: Encoder provided to loader for {len(results)} contexts", flush=True)
3029
+ for ctx in results:
3030
+ self._embed_context(ctx, encoder)
3031
+ else:
3032
+ if idx == 0:
3033
+ print("DEBUG: No encoder provided to __cacheitem_context__", flush=True)
3034
+
3035
  # Final count logged via tqdm in cache_dataset.py
3036
  return results
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:92f50d146182941b8b01be19b4699c1b0ebe37bac1ff155580b20a8755994070
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ecfddd649b981eacb14b68ac183e53a273cab3571ec566eaa64700bbd871e42
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e4ceaef802908dd650ce1ade210a0e827ec433b904adc3bf17c3d8a877e59ae6
3
- size 2854
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6751984e91418c4e318a9a01327bc60f6d0ee18282e7f12b42fe35562611c0d
3
+ size 13919
models/multi_modal_processor.py CHANGED
@@ -88,9 +88,24 @@ class MultiModalEncoder:
88
  inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
89
  embeddings = self.model.get_text_features(**inputs)
90
  else:
91
- inputs = self.processor(images=x, return_tensors="pt").to(self.device)
 
 
92
  embeddings = self.model.get_image_features(**inputs)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Normalize in float32 for numerical stability
95
  embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
96
 
@@ -99,7 +114,8 @@ class MultiModalEncoder:
99
 
100
  except Exception as e:
101
  # Silently fail or log debug only if needed
102
- # traceback.print_exc()
 
103
  return torch.empty(0, self.embedding_dim).to(self.device)
104
 
105
  # --- Test block (SigLIP) ---
 
88
  inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
89
  embeddings = self.model.get_text_features(**inputs)
90
  else:
91
+ # Ensure all images are RGB to avoid "Unable to infer channel dimension format"
92
+ valid_images = [img.convert("RGB") for img in x]
93
+ inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device)
94
  embeddings = self.model.get_image_features(**inputs)
95
 
96
+ # EXTRACT TENSOR IF OUTPUT IS A MODEL OUTPUT OBJECT
97
+ if not isinstance(embeddings, torch.Tensor):
98
+ if hasattr(embeddings, 'pooler_output'):
99
+ embeddings = embeddings.pooler_output
100
+ elif hasattr(embeddings, 'last_hidden_state'):
101
+ # Fallback for models without pooler_output but with hidden state (e.g. usage of [CLS] or mean pooling needed?)
102
+ # For SigLIP/CLIP get_image_features, it should return the features.
103
+ # If it returns an object, it might be the raw output.
104
+ # Let's try to assume it matches the expected embedding dim.
105
+ embeddings = embeddings.last_hidden_state
106
+ elif isinstance(embeddings, (tuple, list)):
107
+ embeddings = embeddings[0]
108
+
109
  # Normalize in float32 for numerical stability
110
  embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
111
 
 
114
 
115
  except Exception as e:
116
  # Silently fail or log debug only if needed
117
+ print(f"ERROR in MultiModalEncoder: {e}", flush=True)
118
+ traceback.print_exc()
119
  return torch.empty(0, self.embedding_dim).to(self.device)
120
 
121
  # --- Test block (SigLIP) ---
pre_cache.sh CHANGED
@@ -1,11 +1,3 @@
1
- #!/bin/bash
2
- # Pre-caches the dataset for training in context mode
3
- #
4
- # Usage:
5
- # ./pre_cache.sh
6
-
7
- set -euo pipefail
8
-
9
  # =========================
10
  # Hardcoded cache settings
11
  # =========================
@@ -48,6 +40,7 @@ python3 scripts/cache_dataset.py \
48
  --num_workers "$NUM_WORKERS" \
49
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
50
  --quantiles "${QUANTILES[@]}" \
 
51
  "$@"
52
 
53
  echo "Done!"
 
 
 
 
 
 
 
 
 
1
  # =========================
2
  # Hardcoded cache settings
3
  # =========================
 
40
  --num_workers "$NUM_WORKERS" \
41
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
42
  --quantiles "${QUANTILES[@]}" \
43
+ --max_samples 150000 \
44
  "$@"
45
 
46
  echo "Done!"
scripts/cache_dataset.py CHANGED
@@ -14,6 +14,7 @@ import huggingface_hub
14
  import logging
15
  from concurrent.futures import ProcessPoolExecutor, as_completed
16
  import multiprocessing as mp
 
17
 
18
  logging.getLogger("httpx").setLevel(logging.WARNING)
19
  logging.getLogger("transformers").setLevel(logging.ERROR)
@@ -30,6 +31,7 @@ from neo4j import GraphDatabase
30
  _worker_dataset = None
31
  _worker_return_class_map = None
32
  _worker_quality_scores_map = None
 
33
 
34
 
35
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
@@ -40,6 +42,20 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
40
  clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
41
  neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
42
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  _worker_dataset = OracleDataset(
45
  data_fetcher=data_fetcher,
@@ -63,7 +79,15 @@ def _process_single_token_context(args):
63
  class_id = _worker_return_class_map.get(mint_addr)
64
  if class_id is None:
65
  return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
66
- contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token)
 
 
 
 
 
 
 
 
67
  if not contexts:
68
  return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
69
  q_score = _worker_quality_scores_map.get(mint_addr)
@@ -75,6 +99,7 @@ def _process_single_token_context(args):
75
  ctx["class_id"] = class_id
76
  filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
77
  output_path = Path(output_dir) / filename
 
78
  torch.save(ctx, output_path)
79
  saved_files.append(filename)
80
  return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
 
14
  import logging
15
  from concurrent.futures import ProcessPoolExecutor, as_completed
16
  import multiprocessing as mp
17
+ from PIL import Image
18
 
19
  logging.getLogger("httpx").setLevel(logging.WARNING)
20
  logging.getLogger("transformers").setLevel(logging.ERROR)
 
31
  _worker_dataset = None
32
  _worker_return_class_map = None
33
  _worker_quality_scores_map = None
34
+ _worker_encoder = None
35
 
36
 
37
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
 
42
  clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
43
  neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
44
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
45
+
46
+ # --- NEW: Init Encoder on GPU ---
47
+ from models.multi_modal_processor import MultiModalEncoder
48
+ # Using float16 for efficiency on GPU
49
+ global _worker_encoder
50
+ try:
51
+ _worker_encoder = MultiModalEncoder(
52
+ model_id="google/siglip-so400m-patch16-256-i18n",
53
+ device="cuda",
54
+ dtype=torch.float16
55
+ )
56
+ except Exception as e:
57
+ print(f"WARN: Failed to initialize MultiModalEncoder on worker: {e}")
58
+ _worker_encoder = None
59
 
60
  _worker_dataset = OracleDataset(
61
  data_fetcher=data_fetcher,
 
79
  class_id = _worker_return_class_map.get(mint_addr)
80
  if class_id is None:
81
  return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
82
+
83
+ # Pass the global encoder (if initialized) to pre-compute embeddings
84
+ global _worker_encoder
85
+ encoder = _worker_encoder
86
+ # print(f"DEBUG: Worker encoder status: {type(encoder)}", flush=True) # Commented out to reduce noise if it works
87
+ if encoder is None:
88
+ print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
89
+
90
+ contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token, encoder=encoder)
91
  if not contexts:
92
  return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
93
  q_score = _worker_quality_scores_map.get(mint_addr)
 
99
  ctx["class_id"] = class_id
100
  filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
101
  output_path = Path(output_dir) / filename
102
+
103
  torch.save(ctx, output_path)
104
  saved_files.append(filename)
105
  return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
scripts/dump_cache_sample.py CHANGED
@@ -36,7 +36,24 @@ def convert_to_serializable(obj):
36
  if isinstance(obj, np.ndarray):
37
  return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
38
  if isinstance(obj, torch.Tensor):
39
- return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if isinstance(obj, datetime):
41
  return {"__type__": "datetime", "value": obj.isoformat()}
42
  if isinstance(obj, bytes):
 
36
  if isinstance(obj, np.ndarray):
37
  return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
38
  if isinstance(obj, torch.Tensor):
39
+ data = obj.tolist()
40
+ # Truncate large tensors for readability
41
+ if obj.numel() > 50:
42
+ flat = obj.flatten().tolist()
43
+ data = flat[:20] + [f"... ({obj.numel()} elements total)"]
44
+ return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": data}
45
+
46
+ # Handle EmbeddingPooler specifically
47
+ if type(obj).__name__ == 'EmbeddingPooler':
48
+ try:
49
+ items = obj.get_all_items()
50
+ return {
51
+ "__type__": "EmbeddingPooler",
52
+ "count": len(items),
53
+ "items": [convert_to_serializable(item) for item in items]
54
+ }
55
+ except:
56
+ return {"__type__": "EmbeddingPooler", "repr": str(obj)}
57
  if isinstance(obj, datetime):
58
  return {"__type__": "datetime", "value": obj.isoformat()}
59
  if isinstance(obj, bytes):
train.py CHANGED
@@ -132,8 +132,12 @@ def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
132
  # Group indices by class_id - use dataset's existing map if available
133
  class_to_indices = defaultdict(list)
134
 
135
- # Fast path: use dataset's file_class_map (already loaded during init)
136
- if hasattr(dataset, 'file_class_map') and dataset.file_class_map:
 
 
 
 
137
  for idx, cached_file in enumerate(dataset.cached_files):
138
  # file_class_map uses filename strings as keys, cached_files are Path objects
139
  fname = cached_file.name if hasattr(cached_file, 'name') else str(cached_file)
 
132
  # Group indices by class_id - use dataset's existing map if available
133
  class_to_indices = defaultdict(list)
134
 
135
+ # Fast path: use dataset's sample_labels (aligned with __getitem__)
136
+ if hasattr(dataset, 'sample_labels') and dataset.sample_labels:
137
+ for idx, class_id in enumerate(dataset.sample_labels):
138
+ class_to_indices[class_id].append(idx)
139
+ # Legacy path: use dataset's file_class_map (for 1-file-1-sample datasets)
140
+ elif hasattr(dataset, 'file_class_map') and dataset.file_class_map:
141
  for idx, cached_file in enumerate(dataset.cached_files):
142
  # file_class_map uses filename strings as keys, cached_files are Path objects
143
  fname = cached_file.name if hasattr(cached_file, 'name') else str(cached_file)