PULSE-code / experiments /tasks /train_exp_retrieval.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Experiment C: T5 Cross-modal sensor-to-text retrieval.
Per-action-segment contrastive training:
- Sensor encoder: Transformer over the multimodal sensor window covering the
annotated segment (with 1s context padding each side).
- Text encoder: small Transformer trained from scratch over character tokens
of the segment's Chinese natural-language description. We treat the
segment's four description fields {task, left_hand, right_hand,
bimanual_interaction} as four "paraphrased variants" of the same segment,
as claimed by the paper.
Loss: symmetric InfoNCE (CLIP-style).
Eval: Recall@{1, 5, 10} with K=100 distractors sampled from the test pool.
Annotations live in ${PULSE_ROOT}/annotations_v2/ (18
volunteers, 127 files, 2,409 fine-grained segments with action_label).
Subject-independent split: test = v25, v26, v27, v3 (same as T1).
"""
import os
import sys
import json
import time
import random
import argparse
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS,
load_modality_array, SCENE_LABELS,
)
ANNOT_DIR = '${PULSE_ROOT}/annotations_v2'
# ---------------------------------------------------------------------------
# Annotation loading
# ---------------------------------------------------------------------------
def parse_timestamp(ts):
"""Parse 'MM:SS-MM:SS' -> (start_sec, end_sec)."""
m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', ts)
if not m:
return None
sm, ss, em, es = map(int, m.groups())
return sm * 60 + ss, em * 60 + es
def collect_segments(volunteers):
"""Scan annotation files and return a list of per-segment dicts with
timestamp, 4 text views, scene, volunteer."""
out = []
for vol in volunteers:
vol_dir = os.path.join(ANNOT_DIR, vol)
if not os.path.isdir(vol_dir):
continue
for fn in sorted(os.listdir(vol_dir)):
if not fn.endswith('.json'):
continue
scene = fn.replace('.json', '')
if scene not in SCENE_LABELS:
continue
try:
d = json.load(open(os.path.join(vol_dir, fn)))
except Exception:
continue
for seg in d.get('segments', []):
ts = parse_timestamp(seg.get('timestamp', ''))
if ts is None:
continue
# Four text views -- paper's "four paraphrased variants"
texts = []
for k in ['task', 'left_hand', 'right_hand', 'bimanual_interaction']:
t = seg.get(k, '').strip()
if t:
texts.append(t)
if len(texts) == 0:
continue
out.append({
'vol': vol,
'scene': scene,
't_start': ts[0],
't_end': ts[1],
'texts': texts,
'action_label': seg.get('action_label', ''),
})
print(f" Collected {len(out)} annotated segments from "
f"{len(set((s['vol'], s['scene']) for s in out))} recordings")
return out
# ---------------------------------------------------------------------------
# Vocabulary for Chinese character tokenization
# ---------------------------------------------------------------------------
PAD, UNK = 0, 1
def build_vocab(segments, min_count=1):
from collections import Counter
c = Counter()
for s in segments:
for t in s['texts']:
for ch in t:
c[ch] += 1
vocab = {'<pad>': PAD, '<unk>': UNK}
for ch, cnt in c.most_common():
if cnt >= min_count:
vocab[ch] = len(vocab)
return vocab
def tokenize(text, vocab, max_len=64):
ids = [vocab.get(ch, UNK) for ch in text][:max_len]
return ids
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class SegmentRetrievalDataset(Dataset):
"""Per-segment sensor window + 4 Chinese caption variants."""
def __init__(self, segments, modalities, vocab, downsample=5,
context_pad_sec=1.0, max_text_len=64, stats=None):
self.modalities = modalities
self.downsample = downsample
self.max_text_len = max_text_len
self.vocab = vocab
# Cache sensor data per recording to avoid re-loading
self._sensor_cache = {}
self._modality_dims = {}
self.items = []
skipped = 0
for seg in segments:
vol, scene = seg['vol'], seg['scene']
arr = self._load_recording(vol, scene)
if arr is None:
skipped += 1
continue
# Compute sample window
sr = 100 # Hz, before downsample
t0 = max(0, int((seg['t_start'] - context_pad_sec) * sr))
t1 = min(arr.shape[0], int((seg['t_end'] + context_pad_sec) * sr))
if t1 - t0 < sr * 0.3: # <0.3s, skip degenerate
skipped += 1
continue
window = arr[t0:t1:downsample] # downsampled sensor window
if window.shape[0] < 4:
skipped += 1
continue
self.items.append({
'window': window.astype(np.float32),
'texts': seg['texts'],
'action_label': seg.get('action_label', ''),
'src': f"{vol}/{scene}@{seg['t_start']}-{seg['t_end']}",
})
print(f" Materialized {len(self.items)} segments (skipped {skipped}), "
f"feat dim {sum(self._modality_dims.values())}")
# Normalize (using train stats if provided)
all_frames = np.concatenate([it['window'] for it in self.items], axis=0).astype(np.float64)
if stats is not None:
self.mean, self.std = stats
else:
self.mean = all_frames.mean(axis=0, keepdims=True)
self.std = all_frames.std(axis=0, keepdims=True)
self.std[self.std < 1e-8] = 1.0
for it in self.items:
it['window'] = ((it['window'].astype(np.float64) - self.mean) /
self.std).astype(np.float32)
it['window'] = np.nan_to_num(it['window'], nan=0.0, posinf=0.0, neginf=0.0)
def _load_recording(self, vol, scene):
key = (vol, scene)
if key in self._sensor_cache:
return self._sensor_cache[key]
scenario_dir = os.path.join(DATASET_DIR, vol, scene)
if not os.path.isdir(scenario_dir):
self._sensor_cache[key] = None
return None
parts = []
for mod in self.modalities:
if mod == 'mocap':
fp = os.path.join(scenario_dir, f"aligned_{vol}{scene}_s_Q.tsv")
else:
fp = os.path.join(scenario_dir, MODALITY_FILES[mod])
if not os.path.exists(fp):
self._sensor_cache[key] = None
return None
arr = load_modality_array(fp, mod)
if arr is None:
self._sensor_cache[key] = None
return None
if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]:
expected = self._modality_dims[mod]
if arr.shape[1] < expected:
pad = np.zeros((arr.shape[0], expected - arr.shape[1]),
dtype=np.float32)
arr = np.concatenate([arr, pad], axis=1)
else:
arr = arr[:, :expected]
if mod not in self._modality_dims:
self._modality_dims[mod] = arr.shape[1]
parts.append(arr)
T_min = min(p.shape[0] for p in parts)
combined = np.concatenate([p[:T_min] for p in parts], axis=1)
self._sensor_cache[key] = combined
return combined
@property
def feat_dim(self):
return sum(self._modality_dims.values())
def get_stats(self):
return (self.mean, self.std)
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
it = self.items[idx]
# Randomly pick one of the 4 captions at training time
text = random.choice(it['texts'])
tok = tokenize(text, self.vocab, max_len=self.max_text_len)
return {
'window': torch.from_numpy(it['window']),
'text_ids': torch.LongTensor(tok),
'all_texts': it['texts'],
'src': it['src'],
}
def retrieval_collate(batch):
windows = [b['window'] for b in batch]
seq_lens = torch.LongTensor([w.shape[0] for w in windows])
padded_w = pad_sequence(windows, batch_first=True, padding_value=0.0)
max_w = padded_w.shape[1]
w_mask = torch.arange(max_w).unsqueeze(0) < seq_lens.unsqueeze(1)
text_ids = [b['text_ids'] for b in batch]
tok_lens = torch.LongTensor([t.shape[0] for t in text_ids])
padded_t = pad_sequence(text_ids, batch_first=True, padding_value=PAD)
max_t = padded_t.shape[1]
t_mask = torch.arange(max_t).unsqueeze(0) < tok_lens.unsqueeze(1)
return {
'window': padded_w,
'window_mask': w_mask,
'text_ids': padded_t,
'text_mask': t_mask,
'srcs': [b['src'] for b in batch],
'all_texts': [b['all_texts'] for b in batch],
}
# ---------------------------------------------------------------------------
# Model: two-tower retrieval
# ---------------------------------------------------------------------------
class SensorEncoder(nn.Module):
def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4,
dropout=0.2, emb_dim=128):
super().__init__()
self.input_proj = nn.Linear(feat_dim, hidden_dim)
self.pos_enc = nn.Parameter(torch.zeros(1, 2048, hidden_dim))
nn.init.trunc_normal_(self.pos_enc, std=0.02)
enc_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=n_heads,
dim_feedforward=4 * hidden_dim, dropout=dropout,
batch_first=True, activation='gelu',
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
self.proj = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, emb_dim),
)
def forward(self, x, mask):
T = x.size(1)
h = self.input_proj(x) + self.pos_enc[:, :T, :]
key_padding = ~mask
h = self.encoder(h, src_key_padding_mask=key_padding)
# Masked mean pool
m = mask.unsqueeze(-1).float()
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
return F.normalize(self.proj(pooled), dim=-1)
class TextEncoder(nn.Module):
def __init__(self, vocab_size, hidden_dim=128, n_layers=2, n_heads=4,
dropout=0.2, emb_dim=128, max_len=64):
super().__init__()
self.embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=PAD)
self.pos_enc = nn.Parameter(torch.zeros(1, max_len, hidden_dim))
nn.init.trunc_normal_(self.pos_enc, std=0.02)
enc_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=n_heads,
dim_feedforward=4 * hidden_dim, dropout=dropout,
batch_first=True, activation='gelu',
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
self.proj = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, emb_dim),
)
def forward(self, ids, mask):
T = ids.size(1)
h = self.embed(ids) + self.pos_enc[:, :T, :]
key_padding = ~mask
h = self.encoder(h, src_key_padding_mask=key_padding)
m = mask.unsqueeze(-1).float()
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
return F.normalize(self.proj(pooled), dim=-1)
class TwoTowerRetrieval(nn.Module):
def __init__(self, feat_dim, vocab_size, hidden_dim=128, emb_dim=128,
max_text_len=64, dropout=0.2):
super().__init__()
self.sensor = SensorEncoder(feat_dim, hidden_dim, emb_dim=emb_dim,
dropout=dropout)
self.text = TextEncoder(vocab_size, hidden_dim, emb_dim=emb_dim,
max_len=max_text_len, dropout=dropout)
self.logit_scale = nn.Parameter(torch.ones(1) * np.log(1 / 0.07))
def forward(self, batch):
se = self.sensor(batch['window'], batch['window_mask'])
te = self.text(batch['text_ids'], batch['text_mask'])
return se, te
# ---------------------------------------------------------------------------
# Loss
# ---------------------------------------------------------------------------
def info_nce(se, te, logit_scale):
"""Symmetric InfoNCE."""
scale = logit_scale.exp().clamp(max=100.0)
logits = scale * se @ te.t() # (B, B)
B = logits.size(0)
targets = torch.arange(B, device=logits.device)
loss_s2t = F.cross_entropy(logits, targets)
loss_t2s = F.cross_entropy(logits.t(), targets)
return 0.5 * (loss_s2t + loss_t2s)
# ---------------------------------------------------------------------------
# Training / Eval
# ---------------------------------------------------------------------------
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def train_one_epoch(model, loader, optimizer, device):
model.train()
total = 0.0
n = 0
for batch in loader:
batch = {k: v.to(device) if torch.is_tensor(v) else v
for k, v in batch.items()}
optimizer.zero_grad()
se, te = model(batch)
loss = info_nce(se, te, model.logit_scale)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total += loss.item() * se.size(0)
n += se.size(0)
return total / max(n, 1)
@torch.no_grad()
def evaluate_retrieval(model, loader, vocab, device, K=100, seed=0):
"""Sensor -> text retrieval. For each sensor query, build pool of
1 correct + K-1 distractors from other test segments, compute rank."""
model.eval()
# Collect all embeddings
all_se = []
all_texts = []
srcs = []
for batch in loader:
dev_batch = {k: v.to(device) if torch.is_tensor(v) else v
for k, v in batch.items()}
se = model.sensor(dev_batch['window'], dev_batch['window_mask'])
all_se.append(se.cpu())
# For eval, use the first caption ("task") as the gold text
for texts in batch['all_texts']:
all_texts.append(texts[0])
srcs.extend(batch['srcs'])
all_se = torch.cat(all_se, dim=0) # (N, D)
# Encode all candidate texts once
text_embs = []
for i in range(0, len(all_texts), 64):
chunk = all_texts[i:i + 64]
tok_lists = [tokenize(t, vocab, max_len=64) for t in chunk]
lens = [len(t) for t in tok_lists]
max_len = max(lens)
pad_ids = torch.zeros(len(chunk), max_len, dtype=torch.long)
mask = torch.zeros(len(chunk), max_len, dtype=torch.bool)
for j, t in enumerate(tok_lists):
pad_ids[j, :len(t)] = torch.LongTensor(t)
mask[j, :len(t)] = True
pad_ids = pad_ids.to(device)
mask = mask.to(device)
te = model.text(pad_ids, mask).cpu()
text_embs.append(te)
text_embs = torch.cat(text_embs, dim=0) # (N, D)
# For each sensor query i, sample K-1 distractors from {0..N}\{i}
rng = np.random.RandomState(seed)
N = all_se.shape[0]
ranks = []
for i in range(N):
pool_size = min(K, N)
neg_candidates = [j for j in range(N) if j != i]
if len(neg_candidates) < pool_size - 1:
pool = [i] + neg_candidates
else:
neg = rng.choice(neg_candidates, size=pool_size - 1, replace=False)
pool = [i] + neg.tolist()
# Compute similarity of query i with pool texts
q = all_se[i:i + 1] # (1, D)
pool_texts = text_embs[pool] # (K, D)
sims = (q @ pool_texts.t()).squeeze(0).numpy() # (K,)
# rank of pool[0] (the correct one)
order = np.argsort(-sims)
rank = int(np.where(order == 0)[0][0]) + 1
ranks.append(rank)
ranks = np.array(ranks)
return {
'N': int(N),
'K': int(K),
'recall@1': float((ranks <= 1).mean()),
'recall@5': float((ranks <= 5).mean()),
'recall@10': float((ranks <= 10).mean()),
'median_rank': float(np.median(ranks)),
'mean_rank': float(ranks.mean()),
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def run_experiment(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
modalities = args.modalities.split(',')
print(f"Modalities: {modalities} | Seed: {args.seed}")
print("Collecting train segments...")
train_segs = collect_segments(TRAIN_VOLS)
print("Collecting test segments...")
test_segs = collect_segments(TEST_VOLS)
# Build char vocab from train only
vocab = build_vocab(train_segs)
print(f" Vocab size: {len(vocab)}")
print("Building train dataset...")
train_ds = SegmentRetrievalDataset(
train_segs, modalities, vocab, downsample=args.downsample,
context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len,
)
stats = train_ds.get_stats()
print("Building test dataset...")
test_ds = SegmentRetrievalDataset(
test_segs, modalities, vocab, downsample=args.downsample,
context_pad_sec=args.context_pad_sec, max_text_len=args.max_text_len,
stats=stats,
)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=retrieval_collate, num_workers=0,
drop_last=True)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=retrieval_collate, num_workers=0)
model = TwoTowerRetrieval(
train_ds.feat_dim, len(vocab),
hidden_dim=args.hidden_dim, emb_dim=args.emb_dim,
max_text_len=args.max_text_len, dropout=args.dropout,
).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Params: {n_params:,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs, eta_min=1e-6,
)
mod_str = '-'.join(modalities)
exp_name = f"retrieval_{mod_str}_seed{args.seed}"
if args.tag:
exp_name += f"_{args.tag}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
best_r10 = 0.0
best_metrics = None
best_state = None
for epoch in range(1, args.epochs + 1):
t0 = time.time()
loss = train_one_epoch(model, train_loader, optimizer, device)
scheduler.step()
if epoch % args.eval_every == 0 or epoch == args.epochs:
m = evaluate_retrieval(model, test_loader, vocab, device,
K=args.K, seed=args.seed)
print(f" E{epoch:3d} | loss {loss:.4f} | R@1 {m['recall@1']:.3f} "
f"R@5 {m['recall@5']:.3f} R@10 {m['recall@10']:.3f} "
f"medR {m['median_rank']:.1f} | {time.time()-t0:.1f}s")
if m['recall@10'] > best_r10:
best_r10 = m['recall@10']
best_metrics = m
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
else:
print(f" E{epoch:3d} | loss {loss:.4f} | {time.time()-t0:.1f}s")
if best_state is not None:
torch.save(best_state, os.path.join(out_dir, 'model_best.pt'))
# Final eval with multiple distractor pool seeds for robustness
model.load_state_dict(best_state)
final_metrics = []
for s in range(3):
m = evaluate_retrieval(model, test_loader, vocab, device,
K=args.K, seed=1000 + s)
final_metrics.append(m)
avg = {k: float(np.mean([fm[k] for fm in final_metrics]))
for k in ['recall@1', 'recall@5', 'recall@10', 'median_rank', 'mean_rank']}
std = {k: float(np.std([fm[k] for fm in final_metrics]))
for k in ['recall@1', 'recall@5', 'recall@10']}
results = {
'experiment': exp_name,
'modalities': modalities,
'seed': args.seed,
'K_pool': args.K,
'n_train_segments': len(train_ds),
'n_test_segments': len(test_ds),
'vocab_size': len(vocab),
'best_recall10': float(best_r10),
'best_metrics': best_metrics,
'final_avg_over_3_pool_seeds': avg,
'final_std_over_3_pool_seeds': std,
'args': vars(args),
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Saved: {out_dir}/results.json")
print(f"Final (avg over 3 pool seeds): R@1 {avg['recall@1']:.3f} "
f"R@5 {avg['recall@5']:.3f} R@10 {avg['recall@10']:.3f}")
return results
def main():
p = argparse.ArgumentParser()
p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu')
p.add_argument('--epochs', type=int, default=60)
p.add_argument('--batch_size', type=int, default=64)
p.add_argument('--lr', type=float, default=5e-4)
p.add_argument('--weight_decay', type=float, default=1e-4)
p.add_argument('--hidden_dim', type=int, default=128)
p.add_argument('--emb_dim', type=int, default=128)
p.add_argument('--dropout', type=float, default=0.2)
p.add_argument('--downsample', type=int, default=5)
p.add_argument('--context_pad_sec', type=float, default=1.0)
p.add_argument('--max_text_len', type=int, default=64)
p.add_argument('--K', type=int, default=100)
p.add_argument('--eval_every', type=int, default=5)
p.add_argument('--seed', type=int, default=42)
p.add_argument('--output_dir', type=str, required=True)
p.add_argument('--tag', type=str, default='')
args = p.parse_args()
run_experiment(args)
if __name__ == '__main__':
main()