#!/usr/bin/env python3 """ Experiment C: T5 Cross-modal sensor-to-text retrieval. Per-action-segment contrastive training: - Sensor encoder: Transformer over the multimodal sensor window covering the annotated segment (with 1s context padding each side). - Text encoder: small Transformer trained from scratch over character tokens of the segment's Chinese natural-language description. We treat the segment's four description fields {task, left_hand, right_hand, bimanual_interaction} as four "paraphrased variants" of the same segment, as claimed by the paper. Loss: symmetric InfoNCE (CLIP-style). Eval: Recall@{1, 5, 10} with K=100 distractors sampled from the test pool. Annotations live in ${PULSE_ROOT}/annotations_v2/ (18 volunteers, 127 files, 2,409 fine-grained segments with action_label). Subject-independent split: test = v25, v26, v27, v3 (same as T1). """ import os import sys import json import time import random import argparse import re import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import ( DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, load_modality_array, SCENE_LABELS, ) ANNOT_DIR = '${PULSE_ROOT}/annotations_v2' # --------------------------------------------------------------------------- # Annotation loading # --------------------------------------------------------------------------- def parse_timestamp(ts): """Parse 'MM:SS-MM:SS' -> (start_sec, end_sec).""" m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', ts) if not m: return None sm, ss, em, es = map(int, m.groups()) return sm * 60 + ss, em * 60 + es def collect_segments(volunteers): """Scan annotation files and return a list of per-segment dicts with timestamp, 4 text views, scene, volunteer.""" out = [] for vol in volunteers: vol_dir = os.path.join(ANNOT_DIR, vol) if not os.path.isdir(vol_dir): continue for fn in sorted(os.listdir(vol_dir)): if not fn.endswith('.json'): continue scene = fn.replace('.json', '') if scene not in SCENE_LABELS: continue try: d = json.load(open(os.path.join(vol_dir, fn))) except Exception: continue for seg in d.get('segments', []): ts = parse_timestamp(seg.get('timestamp', '')) if ts is None: continue # Four text views -- paper's "four paraphrased variants" texts = [] for k in ['task', 'left_hand', 'right_hand', 'bimanual_interaction']: t = seg.get(k, '').strip() if t: texts.append(t) if len(texts) == 0: continue out.append({ 'vol': vol, 'scene': scene, 't_start': ts[0], 't_end': ts[1], 'texts': texts, 'action_label': seg.get('action_label', ''), }) print(f" Collected {len(out)} annotated segments from " f"{len(set((s['vol'], s['scene']) for s in out))} recordings") return out # --------------------------------------------------------------------------- # Vocabulary for Chinese character tokenization # --------------------------------------------------------------------------- PAD, UNK = 0, 1 def build_vocab(segments, min_count=1): from collections import Counter c = Counter() for s in segments: for t in s['texts']: for ch in t: c[ch] += 1 vocab = {'': PAD, '': UNK} for ch, cnt in c.most_common(): if cnt >= min_count: vocab[ch] = len(vocab) return vocab def tokenize(text, vocab, max_len=64): ids = [vocab.get(ch, UNK) for ch in text][:max_len] return ids # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class SegmentRetrievalDataset(Dataset): """Per-segment sensor window + 4 Chinese caption variants.""" def __init__(self, segments, modalities, vocab, downsample=5, context_pad_sec=1.0, max_text_len=64, stats=None): self.modalities = modalities self.downsample = downsample self.max_text_len = max_text_len self.vocab = vocab # Cache sensor data per recording to avoid re-loading self._sensor_cache = {} self._modality_dims = {} self.items = [] skipped = 0 for seg in segments: vol, scene = seg['vol'], seg['scene'] arr = self._load_recording(vol, scene) if arr is None: skipped += 1 continue # Compute sample window sr = 100 # Hz, before downsample t0 = max(0, int((seg['t_start'] - context_pad_sec) * sr)) t1 = min(arr.shape[0], int((seg['t_end'] + context_pad_sec) * sr)) if t1 - t0 < sr * 0.3: # <0.3s, skip degenerate skipped += 1 continue window = arr[t0:t1:downsample] # downsampled sensor window if window.shape[0] < 4: skipped += 1 continue self.items.append({ 'window': window.astype(np.float32), 'texts': seg['texts'], 'action_label': seg.get('action_label', ''), 'src': f"{vol}/{scene}@{seg['t_start']}-{seg['t_end']}", }) print(f" Materialized {len(self.items)} segments (skipped {skipped}), " f"feat dim {sum(self._modality_dims.values())}") # Normalize (using train stats if provided) all_frames = np.concatenate([it['window'] for it in self.items], axis=0).astype(np.float64) if stats is not None: self.mean, self.std = stats else: self.mean = all_frames.mean(axis=0, keepdims=True) self.std = all_frames.std(axis=0, keepdims=True) self.std[self.std < 1e-8] = 1.0 for it in self.items: it['window'] = ((it['window'].astype(np.float64) - self.mean) / self.std).astype(np.float32) it['window'] = np.nan_to_num(it['window'], nan=0.0, posinf=0.0, neginf=0.0) def _load_recording(self, vol, scene): key = (vol, scene) if key in self._sensor_cache: return self._sensor_cache[key] scenario_dir = os.path.join(DATASET_DIR, vol, scene) if not os.path.isdir(scenario_dir): self._sensor_cache[key] = None return None parts = [] for mod in self.modalities: if mod == 'mocap': fp = os.path.join(scenario_dir, f"aligned_{vol}{scene}_s_Q.tsv") else: fp = os.path.join(scenario_dir, MODALITY_FILES[mod]) if not os.path.exists(fp): self._sensor_cache[key] = None return None arr = load_modality_array(fp, mod) if arr is None: self._sensor_cache[key] = None return None if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]: expected = self._modality_dims[mod] if arr.shape[1] < expected: pad = np.zeros((arr.shape[0], expected - arr.shape[1]), dtype=np.float32) arr = np.concatenate([arr, pad], axis=1) else: arr = arr[:, :expected] if mod not in self._modality_dims: self._modality_dims[mod] = arr.shape[1] parts.append(arr) T_min = min(p.shape[0] for p in parts) combined = np.concatenate([p[:T_min] for p in parts], axis=1) self._sensor_cache[key] = combined return combined @property def feat_dim(self): return sum(self._modality_dims.values()) def get_stats(self): return (self.mean, self.std) def __len__(self): return len(self.items) def __getitem__(self, idx): it = self.items[idx] # Randomly pick one of the 4 captions at training time text = random.choice(it['texts']) tok = tokenize(text, self.vocab, max_len=self.max_text_len) return { 'window': torch.from_numpy(it['window']), 'text_ids': torch.LongTensor(tok), 'all_texts': it['texts'], 'src': it['src'], } def retrieval_collate(batch): windows = [b['window'] for b in batch] seq_lens = torch.LongTensor([w.shape[0] for w in windows]) padded_w = pad_sequence(windows, batch_first=True, padding_value=0.0) max_w = padded_w.shape[1] w_mask = torch.arange(max_w).unsqueeze(0) < seq_lens.unsqueeze(1) text_ids = [b['text_ids'] for b in batch] tok_lens = torch.LongTensor([t.shape[0] for t in text_ids]) padded_t = pad_sequence(text_ids, batch_first=True, padding_value=PAD) max_t = padded_t.shape[1] t_mask = torch.arange(max_t).unsqueeze(0) < tok_lens.unsqueeze(1) return { 'window': padded_w, 'window_mask': w_mask, 'text_ids': padded_t, 'text_mask': t_mask, 'srcs': [b['src'] for b in batch], 'all_texts': [b['all_texts'] for b in batch], } # --------------------------------------------------------------------------- # Model: two-tower retrieval # --------------------------------------------------------------------------- class SensorEncoder(nn.Module): def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4, dropout=0.2, emb_dim=128): super().__init__() self.input_proj = nn.Linear(feat_dim, hidden_dim) self.pos_enc = nn.Parameter(torch.zeros(1, 2048, hidden_dim)) nn.init.trunc_normal_(self.pos_enc, std=0.02) enc_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=n_heads, dim_feedforward=4 * hidden_dim, dropout=dropout, batch_first=True, activation='gelu', ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) self.proj = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, emb_dim), ) def forward(self, x, mask): T = x.size(1) h = self.input_proj(x) + self.pos_enc[:, :T, :] key_padding = ~mask h = self.encoder(h, src_key_padding_mask=key_padding) # Masked mean pool m = mask.unsqueeze(-1).float() pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) return F.normalize(self.proj(pooled), dim=-1) class TextEncoder(nn.Module): def __init__(self, vocab_size, hidden_dim=128, n_layers=2, n_heads=4, dropout=0.2, emb_dim=128, max_len=64): super().__init__() self.embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=PAD) self.pos_enc = nn.Parameter(torch.zeros(1, max_len, hidden_dim)) nn.init.trunc_normal_(self.pos_enc, std=0.02) enc_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=n_heads, dim_feedforward=4 * hidden_dim, dropout=dropout, batch_first=True, activation='gelu', ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) self.proj = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, emb_dim), ) def forward(self, ids, mask): T = ids.size(1) h = self.embed(ids) + self.pos_enc[:, :T, :] key_padding = ~mask h = self.encoder(h, src_key_padding_mask=key_padding) m = mask.unsqueeze(-1).float() pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) return F.normalize(self.proj(pooled), dim=-1) class TwoTowerRetrieval(nn.Module): def __init__(self, feat_dim, vocab_size, hidden_dim=128, emb_dim=128, max_text_len=64, dropout=0.2): super().__init__() self.sensor = SensorEncoder(feat_dim, hidden_dim, emb_dim=emb_dim, dropout=dropout) self.text = TextEncoder(vocab_size, hidden_dim, emb_dim=emb_dim, max_len=max_text_len, dropout=dropout) self.logit_scale = nn.Parameter(torch.ones(1) * np.log(1 / 0.07)) def forward(self, batch): se = self.sensor(batch['window'], batch['window_mask']) te = self.text(batch['text_ids'], batch['text_mask']) return se, te # --------------------------------------------------------------------------- # Loss # --------------------------------------------------------------------------- def info_nce(se, te, logit_scale): """Symmetric InfoNCE.""" scale = logit_scale.exp().clamp(max=100.0) logits = scale * se @ te.t() # (B, B) B = logits.size(0) targets = torch.arange(B, device=logits.device) loss_s2t = F.cross_entropy(logits, targets) loss_t2s = F.cross_entropy(logits.t(), targets) return 0.5 * (loss_s2t + loss_t2s) # --------------------------------------------------------------------------- # Training / Eval # --------------------------------------------------------------------------- def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def train_one_epoch(model, loader, optimizer, device): model.train() total = 0.0 n = 0 for batch in loader: batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} optimizer.zero_grad() se, te = model(batch) loss = info_nce(se, te, model.logit_scale) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total += loss.item() * se.size(0) n += se.size(0) return total / max(n, 1) @torch.no_grad() def evaluate_retrieval(model, loader, vocab, device, K=100, seed=0): """Sensor -> text retrieval. For each sensor query, build pool of 1 correct + K-1 distractors from other test segments, compute rank.""" model.eval() # Collect all embeddings all_se = [] all_texts = [] srcs = [] for batch in loader: dev_batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} se = model.sensor(dev_batch['window'], dev_batch['window_mask']) all_se.append(se.cpu()) # For eval, use the first caption ("task") as the gold text for texts in batch['all_texts']: all_texts.append(texts[0]) srcs.extend(batch['srcs']) all_se = torch.cat(all_se, dim=0) # (N, D) # Encode all candidate texts once text_embs = [] for i in range(0, len(all_texts), 64): chunk = all_texts[i:i + 64] tok_lists = [tokenize(t, vocab, max_len=64) for t in chunk] lens = [len(t) for t in tok_lists] max_len = max(lens) pad_ids = torch.zeros(len(chunk), max_len, dtype=torch.long) mask = torch.zeros(len(chunk), max_len, dtype=torch.bool) for j, t in enumerate(tok_lists): pad_ids[j, :len(t)] = torch.LongTensor(t) mask[j, :len(t)] = True pad_ids = pad_ids.to(device) mask = mask.to(device) te = model.text(pad_ids, mask).cpu() text_embs.append(te) text_embs = torch.cat(text_embs, dim=0) # (N, D) # For each sensor query i, sample K-1 distractors from {0..N}\{i} rng = np.random.RandomState(seed) N = all_se.shape[0] ranks = [] for i in range(N): pool_size = min(K, N) neg_candidates = [j for j in range(N) if j != i] if len(neg_candidates) < pool_size - 1: pool = [i] + neg_candidates else: neg = rng.choice(neg_candidates, size=pool_size - 1, replace=False) pool = [i] + neg.tolist() # Compute similarity of query i with pool texts q = all_se[i:i + 1] # (1, D) pool_texts = text_embs[pool] # (K, D) sims = (q @ pool_texts.t()).squeeze(0).numpy() # (K,) # rank of pool[0] (the correct one) order = np.argsort(-sims) rank = int(np.where(order == 0)[0][0]) + 1 ranks.append(rank) ranks = np.array(ranks) return { 'N': int(N), 'K': int(K), 'recall@1': float((ranks <= 1).mean()), 'recall@5': float((ranks <= 5).mean()), 'recall@10': float((ranks <= 10).mean()), 'median_rank': float(np.median(ranks)), 'mean_rank': float(ranks.mean()), } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def run_experiment(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") modalities = args.modalities.split(',') print(f"Modalities: {modalities} | Seed: {args.seed}") print("Collecting train segments...") train_segs = collect_segments(TRAIN_VOLS) print("Collecting test segments...") test_segs = collect_segments(TEST_VOLS) # Build char vocab from train only vocab = build_vocab(train_segs) print(f" Vocab size: {len(vocab)}") print("Building train dataset...") train_ds = SegmentRetrievalDataset( train_segs, modalities, vocab, downsample=args.downsample, context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len, ) stats = train_ds.get_stats() print("Building test dataset...") test_ds = SegmentRetrievalDataset( test_segs, modalities, vocab, downsample=args.downsample, context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len, stats=stats, ) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=retrieval_collate, num_workers=0, drop_last=True) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=retrieval_collate, num_workers=0) model = TwoTowerRetrieval( train_ds.feat_dim, len(vocab), hidden_dim=args.hidden_dim, emb_dim=args.emb_dim, max_text_len=args.max_text_len, dropout=args.dropout, ).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Params: {n_params:,}") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=1e-6, ) mod_str = '-'.join(modalities) exp_name = f"retrieval_{mod_str}_seed{args.seed}" if args.tag: exp_name += f"_{args.tag}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) best_r10 = 0.0 best_metrics = None best_state = None for epoch in range(1, args.epochs + 1): t0 = time.time() loss = train_one_epoch(model, train_loader, optimizer, device) scheduler.step() if epoch % args.eval_every == 0 or epoch == args.epochs: m = evaluate_retrieval(model, test_loader, vocab, device, K=args.K, seed=args.seed) print(f" E{epoch:3d} | loss {loss:.4f} | R@1 {m['recall@1']:.3f} " f"R@5 {m['recall@5']:.3f} R@10 {m['recall@10']:.3f} " f"medR {m['median_rank']:.1f} | {time.time()-t0:.1f}s") if m['recall@10'] > best_r10: best_r10 = m['recall@10'] best_metrics = m best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} else: print(f" E{epoch:3d} | loss {loss:.4f} | {time.time()-t0:.1f}s") if best_state is not None: torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) # Final eval with multiple distractor pool seeds for robustness model.load_state_dict(best_state) final_metrics = [] for s in range(3): m = evaluate_retrieval(model, test_loader, vocab, device, K=args.K, seed=1000 + s) final_metrics.append(m) avg = {k: float(np.mean([fm[k] for fm in final_metrics])) for k in ['recall@1', 'recall@5', 'recall@10', 'median_rank', 'mean_rank']} std = {k: float(np.std([fm[k] for fm in final_metrics])) for k in ['recall@1', 'recall@5', 'recall@10']} results = { 'experiment': exp_name, 'modalities': modalities, 'seed': args.seed, 'K_pool': args.K, 'n_train_segments': len(train_ds), 'n_test_segments': len(test_ds), 'vocab_size': len(vocab), 'best_recall10': float(best_r10), 'best_metrics': best_metrics, 'final_avg_over_3_pool_seeds': avg, 'final_std_over_3_pool_seeds': std, 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Saved: {out_dir}/results.json") print(f"Final (avg over 3 pool seeds): R@1 {avg['recall@1']:.3f} " f"R@5 {avg['recall@5']:.3f} R@10 {avg['recall@10']:.3f}") return results def main(): p = argparse.ArgumentParser() p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu') p.add_argument('--epochs', type=int, default=60) p.add_argument('--batch_size', type=int, default=64) p.add_argument('--lr', type=float, default=5e-4) p.add_argument('--weight_decay', type=float, default=1e-4) p.add_argument('--hidden_dim', type=int, default=128) p.add_argument('--emb_dim', type=int, default=128) p.add_argument('--dropout', type=float, default=0.2) p.add_argument('--downsample', type=int, default=5) p.add_argument('--context_pad_sec', type=float, default=1.0) p.add_argument('--max_text_len', type=int, default=64) p.add_argument('--K', type=int, default=100) p.add_argument('--eval_every', type=int, default=5) p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, required=True) p.add_argument('--tag', type=str, default='') args = p.parse_args() run_experiment(args) if __name__ == '__main__': main()