| |
| """ |
| Experiment C: T5 Cross-modal sensor-to-text retrieval. |
| |
| Per-action-segment contrastive training: |
| - Sensor encoder: Transformer over the multimodal sensor window covering the |
| annotated segment (with 1s context padding each side). |
| - Text encoder: small Transformer trained from scratch over character tokens |
| of the segment's Chinese natural-language description. We treat the |
| segment's four description fields {task, left_hand, right_hand, |
| bimanual_interaction} as four "paraphrased variants" of the same segment, |
| as claimed by the paper. |
| |
| Loss: symmetric InfoNCE (CLIP-style). |
| Eval: Recall@{1, 5, 10} with K=100 distractors sampled from the test pool. |
| |
| Annotations live in ${PULSE_ROOT}/annotations_v2/ (18 |
| volunteers, 127 files, 2,409 fine-grained segments with action_label). |
| Subject-independent split: test = v25, v26, v27, v3 (same as T1). |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import re |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, |
| load_modality_array, SCENE_LABELS, |
| ) |
|
|
| ANNOT_DIR = '${PULSE_ROOT}/annotations_v2' |
|
|
|
|
| |
| |
| |
|
|
| def parse_timestamp(ts): |
| """Parse 'MM:SS-MM:SS' -> (start_sec, end_sec).""" |
| m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', ts) |
| if not m: |
| return None |
| sm, ss, em, es = map(int, m.groups()) |
| return sm * 60 + ss, em * 60 + es |
|
|
|
|
| def collect_segments(volunteers): |
| """Scan annotation files and return a list of per-segment dicts with |
| timestamp, 4 text views, scene, volunteer.""" |
| out = [] |
| for vol in volunteers: |
| vol_dir = os.path.join(ANNOT_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for fn in sorted(os.listdir(vol_dir)): |
| if not fn.endswith('.json'): |
| continue |
| scene = fn.replace('.json', '') |
| if scene not in SCENE_LABELS: |
| continue |
| try: |
| d = json.load(open(os.path.join(vol_dir, fn))) |
| except Exception: |
| continue |
| for seg in d.get('segments', []): |
| ts = parse_timestamp(seg.get('timestamp', '')) |
| if ts is None: |
| continue |
| |
| texts = [] |
| for k in ['task', 'left_hand', 'right_hand', 'bimanual_interaction']: |
| t = seg.get(k, '').strip() |
| if t: |
| texts.append(t) |
| if len(texts) == 0: |
| continue |
| out.append({ |
| 'vol': vol, |
| 'scene': scene, |
| 't_start': ts[0], |
| 't_end': ts[1], |
| 'texts': texts, |
| 'action_label': seg.get('action_label', ''), |
| }) |
| print(f" Collected {len(out)} annotated segments from " |
| f"{len(set((s['vol'], s['scene']) for s in out))} recordings") |
| return out |
|
|
|
|
| |
| |
| |
|
|
| PAD, UNK = 0, 1 |
|
|
|
|
| def build_vocab(segments, min_count=1): |
| from collections import Counter |
| c = Counter() |
| for s in segments: |
| for t in s['texts']: |
| for ch in t: |
| c[ch] += 1 |
| vocab = {'<pad>': PAD, '<unk>': UNK} |
| for ch, cnt in c.most_common(): |
| if cnt >= min_count: |
| vocab[ch] = len(vocab) |
| return vocab |
|
|
|
|
| def tokenize(text, vocab, max_len=64): |
| ids = [vocab.get(ch, UNK) for ch in text][:max_len] |
| return ids |
|
|
|
|
| |
| |
| |
|
|
| class SegmentRetrievalDataset(Dataset): |
| """Per-segment sensor window + 4 Chinese caption variants.""" |
|
|
| def __init__(self, segments, modalities, vocab, downsample=5, |
| context_pad_sec=1.0, max_text_len=64, stats=None): |
| self.modalities = modalities |
| self.downsample = downsample |
| self.max_text_len = max_text_len |
| self.vocab = vocab |
| |
| self._sensor_cache = {} |
| self._modality_dims = {} |
| self.items = [] |
| skipped = 0 |
| for seg in segments: |
| vol, scene = seg['vol'], seg['scene'] |
| arr = self._load_recording(vol, scene) |
| if arr is None: |
| skipped += 1 |
| continue |
| |
| sr = 100 |
| t0 = max(0, int((seg['t_start'] - context_pad_sec) * sr)) |
| t1 = min(arr.shape[0], int((seg['t_end'] + context_pad_sec) * sr)) |
| if t1 - t0 < sr * 0.3: |
| skipped += 1 |
| continue |
| window = arr[t0:t1:downsample] |
| if window.shape[0] < 4: |
| skipped += 1 |
| continue |
| self.items.append({ |
| 'window': window.astype(np.float32), |
| 'texts': seg['texts'], |
| 'action_label': seg.get('action_label', ''), |
| 'src': f"{vol}/{scene}@{seg['t_start']}-{seg['t_end']}", |
| }) |
| print(f" Materialized {len(self.items)} segments (skipped {skipped}), " |
| f"feat dim {sum(self._modality_dims.values())}") |
|
|
| |
| all_frames = np.concatenate([it['window'] for it in self.items], axis=0).astype(np.float64) |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| self.mean = all_frames.mean(axis=0, keepdims=True) |
| self.std = all_frames.std(axis=0, keepdims=True) |
| self.std[self.std < 1e-8] = 1.0 |
| for it in self.items: |
| it['window'] = ((it['window'].astype(np.float64) - self.mean) / |
| self.std).astype(np.float32) |
| it['window'] = np.nan_to_num(it['window'], nan=0.0, posinf=0.0, neginf=0.0) |
|
|
| def _load_recording(self, vol, scene): |
| key = (vol, scene) |
| if key in self._sensor_cache: |
| return self._sensor_cache[key] |
| scenario_dir = os.path.join(DATASET_DIR, vol, scene) |
| if not os.path.isdir(scenario_dir): |
| self._sensor_cache[key] = None |
| return None |
| parts = [] |
| for mod in self.modalities: |
| if mod == 'mocap': |
| fp = os.path.join(scenario_dir, f"aligned_{vol}{scene}_s_Q.tsv") |
| else: |
| fp = os.path.join(scenario_dir, MODALITY_FILES[mod]) |
| if not os.path.exists(fp): |
| self._sensor_cache[key] = None |
| return None |
| arr = load_modality_array(fp, mod) |
| if arr is None: |
| self._sensor_cache[key] = None |
| return None |
| if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]: |
| expected = self._modality_dims[mod] |
| if arr.shape[1] < expected: |
| pad = np.zeros((arr.shape[0], expected - arr.shape[1]), |
| dtype=np.float32) |
| arr = np.concatenate([arr, pad], axis=1) |
| else: |
| arr = arr[:, :expected] |
| if mod not in self._modality_dims: |
| self._modality_dims[mod] = arr.shape[1] |
| parts.append(arr) |
| T_min = min(p.shape[0] for p in parts) |
| combined = np.concatenate([p[:T_min] for p in parts], axis=1) |
| self._sensor_cache[key] = combined |
| return combined |
|
|
| @property |
| def feat_dim(self): |
| return sum(self._modality_dims.values()) |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def __getitem__(self, idx): |
| it = self.items[idx] |
| |
| text = random.choice(it['texts']) |
| tok = tokenize(text, self.vocab, max_len=self.max_text_len) |
| return { |
| 'window': torch.from_numpy(it['window']), |
| 'text_ids': torch.LongTensor(tok), |
| 'all_texts': it['texts'], |
| 'src': it['src'], |
| } |
|
|
|
|
| def retrieval_collate(batch): |
| windows = [b['window'] for b in batch] |
| seq_lens = torch.LongTensor([w.shape[0] for w in windows]) |
| padded_w = pad_sequence(windows, batch_first=True, padding_value=0.0) |
| max_w = padded_w.shape[1] |
| w_mask = torch.arange(max_w).unsqueeze(0) < seq_lens.unsqueeze(1) |
|
|
| text_ids = [b['text_ids'] for b in batch] |
| tok_lens = torch.LongTensor([t.shape[0] for t in text_ids]) |
| padded_t = pad_sequence(text_ids, batch_first=True, padding_value=PAD) |
| max_t = padded_t.shape[1] |
| t_mask = torch.arange(max_t).unsqueeze(0) < tok_lens.unsqueeze(1) |
|
|
| return { |
| 'window': padded_w, |
| 'window_mask': w_mask, |
| 'text_ids': padded_t, |
| 'text_mask': t_mask, |
| 'srcs': [b['src'] for b in batch], |
| 'all_texts': [b['all_texts'] for b in batch], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class SensorEncoder(nn.Module): |
| def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4, |
| dropout=0.2, emb_dim=128): |
| super().__init__() |
| self.input_proj = nn.Linear(feat_dim, hidden_dim) |
| self.pos_enc = nn.Parameter(torch.zeros(1, 2048, hidden_dim)) |
| nn.init.trunc_normal_(self.pos_enc, std=0.02) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, nhead=n_heads, |
| dim_feedforward=4 * hidden_dim, dropout=dropout, |
| batch_first=True, activation='gelu', |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
| self.proj = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, emb_dim), |
| ) |
|
|
| def forward(self, x, mask): |
| T = x.size(1) |
| h = self.input_proj(x) + self.pos_enc[:, :T, :] |
| key_padding = ~mask |
| h = self.encoder(h, src_key_padding_mask=key_padding) |
| |
| m = mask.unsqueeze(-1).float() |
| pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
| return F.normalize(self.proj(pooled), dim=-1) |
|
|
|
|
| class TextEncoder(nn.Module): |
| def __init__(self, vocab_size, hidden_dim=128, n_layers=2, n_heads=4, |
| dropout=0.2, emb_dim=128, max_len=64): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=PAD) |
| self.pos_enc = nn.Parameter(torch.zeros(1, max_len, hidden_dim)) |
| nn.init.trunc_normal_(self.pos_enc, std=0.02) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, nhead=n_heads, |
| dim_feedforward=4 * hidden_dim, dropout=dropout, |
| batch_first=True, activation='gelu', |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
| self.proj = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, emb_dim), |
| ) |
|
|
| def forward(self, ids, mask): |
| T = ids.size(1) |
| h = self.embed(ids) + self.pos_enc[:, :T, :] |
| key_padding = ~mask |
| h = self.encoder(h, src_key_padding_mask=key_padding) |
| m = mask.unsqueeze(-1).float() |
| pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
| return F.normalize(self.proj(pooled), dim=-1) |
|
|
|
|
| class TwoTowerRetrieval(nn.Module): |
| def __init__(self, feat_dim, vocab_size, hidden_dim=128, emb_dim=128, |
| max_text_len=64, dropout=0.2): |
| super().__init__() |
| self.sensor = SensorEncoder(feat_dim, hidden_dim, emb_dim=emb_dim, |
| dropout=dropout) |
| self.text = TextEncoder(vocab_size, hidden_dim, emb_dim=emb_dim, |
| max_len=max_text_len, dropout=dropout) |
| self.logit_scale = nn.Parameter(torch.ones(1) * np.log(1 / 0.07)) |
|
|
| def forward(self, batch): |
| se = self.sensor(batch['window'], batch['window_mask']) |
| te = self.text(batch['text_ids'], batch['text_mask']) |
| return se, te |
|
|
|
|
| |
| |
| |
|
|
| def info_nce(se, te, logit_scale): |
| """Symmetric InfoNCE.""" |
| scale = logit_scale.exp().clamp(max=100.0) |
| logits = scale * se @ te.t() |
| B = logits.size(0) |
| targets = torch.arange(B, device=logits.device) |
| loss_s2t = F.cross_entropy(logits, targets) |
| loss_t2s = F.cross_entropy(logits.t(), targets) |
| return 0.5 * (loss_s2t + loss_t2s) |
|
|
|
|
| |
| |
| |
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_one_epoch(model, loader, optimizer, device): |
| model.train() |
| total = 0.0 |
| n = 0 |
| for batch in loader: |
| batch = {k: v.to(device) if torch.is_tensor(v) else v |
| for k, v in batch.items()} |
| optimizer.zero_grad() |
| se, te = model(batch) |
| loss = info_nce(se, te, model.logit_scale) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total += loss.item() * se.size(0) |
| n += se.size(0) |
| return total / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_retrieval(model, loader, vocab, device, K=100, seed=0): |
| """Sensor -> text retrieval. For each sensor query, build pool of |
| 1 correct + K-1 distractors from other test segments, compute rank.""" |
| model.eval() |
| |
| all_se = [] |
| all_texts = [] |
| srcs = [] |
| for batch in loader: |
| dev_batch = {k: v.to(device) if torch.is_tensor(v) else v |
| for k, v in batch.items()} |
| se = model.sensor(dev_batch['window'], dev_batch['window_mask']) |
| all_se.append(se.cpu()) |
| |
| for texts in batch['all_texts']: |
| all_texts.append(texts[0]) |
| srcs.extend(batch['srcs']) |
| all_se = torch.cat(all_se, dim=0) |
| |
| text_embs = [] |
| for i in range(0, len(all_texts), 64): |
| chunk = all_texts[i:i + 64] |
| tok_lists = [tokenize(t, vocab, max_len=64) for t in chunk] |
| lens = [len(t) for t in tok_lists] |
| max_len = max(lens) |
| pad_ids = torch.zeros(len(chunk), max_len, dtype=torch.long) |
| mask = torch.zeros(len(chunk), max_len, dtype=torch.bool) |
| for j, t in enumerate(tok_lists): |
| pad_ids[j, :len(t)] = torch.LongTensor(t) |
| mask[j, :len(t)] = True |
| pad_ids = pad_ids.to(device) |
| mask = mask.to(device) |
| te = model.text(pad_ids, mask).cpu() |
| text_embs.append(te) |
| text_embs = torch.cat(text_embs, dim=0) |
|
|
| |
| rng = np.random.RandomState(seed) |
| N = all_se.shape[0] |
| ranks = [] |
| for i in range(N): |
| pool_size = min(K, N) |
| neg_candidates = [j for j in range(N) if j != i] |
| if len(neg_candidates) < pool_size - 1: |
| pool = [i] + neg_candidates |
| else: |
| neg = rng.choice(neg_candidates, size=pool_size - 1, replace=False) |
| pool = [i] + neg.tolist() |
| |
| q = all_se[i:i + 1] |
| pool_texts = text_embs[pool] |
| sims = (q @ pool_texts.t()).squeeze(0).numpy() |
| |
| order = np.argsort(-sims) |
| rank = int(np.where(order == 0)[0][0]) + 1 |
| ranks.append(rank) |
| ranks = np.array(ranks) |
| return { |
| 'N': int(N), |
| 'K': int(K), |
| 'recall@1': float((ranks <= 1).mean()), |
| 'recall@5': float((ranks <= 5).mean()), |
| 'recall@10': float((ranks <= 10).mean()), |
| 'median_rank': float(np.median(ranks)), |
| 'mean_rank': float(ranks.mean()), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| modalities = args.modalities.split(',') |
| print(f"Modalities: {modalities} | Seed: {args.seed}") |
|
|
| print("Collecting train segments...") |
| train_segs = collect_segments(TRAIN_VOLS) |
| print("Collecting test segments...") |
| test_segs = collect_segments(TEST_VOLS) |
|
|
| |
| vocab = build_vocab(train_segs) |
| print(f" Vocab size: {len(vocab)}") |
|
|
| print("Building train dataset...") |
| train_ds = SegmentRetrievalDataset( |
| train_segs, modalities, vocab, downsample=args.downsample, |
| context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len, |
| ) |
| stats = train_ds.get_stats() |
| print("Building test dataset...") |
| test_ds = SegmentRetrievalDataset( |
| test_segs, modalities, vocab, downsample=args.downsample, |
| context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len, |
| stats=stats, |
| ) |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=retrieval_collate, num_workers=0, |
| drop_last=True) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=retrieval_collate, num_workers=0) |
|
|
| model = TwoTowerRetrieval( |
| train_ds.feat_dim, len(vocab), |
| hidden_dim=args.hidden_dim, emb_dim=args.emb_dim, |
| max_text_len=args.max_text_len, dropout=args.dropout, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Params: {n_params:,}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=args.epochs, eta_min=1e-6, |
| ) |
|
|
| mod_str = '-'.join(modalities) |
| exp_name = f"retrieval_{mod_str}_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_r10 = 0.0 |
| best_metrics = None |
| best_state = None |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| loss = train_one_epoch(model, train_loader, optimizer, device) |
| scheduler.step() |
| if epoch % args.eval_every == 0 or epoch == args.epochs: |
| m = evaluate_retrieval(model, test_loader, vocab, device, |
| K=args.K, seed=args.seed) |
| print(f" E{epoch:3d} | loss {loss:.4f} | R@1 {m['recall@1']:.3f} " |
| f"R@5 {m['recall@5']:.3f} R@10 {m['recall@10']:.3f} " |
| f"medR {m['median_rank']:.1f} | {time.time()-t0:.1f}s") |
| if m['recall@10'] > best_r10: |
| best_r10 = m['recall@10'] |
| best_metrics = m |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| else: |
| print(f" E{epoch:3d} | loss {loss:.4f} | {time.time()-t0:.1f}s") |
|
|
| if best_state is not None: |
| torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) |
|
|
| |
| model.load_state_dict(best_state) |
| final_metrics = [] |
| for s in range(3): |
| m = evaluate_retrieval(model, test_loader, vocab, device, |
| K=args.K, seed=1000 + s) |
| final_metrics.append(m) |
| avg = {k: float(np.mean([fm[k] for fm in final_metrics])) |
| for k in ['recall@1', 'recall@5', 'recall@10', 'median_rank', 'mean_rank']} |
| std = {k: float(np.std([fm[k] for fm in final_metrics])) |
| for k in ['recall@1', 'recall@5', 'recall@10']} |
|
|
| results = { |
| 'experiment': exp_name, |
| 'modalities': modalities, |
| 'seed': args.seed, |
| 'K_pool': args.K, |
| 'n_train_segments': len(train_ds), |
| 'n_test_segments': len(test_ds), |
| 'vocab_size': len(vocab), |
| 'best_recall10': float(best_r10), |
| 'best_metrics': best_metrics, |
| 'final_avg_over_3_pool_seeds': avg, |
| 'final_std_over_3_pool_seeds': std, |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| print(f"Saved: {out_dir}/results.json") |
| print(f"Final (avg over 3 pool seeds): R@1 {avg['recall@1']:.3f} " |
| f"R@5 {avg['recall@5']:.3f} R@10 {avg['recall@10']:.3f}") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu') |
| p.add_argument('--epochs', type=int, default=60) |
| p.add_argument('--batch_size', type=int, default=64) |
| p.add_argument('--lr', type=float, default=5e-4) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--emb_dim', type=int, default=128) |
| p.add_argument('--dropout', type=float, default=0.2) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--context_pad_sec', type=float, default=1.0) |
| p.add_argument('--max_text_len', type=int, default=64) |
| p.add_argument('--K', type=int, default=100) |
| p.add_argument('--eval_every', type=int, default=5) |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| args = p.parse_args() |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|