Other
Transformers
Safetensors
PyTorch
English
vision-language-action
humanoid-robotics
telepathy
multimodal
robotics-control
lora
Instructions to use Veltraxor/Sigma with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Veltraxor/Sigma with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Veltraxor/Sigma", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # eval_sigma_vla_rollout.py | |
| # Offline closed-loop evaluation for Telepathy-augmented VLA on top of PI05 policy backbone. | |
| # | |
| # Key design: | |
| # - base_model_id is a LeRobot/OpenPI policy repo (e.g., lerobot/pi05_base or your fine-tuned Sigma repo). | |
| # - We load PI05Policy via LeRobot, NOT AutoModelForCausalLM. | |
| # - Text embeddings are taken from the PI05 internal text backbone so that TelepathyLanguageModule | |
| # receives the same type of inputs used during training. | |
| # | |
| # Hardened in this revision: | |
| # - Robust recursive shard discovery under any naming & subfolders. | |
| # - Shard content structure normalization (list-of-samples, or dict{samples/data}). | |
| # - Collate auto-adapts to real schema: vision/state/action/text, with time-dim collapse for vision. | |
| # - Action GT supports dict-style branches or a single tensor. | |
| # - Metrics tolerate missing multi-branch outputs (fallback to "action"). | |
| # - Text tokens dtype/device aligned to model dtype for mixed precision safety. | |
| # - Robot state time-dim collapse + pad/trim to state encoder expected dim. | |
| # - Dynamic projection to align vision/state token hidden size to vision backbone dim (768), | |
| # and project text to the same dim BEFORE feeding language module. | |
| # - Optional max_text_len to avoid tokenizer truncation warnings. | |
| # - action input contract hardening: | |
| # * high_level_rep 2D -> 3D | |
| # * tau None/2D -> 3D | |
| # * tau length aligned to high_level_rep length | |
| # * tau last-dim auto pad/trim so concat(high_level_rep, tau) matches action_condition_proj in_features | |
| # - tokenizer_id can be a LOCAL path; when it exists locally we load with local_files_only | |
| # - _align_target handles 2D<->3D mismatches (fixes MSE crashes) | |
| # - remove duplicated "high_level_rep/tau re-normalization" that overwrote the hardening | |
| # | |
| # NEW in this patch: | |
| # - cosine_alignment auto-aligns hidden sizes (fixes 256 vs 2048 crash). | |
| # - semantic pooling guard supports 2D/3D factors safely. | |
| # - alignment metric ignores zero-length cases robustly. | |
| # | |
| # EXTRA HARDENING (this patch for your baseline issue): | |
| # - Try strict load for PI05Policy if the LeRobot version supports it. | |
| # - Verify tokenizer vocab size and special-token ids match PI05 text embedding table. | |
| # - Fail fast with a clear message if mismatch is detected (unless explicitly overridden). | |
| # | |
| # NEW in this hard-set patch: | |
| # - Per-sample MSE is exposed from success proxy. | |
| # - A "hard set" is defined as samples whose branch-wise MSE exceeds hard thresholds. | |
| # - Hard-set averages (MSE and fraction of samples) are reported alongside global metrics. | |
| # | |
| # NEW in this adapter patch: | |
| # - sigma_telepathy_adapter is applied at eval time (when telepathy is enabled) to gate | |
| # Telepathy residuals based on their magnitude and tau strength, optionally using | |
| # offline base_action_* if present in the shards. | |
| from __future__ import annotations | |
| import os | |
| import glob | |
| import json | |
| import argparse | |
| import importlib | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from dotenv import load_dotenv | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except Exception: | |
| snapshot_download = None # type: ignore | |
| from vision_sigma_vla import TelepathyVisionModule, VisionConfig | |
| from language_sigma_vla import TelepathyLanguageModule, LanguageConfig | |
| from action_sigma_vla import TelepathyActionModule, ActionConfig | |
| from sigma_telepathy_adapter import SigmaTelepathyAdapter, SigmaTelepathyAdapterConfig | |
| def ensure_sigma_artifacts_from_hf( | |
| repo_id: str, | |
| hf_token: Optional[str], | |
| local_cache_root: str, | |
| ) -> Dict[str, str]: | |
| """ | |
| Download Sigma artifacts from HF repo into a local cache folder. | |
| Returns local paths for shard_dir and telepathy_heads_path. | |
| We only pull: | |
| storage/sigma_pickplace/** | |
| storage/sigma_lora_out/** | |
| """ | |
| if snapshot_download is None: | |
| raise ImportError( | |
| "huggingface_hub is not available but auto-download was requested. " | |
| "Please `pip install huggingface_hub` or download artifacts manually." | |
| ) | |
| os.makedirs(local_cache_root, exist_ok=True) | |
| local_dir = snapshot_download( | |
| repo_id=repo_id, | |
| token=hf_token, | |
| local_dir=os.path.join(local_cache_root, repo_id.replace("/", "__")), | |
| local_dir_use_symlinks=False, | |
| allow_patterns=[ | |
| "storage/sigma_pickplace/**", | |
| "storage/sigma_lora_out/**", | |
| ], | |
| ) | |
| shard_dir = os.path.join(local_dir, "storage", "sigma_pickplace") | |
| telepathy_heads_path = os.path.join( | |
| local_dir, "storage", "sigma_lora_out", "sigma_telepathy_heads.pt" | |
| ) | |
| return { | |
| "local_repo_dir": local_dir, | |
| "shard_dir": shard_dir, | |
| "telepathy_heads_path": telepathy_heads_path, | |
| } | |
| def load_pi05_policy( | |
| repo_id: str, | |
| hf_token: Optional[str], | |
| device: torch.device, | |
| strict_load: bool = True, | |
| ): | |
| """ | |
| Load PI05Policy from LeRobot. We try a few import paths to be robust across versions. | |
| If the LeRobot PI05Policy.from_pretrained supports strict loading, we enable it. | |
| """ | |
| policy_cls = None | |
| import_errors = [] | |
| candidate_paths = [ | |
| ("lerobot.policies.pi05.modeling_pi05", "PI05Policy"), | |
| ("lerobot.policies.pi05", "PI05Policy"), | |
| ] | |
| for mod_name, cls_name in candidate_paths: | |
| try: | |
| mod = importlib.import_module(mod_name) | |
| policy_cls = getattr(mod, cls_name) | |
| break | |
| except Exception as e: | |
| import_errors.append(f"{mod_name}.{cls_name}: {type(e).__name__}: {e}") | |
| if policy_cls is None: | |
| raise ImportError( | |
| "Failed to import PI05Policy from LeRobot. Tried:\n - " | |
| + "\n - ".join(import_errors) | |
| ) | |
| policy = None | |
| tried = [] | |
| if strict_load: | |
| try: | |
| policy = policy_cls.from_pretrained(repo_id, token=hf_token, strict=True) | |
| tried.append("from_pretrained(..., strict=True)") | |
| except TypeError: | |
| tried.append("strict=True not supported") | |
| except Exception as e: | |
| tried.append(f"strict=True failed: {type(e).__name__}: {e}") | |
| if policy is None: | |
| try: | |
| policy = policy_cls.from_pretrained(repo_id, token=hf_token) | |
| tried.append("from_pretrained(repo_id, token=...)") | |
| except TypeError: | |
| policy = policy_cls.from_pretrained(pretrained_name_or_path=repo_id, token=hf_token) | |
| tried.append("from_pretrained(pretrained_name_or_path=..., token=...)") | |
| if policy is None: | |
| raise RuntimeError("PI05Policy loading returned None. Tried: " + "; ".join(tried)) | |
| policy = policy.to(device) | |
| policy.eval() | |
| return policy | |
| def get_policy_tokenizer( | |
| policy, | |
| repo_id: str, | |
| hf_token: Optional[str], | |
| forced_tokenizer_id: str = "", | |
| ): | |
| """ | |
| Robust tokenizer getter for PI05Policy. | |
| IMPORTANT: | |
| - Never call AutoTokenizer.from_pretrained(repo_id) because repo_id is a policy repo. | |
| - If --tokenizer_id is provided and points to a LOCAL folder, load locally. | |
| - Otherwise load from HF id. | |
| - If still missing, recursively search for tokenizer/processor inside policy. | |
| """ | |
| from transformers import AutoTokenizer | |
| if forced_tokenizer_id: | |
| if os.path.exists(forced_tokenizer_id): | |
| tok = AutoTokenizer.from_pretrained( | |
| forced_tokenizer_id, | |
| local_files_only=True, | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| tok = AutoTokenizer.from_pretrained( | |
| forced_tokenizer_id, | |
| token=hf_token, | |
| trust_remote_code=True, | |
| ) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| return tok | |
| def _recursive_find_tokenizer(obj, max_depth: int = 4): | |
| if obj is None or max_depth <= 0: | |
| return None | |
| for key in ["tokenizer", "processor", "text_tokenizer", "language_tokenizer"]: | |
| if hasattr(obj, key): | |
| v = getattr(obj, key) | |
| if v is None: | |
| continue | |
| if key == "processor" and hasattr(v, "tokenizer") and v.tokenizer is not None: | |
| return v.tokenizer | |
| if hasattr(v, "__call__"): | |
| return v | |
| nested_names = [ | |
| "paligemma_with_expert", | |
| "paligemma", | |
| "gemma_expert", | |
| "language_model", | |
| "text_model", | |
| "model", | |
| "policy", | |
| ] | |
| for name in nested_names: | |
| if hasattr(obj, name): | |
| found = _recursive_find_tokenizer( | |
| getattr(obj, name), max_depth=max_depth - 1 | |
| ) | |
| if found is not None: | |
| return found | |
| return None | |
| tok = _recursive_find_tokenizer(policy) | |
| if tok is not None: | |
| if getattr(tok, "pad_token", None) is None and getattr(tok, "eos_token", None) is not None: | |
| tok.pad_token = tok.eos_token | |
| return tok | |
| backbone_name = None | |
| config_candidates = [] | |
| for attr in ["config", "model", "paligemma_with_expert", "paligemma"]: | |
| if hasattr(policy, attr): | |
| config_candidates.append(getattr(policy, attr)) | |
| def _try_get_name(cfg_obj): | |
| if cfg_obj is None: | |
| return None | |
| for k in [ | |
| "_name_or_path", | |
| "text_backbone_id", | |
| "text_model_id", | |
| "language_model_id", | |
| "processor_name_or_path", | |
| "tokenizer_name_or_path", | |
| ]: | |
| if hasattr(cfg_obj, k): | |
| v = getattr(cfg_obj, k) | |
| if isinstance(v, str) and v: | |
| return v | |
| if hasattr(cfg_obj, "config"): | |
| c = getattr(cfg_obj, "config") | |
| if hasattr(c, "_name_or_path") and isinstance(c._name_or_path, str) and c._name_or_path: | |
| return c._name_or_path | |
| return None | |
| for c in config_candidates: | |
| backbone_name = _try_get_name(c) | |
| if backbone_name: | |
| break | |
| if backbone_name: | |
| tok = AutoTokenizer.from_pretrained( | |
| backbone_name, token=hf_token, trust_remote_code=True | |
| ) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| return tok | |
| raise ValueError( | |
| f"Cannot obtain tokenizer from PI05Policy for repo '{repo_id}'. " | |
| "Your lerobot PI05 port does not expose tokenizer/processor nor backbone name. " | |
| "Please pass --tokenizer_id explicitly." | |
| ) | |
| def get_policy_text_embedding_layer(policy): | |
| """ | |
| Locate the text embedding layer inside PI05Policy robustly. | |
| """ | |
| def _recursive_find(obj, depth: int = 6): | |
| if obj is None or depth <= 0: | |
| return None | |
| if hasattr(obj, "get_input_embeddings"): | |
| try: | |
| emb = obj.get_input_embeddings() | |
| if emb is not None: | |
| return emb | |
| except Exception: | |
| pass | |
| for key in ["embed_tokens", "embeddings", "token_embedding"]: | |
| if hasattr(obj, key): | |
| v = getattr(obj, key) | |
| if isinstance(v, nn.Module): | |
| return v | |
| nested_names = [ | |
| "model", | |
| "paligemma_with_expert", | |
| "paligemma", | |
| "language_model", | |
| "gemma_expert", | |
| "text_model", | |
| "policy", | |
| ] | |
| for name in nested_names: | |
| if hasattr(obj, name): | |
| found = _recursive_find(getattr(obj, name), depth=depth - 1) | |
| if found is not None: | |
| return found | |
| return None | |
| emb = _recursive_find(policy) | |
| if emb is None: | |
| raise AttributeError( | |
| "Cannot locate PI05 text embedding layer via recursive search. " | |
| "Your PI05Policy likely changed internal naming. " | |
| "Please inspect policy.model.* to confirm embed_tokens location." | |
| ) | |
| return emb | |
| def verify_tokenizer_embedding_compat( | |
| tokenizer, | |
| text_embed_layer: nn.Module, | |
| allow_mismatch: bool = False, | |
| ): | |
| """ | |
| Ensure tokenizer vocab/special ids are consistent with PI05 text embedding table. | |
| This directly prevents the 'embed_tokens.weight missing or misaligned' baseline issue. | |
| """ | |
| emb_vocab = None | |
| if isinstance(text_embed_layer, nn.Embedding): | |
| emb_vocab = int(text_embed_layer.num_embeddings) | |
| elif hasattr(text_embed_layer, "weight") and text_embed_layer.weight is not None: | |
| emb_vocab = int(text_embed_layer.weight.size(0)) | |
| tok_vocab = getattr(tokenizer, "vocab_size", None) | |
| if tok_vocab is None: | |
| try: | |
| tok_vocab = len(tokenizer) | |
| except Exception: | |
| tok_vocab = None | |
| if emb_vocab is None or tok_vocab is None: | |
| print("[WARN] Cannot infer tokenizer/embedding vocab sizes. Skipping compatibility check.") | |
| return | |
| if emb_vocab != tok_vocab: | |
| msg = ( | |
| f"[ERROR] Tokenizer vocab size ({tok_vocab}) != PI05 embedding table size ({emb_vocab}). " | |
| "This will corrupt text embeddings and invalidate baseline. " | |
| "Fix by passing --tokenizer_id matching the PI05 text backbone " | |
| "(e.g., the original openpi/PI05 tokenizer) or re-exporting policy with aligned vocab." | |
| ) | |
| if allow_mismatch: | |
| print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.") | |
| else: | |
| raise ValueError(msg) | |
| for name in ["pad_token_id", "eos_token_id", "bos_token_id", "unk_token_id"]: | |
| tid = getattr(tokenizer, name, None) | |
| if tid is None: | |
| continue | |
| if not (0 <= int(tid) < emb_vocab): | |
| msg = ( | |
| f"[ERROR] Tokenizer {name}={tid} out of embedding range [0, {emb_vocab-1}]. " | |
| "Your tokenizer does not belong to this PI05 backbone." | |
| ) | |
| if allow_mismatch: | |
| print(msg.replace("[ERROR]", "[WARN]") + " Proceeding due to --allow_tokenizer_mismatch.") | |
| else: | |
| raise ValueError(msg) | |
| class TelepathyVLA(nn.Module): | |
| """ | |
| Full model matching your final arrows. | |
| """ | |
| def __init__( | |
| self, | |
| v_cfg: VisionConfig, | |
| l_cfg: LanguageConfig, | |
| a_cfg: ActionConfig, | |
| disable_telepathy: bool = False, | |
| ): | |
| super().__init__() | |
| self.vision = TelepathyVisionModule(v_cfg) | |
| self.language = TelepathyLanguageModule(l_cfg) | |
| self.action = TelepathyActionModule(a_cfg) | |
| self.disable_telepathy = disable_telepathy | |
| self.register_buffer("_m_prev", None, persistent=False) | |
| self._proj_inited = False | |
| self.text_proj: Optional[nn.Module] = None | |
| self.vision_proj: Optional[nn.Module] = None | |
| self.state_proj: Optional[nn.Module] = None | |
| def reset_memory(self): | |
| self._m_prev = None | |
| def forward_once( | |
| self, | |
| vis_obs: torch.Tensor, | |
| robot_state: torch.Tensor, | |
| text_tokens: torch.Tensor, | |
| depth_obs: Optional[torch.Tensor] = None, | |
| audio_obs: Optional[torch.Tensor] = None, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| return_intermediate: bool = False, | |
| ) -> Dict[str, torch.Tensor]: | |
| vis0 = self.vision( | |
| vis_obs=vis_obs, | |
| robot_state=robot_state, | |
| depth_obs=depth_obs, | |
| audio_obs=audio_obs, | |
| telepathy_factors=None, | |
| return_intermediate=return_intermediate, | |
| ) | |
| vis_d = vis0["vision_tokens"].size(-1) | |
| state_d = vis0["state_tokens"].size(-1) | |
| target_d = vis_d | |
| if not self._proj_inited: | |
| self.text_proj = nn.Linear(text_tokens.size(-1), target_d, bias=False) \ | |
| if text_tokens.size(-1) != target_d else nn.Identity() | |
| self.vision_proj = nn.Identity() if vis_d == target_d else nn.Linear(vis_d, target_d, bias=False) | |
| self.state_proj = nn.Identity() if state_d == target_d else nn.Linear(state_d, target_d, bias=False) | |
| self.text_proj = self.text_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) | |
| self.vision_proj = self.vision_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) | |
| self.state_proj = self.state_proj.to(device=text_tokens.device, dtype=text_tokens.dtype) | |
| self._proj_inited = True | |
| assert self.text_proj is not None and self.vision_proj is not None and self.state_proj is not None | |
| text_tokens = self.text_proj(text_tokens) | |
| vision_tokens = self.vision_proj(vis0["vision_tokens"]) | |
| state_tokens = self.state_proj(vis0["state_tokens"]) | |
| lang_out = self.language( | |
| text_tokens=text_tokens, | |
| vision_tokens=vision_tokens, | |
| state_tokens=state_tokens, | |
| m_prev=self._m_prev, | |
| attn_mask=attn_mask, | |
| return_intermediate=return_intermediate, | |
| ) | |
| raw_tau = lang_out.get("telepathy_factors", None) | |
| self._m_prev = lang_out.get("m_t", None) | |
| telepathy_scale = float(getattr(self, "telepathy_scale", 1.0)) | |
| if self.disable_telepathy: | |
| tau = None | |
| vis_out = vis0 | |
| else: | |
| tau = raw_tau | |
| if tau is not None: | |
| tau = tau * telepathy_scale | |
| vis_out = self.vision( | |
| vis_obs=vis_obs, | |
| robot_state=robot_state, | |
| depth_obs=depth_obs, | |
| audio_obs=audio_obs, | |
| telepathy_factors=tau, | |
| return_intermediate=return_intermediate, | |
| ) | |
| high_level_rep = lang_out.get("high_level_rep", None) | |
| if high_level_rep is None: | |
| raise KeyError("language output missing 'high_level_rep'.") | |
| if high_level_rep.dim() == 2: | |
| high_level_rep = high_level_rep.unsqueeze(1) | |
| if tau is None: | |
| B, L, _ = high_level_rep.shape | |
| tau_dim = getattr(self.language, "tau_dim", 128) | |
| tau = torch.zeros(B, L, tau_dim, device=high_level_rep.device, dtype=high_level_rep.dtype) | |
| else: | |
| if tau.dim() == 2: | |
| tau = tau.unsqueeze(1) | |
| if tau.size(1) != high_level_rep.size(1): | |
| L = high_level_rep.size(1) | |
| if tau.size(1) == 1: | |
| tau = tau.expand(-1, L, -1) | |
| else: | |
| tau = tau[:, :L, :] | |
| expected_in = None | |
| acp = getattr(self.action, "action_condition_proj", None) | |
| if acp is not None: | |
| if hasattr(acp, "in_features"): | |
| expected_in = int(acp.in_features) | |
| elif hasattr(acp, "net") and len(acp.net) > 0 and hasattr(acp.net[0], "in_features"): | |
| expected_in = int(acp.net[0].in_features) | |
| if expected_in is not None: | |
| d_high = high_level_rep.size(-1) | |
| target_tau = expected_in - d_high | |
| if target_tau <= 0: | |
| pass | |
| else: | |
| if tau.size(-1) < target_tau: | |
| tau = F.pad(tau, (0, target_tau - tau.size(-1))) | |
| elif tau.size(-1) > target_tau: | |
| tau = tau[..., :target_tau] | |
| state_for_action = vis_out["state_tokens"] | |
| if state_for_action.dim() == 2: | |
| state_for_action = state_for_action.unsqueeze(1) | |
| elif state_for_action.dim() > 3: | |
| state_for_action = state_for_action.view( | |
| state_for_action.size(0), -1, state_for_action.size(-1) | |
| ) | |
| lang_d = high_level_rep.size(-1) | |
| def _pad_or_trim_to(x: torch.Tensor, d: int) -> torch.Tensor: | |
| cur_d = x.size(-1) | |
| if cur_d == d: | |
| return x | |
| if cur_d < d: | |
| return F.pad(x, (0, d - cur_d)) | |
| return x[..., :d] | |
| state_for_action = _pad_or_trim_to(state_for_action, lang_d) | |
| act_out = self.action( | |
| high_level_rep=high_level_rep, | |
| telepathy_factors=tau, | |
| state_tokens=state_for_action, | |
| return_intermediate=return_intermediate, | |
| ) | |
| out: Dict[str, torch.Tensor] = {} | |
| out.update(vis_out) | |
| out.update(lang_out) | |
| out.update(act_out) | |
| return out | |
| class SigmaShardDataset(Dataset): | |
| """ | |
| Loads .pt shards produced by dataset_preprocess_sigma_vla.py. | |
| Each shard is a list of dict samples OR a dict containing a list (samples/data). | |
| """ | |
| def __init__(self, shard_dir: str): | |
| super().__init__() | |
| if not os.path.isdir(shard_dir): | |
| raise FileNotFoundError( | |
| f"shard_dir does not exist: {shard_dir}. Double-check the path." | |
| ) | |
| patterns = [ | |
| os.path.join(shard_dir, "sigma_vla_shard_*.pt"), | |
| os.path.join(shard_dir, "*.pt"), | |
| os.path.join(shard_dir, "**", "*.pt"), | |
| ] | |
| paths: List[str] = [] | |
| for p in patterns: | |
| paths.extend(glob.glob(p, recursive=True)) | |
| self.shard_paths = sorted(list(set(paths))) | |
| if len(self.shard_paths) == 0: | |
| raise FileNotFoundError( | |
| f"No .pt shards found under {shard_dir}. " | |
| "Your HF cache is empty or shards are not tracked by LFS." | |
| ) | |
| print(f"[INFO] Found {len(self.shard_paths)} shard files. Example: {self.shard_paths[:3]}") | |
| self.index_map: List[Tuple[int, int]] = [] | |
| self._shard_cache: Dict[int, List[Dict[str, Any]]] = {} | |
| for sid, p in enumerate(self.shard_paths): | |
| shard = torch.load(p, map_location="cpu") | |
| shard_list = self._normalize_shard(shard, p) | |
| for lid in range(len(shard_list)): | |
| self.index_map.append((sid, lid)) | |
| self.total = len(self.index_map) | |
| def __len__(self): | |
| return self.total | |
| def _normalize_shard(self, shard_obj: Any, path: str) -> List[Dict[str, Any]]: | |
| if isinstance(shard_obj, (list, tuple)): | |
| return list(shard_obj) | |
| if isinstance(shard_obj, dict): | |
| for k in ["samples", "data", "items"]: | |
| if k in shard_obj and isinstance(shard_obj[k], (list, tuple)): | |
| return list(shard_obj[k]) | |
| raise TypeError( | |
| f"Unsupported shard format in {path}. " | |
| f"Expected list/tuple of samples or dict{{samples/data}}. " | |
| f"Got type: {type(shard_obj).__name__}" | |
| ) | |
| def _get_shard(self, sid: int) -> List[Dict[str, Any]]: | |
| if sid not in self._shard_cache: | |
| raw = torch.load(self.shard_paths[sid], map_location="cpu") | |
| self._shard_cache[sid] = self._normalize_shard(raw, self.shard_paths[sid]) | |
| return self._shard_cache[sid] | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| sid, lid = self.index_map[idx] | |
| shard = self._get_shard(sid) | |
| return shard[lid] | |
| def collate_sigma(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Robust collate for Sigma shards. | |
| """ | |
| s0 = batch_list[0] | |
| def pick_key(sample: Dict[str, Any], candidates: List[str], field_name: str): | |
| for k in candidates: | |
| if k in sample: | |
| return k | |
| raise KeyError( | |
| f"Shard sample missing required field '{field_name}'. " | |
| f"Tried keys: {candidates}. " | |
| f"Available keys: {list(sample.keys())}" | |
| ) | |
| if "vision" in s0: | |
| vis_k = "vision" | |
| else: | |
| vis_k = pick_key(s0, ["vis_obs", "rgb_obs", "image", "images", "obs"], "vision/vis_obs") | |
| vis_obs = torch.stack([b[vis_k] for b in batch_list], dim=0).float() | |
| if vis_obs.dim() == 5: | |
| vis_obs = vis_obs[:, -1] | |
| depth_obs = None | |
| if "depth" in s0: | |
| depth_obs = torch.stack([b["depth"] for b in batch_list], dim=0).float() | |
| elif any(k in s0 for k in ["depth_obs", "depths"]): | |
| dk = pick_key(s0, ["depth_obs", "depths"], "depth") | |
| depth_obs = torch.stack([b[dk] for b in batch_list], dim=0).float() | |
| audio_obs = None | |
| if "audio" in s0: | |
| audio_obs = torch.stack([b["audio"] for b in batch_list], dim=0).float() | |
| elif any(k in s0 for k in ["audio_obs", "audios"]): | |
| ak = pick_key(s0, ["audio_obs", "audios"], "audio") | |
| audio_obs = torch.stack([b[ak] for b in batch_list], dim=0).float() | |
| if "state" in s0: | |
| state_k = "state" | |
| else: | |
| state_k = pick_key(s0, ["robot_state", "proprio", "proprio_obs"], "state/robot_state") | |
| robot_state = torch.stack([b[state_k] for b in batch_list], dim=0).float() | |
| if "text" in s0: | |
| texts = [b.get("text", "") for b in batch_list] | |
| else: | |
| text_k = pick_key(s0, ["text", "prompt", "instruction"], "text") | |
| texts = [b.get(text_k, "") for b in batch_list] | |
| if "action" in s0: | |
| a0 = s0["action"] | |
| if isinstance(a0, dict): | |
| def pick_action_key(d, candidates, name): | |
| for k in candidates: | |
| if k in d: | |
| return k | |
| raise KeyError( | |
| f"Action dict missing '{name}'. Tried {candidates}. " | |
| f"Available action keys: {list(d.keys())}" | |
| ) | |
| vec_k = pick_action_key(a0, ["gt_action_vector", "action_vector", "vector", "vec"], "gt_action_vector") | |
| chk_k = pick_action_key(a0, ["gt_action_chunk", "action_chunk", "chunk", "chk"], "gt_action_chunk") | |
| trj_k = pick_action_key(a0, ["gt_action_trajectory", "action_trajectory", "trajectory", "traj"], "gt_action_trajectory") | |
| gt_action_vector = torch.stack([b["action"][vec_k] for b in batch_list], dim=0).float() | |
| gt_action_chunk = torch.stack([b["action"][chk_k] for b in batch_list], dim=0).float() | |
| gt_action_trajectory = torch.stack([b["action"][trj_k] for b in batch_list], dim=0).float() | |
| else: | |
| act = torch.stack([b["action"] for b in batch_list], dim=0).float() | |
| gt_action_vector = act | |
| gt_action_chunk = act | |
| gt_action_trajectory = act | |
| else: | |
| gt_vec_k = pick_key(s0, ["gt_action_vector", "action_vector", "gt_vec"], "gt_action_vector") | |
| gt_chk_k = pick_key(s0, ["gt_action_chunk", "action_chunk", "gt_chunk"], "gt_action_chunk") | |
| gt_trj_k = pick_key(s0, ["gt_action_trajectory", "action_trajectory", "gt_traj"], "gt_action_trajectory") | |
| gt_action_vector = torch.stack([b[gt_vec_k] for b in batch_list], dim=0).float() | |
| gt_action_chunk = torch.stack([b[gt_chk_k] for b in batch_list], dim=0).float() | |
| gt_action_trajectory = torch.stack([b[gt_trj_k] for b in batch_list], dim=0).float() | |
| # Optional offline base actions for adapter; if missing, we simply do not include them. | |
| base_action_vector = None | |
| base_action_chunk = None | |
| base_action_trajectory = None | |
| has_base_top = any( | |
| k in s0 | |
| for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"] | |
| ) | |
| has_base_in_action = "action" in s0 and isinstance(s0["action"], dict) and any( | |
| k in s0["action"] | |
| for k in ["base_action_vector", "base_action_chunk", "base_action_trajectory"] | |
| ) | |
| if has_base_top: | |
| if "base_action_vector" in s0: | |
| base_action_vector = torch.stack([b["base_action_vector"] for b in batch_list], dim=0).float() | |
| if "base_action_chunk" in s0: | |
| base_action_chunk = torch.stack([b["base_action_chunk"] for b in batch_list], dim=0).float() | |
| if "base_action_trajectory" in s0: | |
| base_action_trajectory = torch.stack([b["base_action_trajectory"] for b in batch_list], dim=0).float() | |
| elif has_base_in_action: | |
| a0 = s0["action"] | |
| def pick_base_key(d, candidates): | |
| for k in candidates: | |
| if k in d: | |
| return k | |
| return None | |
| vec_bk = pick_base_key(a0, ["base_action_vector", "base_vec"]) | |
| chk_bk = pick_base_key(a0, ["base_action_chunk", "base_chunk"]) | |
| trj_bk = pick_base_key(a0, ["base_action_trajectory", "base_traj"]) | |
| if vec_bk is not None: | |
| base_action_vector = torch.stack([b["action"][vec_bk] for b in batch_list], dim=0).float() | |
| if chk_bk is not None: | |
| base_action_chunk = torch.stack([b["action"][chk_bk] for b in batch_list], dim=0).float() | |
| if trj_bk is not None: | |
| base_action_trajectory = torch.stack([b["action"][trj_bk] for b in batch_list], dim=0).float() | |
| batch: Dict[str, Any] = { | |
| "vis_obs": vis_obs, | |
| "depth_obs": depth_obs, | |
| "audio_obs": audio_obs, | |
| "robot_state": robot_state, | |
| "texts": texts, | |
| "gt_action_vector": gt_action_vector, | |
| "gt_action_chunk": gt_action_chunk, | |
| "gt_action_trajectory": gt_action_trajectory, | |
| } | |
| if base_action_vector is not None: | |
| batch["base_action_vector"] = base_action_vector | |
| if base_action_chunk is not None: | |
| batch["base_action_chunk"] = base_action_chunk | |
| if base_action_trajectory is not None: | |
| batch["base_action_trajectory"] = base_action_trajectory | |
| return batch | |
| def _align_target(pred_t: torch.Tensor, gt_t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Align GT to prediction for MSE: | |
| - handle 2D vs 3D mismatches by collapsing or expanding time dimension. | |
| - then align last-dim by pad/trim. | |
| """ | |
| if gt_t.dim() == 3 and pred_t.dim() == 2: | |
| gt_t = gt_t[:, -1, :] | |
| if pred_t.dim() == 3 and gt_t.dim() == 2: | |
| gt_t = gt_t.unsqueeze(1) | |
| if gt_t.size(1) != pred_t.size(1): | |
| gt_t = gt_t.expand(-1, pred_t.size(1), -1) | |
| if pred_t.dim() == 3 and gt_t.dim() == 3: | |
| Tp = pred_t.size(1) | |
| Tg = gt_t.size(1) | |
| if Tg < Tp: | |
| pad = torch.zeros( | |
| gt_t.size(0), Tp - Tg, gt_t.size(2), | |
| device=gt_t.device, dtype=gt_t.dtype | |
| ) | |
| gt_t = torch.cat([gt_t, pad], dim=1) | |
| elif Tg > Tp: | |
| gt_t = gt_t[:, :Tp, :] | |
| pd = pred_t.size(-1) | |
| gd = gt_t.size(-1) | |
| if gd < pd: | |
| gt_t = F.pad(gt_t, (0, pd - gd)) | |
| elif gd > pd: | |
| gt_t = gt_t[..., :pd] | |
| return gt_t | |
| def _pred_action(pred: Dict[str, torch.Tensor], key: str) -> torch.Tensor: | |
| if key in pred: | |
| return pred[key] | |
| if "action" in pred: | |
| return pred["action"] | |
| raise KeyError( | |
| f"Pred dict missing action key '{key}' and fallback 'action'. " | |
| f"Available pred keys: {list(pred.keys())}" | |
| ) | |
| def compute_branch_mse(pred: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, float]: | |
| vec_pred = _pred_action(pred, "action_vector") | |
| chk_pred = _pred_action(pred, "action_chunk") | |
| trj_pred = _pred_action(pred, "action_trajectory") | |
| device = vec_pred.device | |
| gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device)) | |
| gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device)) | |
| gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device)) | |
| mse_vec = F.mse_loss(vec_pred, gt_vec).item() | |
| mse_chk = F.mse_loss(chk_pred, gt_chk).item() | |
| mse_trj = F.mse_loss(trj_pred, gt_trj).item() | |
| return {"mse_vector": mse_vec, "mse_chunk": mse_chk, "mse_traj": mse_trj} | |
| def compute_success_proxy( | |
| pred: Dict[str, torch.Tensor], | |
| batch: Dict[str, Any], | |
| thr_vec: float, | |
| thr_chk: float, | |
| thr_trj: float, | |
| ) -> Tuple[int, int, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns: | |
| num_success, num_total, mse_vec_per_sample, mse_chk_per_sample, mse_trj_per_sample | |
| where per-sample MSE is averaged over all non-batch dims. | |
| """ | |
| vec_pred = _pred_action(pred, "action_vector") | |
| chk_pred = _pred_action(pred, "action_chunk") | |
| trj_pred = _pred_action(pred, "action_trajectory") | |
| device = vec_pred.device | |
| gt_vec = _align_target(vec_pred, batch["gt_action_vector"].to(device)) | |
| gt_chk = _align_target(chk_pred, batch["gt_action_chunk"].to(device)) | |
| gt_trj = _align_target(trj_pred, batch["gt_action_trajectory"].to(device)) | |
| reduce_dims_vec = list(range(1, vec_pred.dim())) | |
| reduce_dims_chk = list(range(1, chk_pred.dim())) | |
| reduce_dims_trj = list(range(1, trj_pred.dim())) | |
| mse_vec_s = ((vec_pred - gt_vec) ** 2).mean(dim=reduce_dims_vec) | |
| mse_chk_s = ((chk_pred - gt_chk) ** 2).mean(dim=reduce_dims_chk) | |
| mse_trj_s = ((trj_pred - gt_trj) ** 2).mean(dim=reduce_dims_trj) | |
| success_mask = (mse_vec_s < thr_vec) & (mse_chk_s < thr_chk) & (mse_trj_s < thr_trj) | |
| num_success = int(success_mask.sum().item()) | |
| num_total = int(success_mask.numel()) | |
| return num_success, num_total, mse_vec_s, mse_chk_s, mse_trj_s | |
| def compute_telepathy_stability(pred: Dict[str, torch.Tensor]) -> float: | |
| tau = pred.get("telepathy_factors", None) | |
| if tau is None: | |
| return float("nan") | |
| return float((tau ** 2).mean().item()) | |
| def cosine_alignment(a: torch.Tensor, b: torch.Tensor) -> float: | |
| """ | |
| Cosine alignment that is robust to hidden-size mismatch. | |
| Accepts [B, D] or [B, T, D]. Pools time if present. | |
| If dims differ, crops both to min(Da, Db) for a fair cosine check. | |
| """ | |
| if a.dim() == 3: | |
| a = a.mean(dim=1) | |
| if b.dim() == 3: | |
| b = b.mean(dim=1) | |
| if a.numel() == 0 or b.numel() == 0: | |
| return float("nan") | |
| da, db = a.size(-1), b.size(-1) | |
| if da != db: | |
| d = min(da, db) | |
| a = a[..., :d] | |
| b = b[..., :d] | |
| a = F.normalize(a, dim=-1) | |
| b = F.normalize(b, dim=-1) | |
| return float((a * b).sum(dim=-1).mean().item()) | |
| def build_text_tokens_from_policy( | |
| tokenizer, | |
| text_embed_layer: nn.Module, | |
| texts: List[str], | |
| device: torch.device, | |
| target_dtype: torch.dtype, | |
| max_text_len: int = 0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Tokenize prompts and map to embeddings using PI05 internal embedding layer. | |
| Returns (text_tokens, attn_mask). | |
| """ | |
| if max_text_len and max_text_len > 0: | |
| tok = tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_text_len, | |
| return_tensors="pt", | |
| ) | |
| else: | |
| tok = tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=False, | |
| return_tensors="pt", | |
| ) | |
| if hasattr(tok, "input_ids"): | |
| input_ids = tok.input_ids | |
| attn_mask = tok.attention_mask | |
| else: | |
| input_ids = tok["input_ids"] | |
| attn_mask = tok.get("attention_mask", None) | |
| if attn_mask is None: | |
| attn_mask = torch.ones_like(input_ids) | |
| input_ids = input_ids.to(device) | |
| attn_mask = attn_mask.to(device) | |
| text_tokens = text_embed_layer(input_ids).to(dtype=target_dtype) | |
| return text_tokens, attn_mask | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--sigma_env", type=str, default="sigma.env") | |
| parser.add_argument("--shard_dir", type=str, default="") | |
| parser.add_argument("--output_dir", type=str, default="./sigma_eval_out") | |
| parser.add_argument( | |
| "--base_model_id", | |
| type=str, | |
| required=True, | |
| help="LeRobot/OpenPI policy repo, e.g., lerobot/pi05_base or your Sigma policy repo.", | |
| ) | |
| parser.add_argument( | |
| "--telepathy_heads_path", | |
| type=str, | |
| default="", | |
| help="Path to sigma_telepathy_heads.pt. If empty, auto-fetch may fill it.", | |
| ) | |
| parser.add_argument( | |
| "--disable_telepathy", | |
| action="store_true", | |
| help="Disable telepathy injection (control run).", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_id", | |
| type=str, | |
| default="", | |
| help="Explicit HF tokenizer id OR local tokenizer folder path.", | |
| ) | |
| parser.add_argument("--max_text_len", type=int, default=0) | |
| parser.add_argument( | |
| "--artifacts_repo_id", | |
| type=str, | |
| default="", | |
| help="HF repo containing storage/sigma_pickplace and storage/sigma_lora_out.", | |
| ) | |
| parser.add_argument( | |
| "--hf_cache_root", | |
| type=str, | |
| default="/workspace/.hf_sigma_cache", | |
| ) | |
| parser.add_argument("--load_in_4bit", action="store_true") | |
| parser.add_argument("--dtype", type=str, default="bf16") | |
| parser.add_argument("--batch_size", type=int, default=4) | |
| parser.add_argument("--num_workers", type=int, default=2) | |
| parser.add_argument("--max_batches", type=int, default=-1) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument( | |
| "--shuffle", | |
| action="store_true", | |
| help="Shuffle dataset order to enable different random subsets per seed.", | |
| ) | |
| parser.add_argument( | |
| "--telepathy_scale", | |
| type=float, | |
| default=1.0, | |
| help="Multiply telepathy_factors (tau) to control injection strength.", | |
| ) | |
| parser.add_argument("--succ_thr_vec", type=float, default=0.05) | |
| parser.add_argument("--succ_thr_chk", type=float, default=0.10) | |
| parser.add_argument("--succ_thr_trj", type=float, default=0.10) | |
| # Hard-set thresholds: if <=0, they default to 2x the success thresholds. | |
| parser.add_argument( | |
| "--hard_thr_vec", | |
| type=float, | |
| default=-1.0, | |
| help="Per-sample MSE threshold for the 'hard' set on vector branch; <=0 means 2x succ_thr_vec.", | |
| ) | |
| parser.add_argument( | |
| "--hard_thr_chk", | |
| type=float, | |
| default=-1.0, | |
| help="Per-sample MSE threshold for the 'hard' set on chunk branch; <=0 means 2x succ_thr_chk.", | |
| ) | |
| parser.add_argument( | |
| "--hard_thr_trj", | |
| type=float, | |
| default=-1.0, | |
| help="Per-sample MSE threshold for the 'hard' set on trajectory branch; <=0 means 2x succ_thr_trj.", | |
| ) | |
| parser.add_argument( | |
| "--strict_pi05_load", | |
| action="store_true", | |
| help="Try strict PI05Policy loading if supported by LeRobot.", | |
| ) | |
| parser.add_argument( | |
| "--allow_tokenizer_mismatch", | |
| action="store_true", | |
| help="Do not fail on tokenizer/embedding mismatch (NOT recommended for baseline).", | |
| ) | |
| # Simple flag to enable/disable the adapter without touching telepathy itself. | |
| parser.add_argument( | |
| "--use_telepathy_adapter", | |
| action="store_true", | |
| help="If set and telepathy is enabled, apply sigma_telepathy_adapter to actions in eval.", | |
| ) | |
| args = parser.parse_args() | |
| if os.path.exists(args.sigma_env): | |
| load_dotenv(args.sigma_env) | |
| hf_token = os.getenv("HF_TOKEN", None) | |
| accelerator = Accelerator(mixed_precision=args.dtype if args.dtype != "fp32" else "no") | |
| set_seed(args.seed) | |
| device = accelerator.device | |
| if args.load_in_4bit: | |
| print("[WARN] --load_in_4bit is ignored for PI05Policy evaluator.") | |
| artifacts_repo = args.artifacts_repo_id.strip() | |
| if not artifacts_repo and args.base_model_id.startswith("Veltraxor/"): | |
| artifacts_repo = args.base_model_id | |
| need_shards = (not args.shard_dir) or (not os.path.isdir(args.shard_dir)) | |
| need_heads = (not args.telepathy_heads_path) or (not os.path.isfile(args.telepathy_heads_path)) | |
| if artifacts_repo and (need_shards or need_heads): | |
| paths = ensure_sigma_artifacts_from_hf( | |
| repo_id=artifacts_repo, | |
| hf_token=hf_token, | |
| local_cache_root=args.hf_cache_root, | |
| ) | |
| if need_shards: | |
| args.shard_dir = paths["shard_dir"] | |
| print(f"[INFO] Using cached shard_dir: {args.shard_dir}") | |
| if need_heads: | |
| args.telepathy_heads_path = paths["telepathy_heads_path"] | |
| print(f"[INFO] Using cached telepathy_heads_path: {args.telepathy_heads_path}") | |
| if not args.shard_dir or not os.path.isdir(args.shard_dir): | |
| raise FileNotFoundError( | |
| f"shard_dir not found locally: {args.shard_dir}. " | |
| "Either provide a valid local path or an artifacts_repo_id for auto-download." | |
| ) | |
| if not args.telepathy_heads_path or not os.path.isfile(args.telepathy_heads_path): | |
| raise FileNotFoundError( | |
| f"telepathy_heads_path not found locally: {args.telepathy_heads_path}. " | |
| "Either provide a valid local path or store it under storage/sigma_lora_out/ " | |
| "in artifacts_repo_id for auto-download." | |
| ) | |
| policy = load_pi05_policy( | |
| args.base_model_id, | |
| hf_token, | |
| device=device, | |
| strict_load=args.strict_pi05_load, | |
| ) | |
| tokenizer = get_policy_tokenizer( | |
| policy, | |
| args.base_model_id, | |
| hf_token, | |
| forced_tokenizer_id=args.tokenizer_id, | |
| ) | |
| text_embed_layer = get_policy_text_embedding_layer(policy) | |
| verify_tokenizer_embedding_compat( | |
| tokenizer=tokenizer, | |
| text_embed_layer=text_embed_layer, | |
| allow_mismatch=args.allow_tokenizer_mismatch, | |
| ) | |
| v_cfg = VisionConfig() | |
| l_cfg = LanguageConfig() | |
| a_cfg = ActionConfig() | |
| telepathy_vla = TelepathyVLA(v_cfg, l_cfg, a_cfg, disable_telepathy=args.disable_telepathy) | |
| telepathy_vla.telepathy_scale = args.telepathy_scale | |
| # Instantiate Telepathy adapter (used only when telepathy is enabled and flag is set). | |
| adapter_cfg = SigmaTelepathyAdapterConfig() | |
| telepathy_adapter = SigmaTelepathyAdapter(adapter_cfg).to(device) | |
| if accelerator.is_main_process: | |
| file_size_mb = os.path.getsize(args.telepathy_heads_path) / (1024 * 1024) | |
| print(f"[CHECK-A] disable_telepathy={args.disable_telepathy}") | |
| print(f"[CHECK-A] telepathy_heads_path={args.telepathy_heads_path} size={file_size_mb:.2f}MB") | |
| sd = torch.load(args.telepathy_heads_path, map_location="cpu") | |
| tensor_list = [v.detach().float().reshape(-1) for v in sd.values() if torch.is_tensor(v)] | |
| if accelerator.is_main_process and len(tensor_list) > 0: | |
| capped = [t[:100000] for t in tensor_list] | |
| flat = torch.cat(capped, dim=0) | |
| rms = torch.sqrt((flat ** 2).mean()).item() | |
| print(f"[CHECK-A] heads_tensors={len(tensor_list)} mean={flat.mean().item():.6f} std={flat.std().item():.6f} rms={rms:.6f}") | |
| missing, unexpected = telepathy_vla.load_state_dict(sd, strict=False) | |
| if accelerator.is_main_process: | |
| if len(missing) > 0 or len(unexpected) > 0: | |
| print(f"[CHECK-A] loaded with strict=False. Missing={len(missing)} Unexpected={len(unexpected)}") | |
| print(f"[CHECK-A] Missing keys (first 20): {missing[:20]}") | |
| print(f"[CHECK-A] Unexpected keys (first 20): {unexpected[:20]}") | |
| else: | |
| print("[CHECK-A] heads fully matched (no missing/unexpected).") | |
| telepathy_vla.eval() | |
| ds = SigmaShardDataset(args.shard_dir) | |
| dl = DataLoader( | |
| ds, | |
| batch_size=args.batch_size, | |
| shuffle=args.shuffle, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_sigma, | |
| drop_last=False, | |
| pin_memory=torch.cuda.is_available(), | |
| ) | |
| telepathy_vla, dl = accelerator.prepare(telepathy_vla, dl) | |
| target_dtype = next(telepathy_vla.parameters()).dtype | |
| sum_mse_vec = 0.0 | |
| sum_mse_chk = 0.0 | |
| sum_mse_trj = 0.0 | |
| sum_tau_l2 = 0.0 | |
| sum_sem_align = 0.0 | |
| # Hard-set aggregators | |
| hard_thr_vec = args.hard_thr_vec if args.hard_thr_vec > 0.0 else 2.0 * args.succ_thr_vec | |
| hard_thr_chk = args.hard_thr_chk if args.hard_thr_chk > 0.0 else 2.0 * args.succ_thr_chk | |
| hard_thr_trj = args.hard_thr_trj if args.hard_thr_trj > 0.0 else 2.0 * args.succ_thr_trj | |
| sum_hard_mse_vec = 0.0 | |
| sum_hard_mse_chk = 0.0 | |
| sum_hard_mse_trj = 0.0 | |
| total_hard_samples = 0 | |
| n_batches = 0 | |
| n_samples = 0 | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| for bidx, batch in enumerate(dl): | |
| if args.max_batches > 0 and bidx >= args.max_batches: | |
| break | |
| telepathy_vla.reset_memory() | |
| B = batch["vis_obs"].size(0) | |
| n_samples += B | |
| text_tokens, attn_mask = build_text_tokens_from_policy( | |
| tokenizer=tokenizer, | |
| text_embed_layer=text_embed_layer, | |
| texts=batch["texts"], | |
| device=device, | |
| target_dtype=target_dtype, | |
| max_text_len=args.max_text_len, | |
| ) | |
| robot_state = batch["robot_state"].to(device) | |
| if robot_state.dim() == 3: | |
| robot_state = robot_state[:, -1] | |
| # Move optional base actions to device for the adapter. | |
| if "base_action_vector" in batch: | |
| batch["base_action_vector"] = batch["base_action_vector"].to(device) | |
| if "base_action_chunk" in batch: | |
| batch["base_action_chunk"] = batch["base_action_chunk"].to(device) | |
| if "base_action_trajectory" in batch: | |
| batch["base_action_trajectory"] = batch["base_action_trajectory"].to(device) | |
| try: | |
| expected_d = telepathy_vla.vision.state_encoder.mlp[0].in_features | |
| except Exception: | |
| expected_d = robot_state.size(-1) | |
| cur_d = robot_state.size(-1) | |
| if cur_d < expected_d: | |
| robot_state = F.pad(robot_state, (0, expected_d - cur_d)) | |
| elif cur_d > expected_d: | |
| robot_state = robot_state[..., :expected_d] | |
| pred = telepathy_vla.forward_once( | |
| vis_obs=batch["vis_obs"].to(device), | |
| robot_state=robot_state, | |
| depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None, | |
| audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None, | |
| text_tokens=text_tokens, | |
| attn_mask=attn_mask, | |
| return_intermediate=True, | |
| ) | |
| if accelerator.is_main_process and bidx == 0: | |
| model_ref = telepathy_vla.module if hasattr(telepathy_vla, "module") else telepathy_vla | |
| model_ref.reset_memory() | |
| prev_flag = bool(model_ref.disable_telepathy) | |
| model_ref.disable_telepathy = True | |
| pred_ctrl = model_ref.forward_once( | |
| vis_obs=batch["vis_obs"].to(device), | |
| robot_state=robot_state, | |
| depth_obs=batch["depth_obs"].to(device) if batch["depth_obs"] is not None else None, | |
| audio_obs=batch["audio_obs"].to(device) if batch["audio_obs"] is not None else None, | |
| text_tokens=text_tokens, | |
| attn_mask=attn_mask, | |
| return_intermediate=False, | |
| ) | |
| model_ref.disable_telepathy = prev_flag | |
| try: | |
| act_exp = _pred_action(pred, "action_vector") | |
| act_ctl = _pred_action(pred_ctrl, "action_vector") | |
| diff = (act_exp - act_ctl).abs().mean().item() | |
| print(f"[CHECK-B] telepathy_effect_mean_abs_diff(action_vector)={diff:.6f}") | |
| except Exception as e: | |
| print(f"[CHECK-B] action diff check failed: {type(e).__name__}: {e}") | |
| # Apply Telepathy adapter only when telepathy is enabled and the flag is set. | |
| if (not args.disable_telepathy) and args.use_telepathy_adapter: | |
| pred = telepathy_adapter(pred, batch) | |
| mse = compute_branch_mse(pred, batch) | |
| tau_l2 = compute_telepathy_stability(pred) | |
| ( | |
| _, | |
| _, | |
| mse_vec_s, | |
| mse_chk_s, | |
| mse_trj_s, | |
| ) = compute_success_proxy( | |
| pred, | |
| batch, | |
| thr_vec=args.succ_thr_vec, | |
| thr_chk=args.succ_thr_chk, | |
| thr_trj=args.succ_thr_trj, | |
| ) | |
| # Hard-set accumulation: samples where any branch MSE exceeds hard thresholds | |
| hard_mask = (mse_vec_s > hard_thr_vec) | (mse_chk_s > hard_thr_chk) | (mse_trj_s > hard_thr_trj) | |
| hard_count = int(hard_mask.sum().item()) | |
| if hard_count > 0: | |
| sum_hard_mse_vec += mse_vec_s[hard_mask].sum().item() | |
| sum_hard_mse_chk += mse_chk_s[hard_mask].sum().item() | |
| sum_hard_mse_trj += mse_trj_s[hard_mask].sum().item() | |
| total_hard_samples += hard_count | |
| sem_factors = pred.get("semantic_factors", None) | |
| if sem_factors is not None: | |
| if sem_factors.dim() == 3: | |
| sem_pool = sem_factors.mean(dim=1) | |
| elif sem_factors.dim() == 2: | |
| sem_pool = sem_factors | |
| else: | |
| sem_pool = sem_factors.view(sem_factors.size(0), -1) | |
| txt_pool = text_tokens.mean(dim=1) | |
| sem_align = cosine_alignment(sem_pool, txt_pool) | |
| else: | |
| sem_align = float("nan") | |
| sum_mse_vec += mse["mse_vector"] | |
| sum_mse_chk += mse["mse_chunk"] | |
| sum_mse_trj += mse["mse_traj"] | |
| if not (tau_l2 != tau_l2): | |
| sum_tau_l2 += tau_l2 | |
| if not (sem_align != sem_align): | |
| sum_sem_align += sem_align | |
| n_batches += 1 | |
| if accelerator.is_main_process and bidx % 20 == 0: | |
| print( | |
| f"batch={bidx} " | |
| f"mse_vec={mse['mse_vector']:.4f} mse_chk={mse['mse_chunk']:.4f} mse_trj={mse['mse_traj']:.4f} " | |
| f"tau_l2={tau_l2:.4f} sem_align={sem_align:.4f}" | |
| ) | |
| if accelerator.is_main_process: | |
| avg_mse_vec = sum_mse_vec / max(1, n_batches) | |
| avg_mse_chk = sum_mse_chk / max(1, n_batches) | |
| avg_mse_trj = sum_mse_trj / max(1, n_batches) | |
| avg_tau_l2 = sum_tau_l2 / max(1, n_batches) | |
| avg_sem_align = sum_sem_align / max(1, n_batches) | |
| if total_hard_samples > 0: | |
| avg_hard_mse_vec = sum_hard_mse_vec / float(total_hard_samples) | |
| avg_hard_mse_chk = sum_hard_mse_chk / float(total_hard_samples) | |
| avg_hard_mse_trj = sum_hard_mse_trj / float(total_hard_samples) | |
| else: | |
| avg_hard_mse_vec = float("nan") | |
| avg_hard_mse_chk = float("nan") | |
| avg_hard_mse_trj = float("nan") | |
| hard_fraction = float(total_hard_samples / max(1, n_samples)) | |
| report = { | |
| "num_samples": n_samples, | |
| "num_batches": n_batches, | |
| "avg_mse_vector": avg_mse_vec, | |
| "avg_mse_chunk": avg_mse_chk, | |
| "avg_mse_traj": avg_mse_trj, | |
| "avg_tau_l2": avg_tau_l2, | |
| "avg_semantic_text_alignment": avg_sem_align, | |
| "hard_thresholds": { | |
| "vec": hard_thr_vec, | |
| "chk": hard_thr_chk, | |
| "trj": hard_thr_trj, | |
| }, | |
| "avg_hard_mse_vector": avg_hard_mse_vec, | |
| "avg_hard_mse_chunk": avg_hard_mse_chk, | |
| "avg_hard_mse_traj": avg_hard_mse_trj, | |
| "hard_sample_fraction": hard_fraction, | |
| "total_hard_samples": int(total_hard_samples), | |
| } | |
| with open( | |
| os.path.join(args.output_dir, "sigma_eval_report.json"), | |
| "w", | |
| encoding="utf-8", | |
| ) as f: | |
| json.dump(report, f, indent=2) | |
| print("[DONE] Saved report:", report) | |
| if __name__ == "__main__": | |
| main() | |