| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoTokenizer, AutoConfig |
|
|
| from birwkv7 import BiRWKV7Layer, init_from_attention |
|
|
|
|
| def _find_encoder(model): |
| for attr in ['encoder', 'model']: |
| if hasattr(model, attr): |
| candidate = getattr(model, attr) |
| if hasattr(candidate, 'layers'): |
| return candidate |
| if hasattr(model, 'layers'): |
| return model |
| raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") |
|
|
|
|
| def find_attention_layers(model): |
| encoder = _find_encoder(model) |
| layers = [] |
|
|
| for i, layer in enumerate(encoder.layers): |
| attn = None |
| attn_path = None |
| for name in ['attn', 'attention', 'self_attn', 'self_attention']: |
| if hasattr(layer, name): |
| attn = getattr(layer, name) |
| attn_path = f"layers.{i}.{name}" |
| break |
|
|
| if attn is None: |
| continue |
|
|
| is_global = False |
| if hasattr(attn, 'local_attention'): |
| is_global = not attn.local_attention |
| elif hasattr(attn, 'is_global_attention'): |
| is_global = attn.is_global_attention |
| elif hasattr(attn, 'use_sliding_window'): |
| is_global = not attn.use_sliding_window |
| elif hasattr(attn, 'sliding_window'): |
| is_global = attn.sliding_window is None |
| else: |
| is_global = (i % 3 == 2) |
|
|
| layers.append((i, attn_path, attn, is_global)) |
|
|
| return layers |
|
|
|
|
| def perform_surgery(model, variant, hidden_size, num_heads, replaced_layers=None): |
| layers = find_attention_layers(model) |
| global_indices = [idx for idx, _, _, g in layers if g] |
| local_indices = [idx for idx, _, _, g in layers if not g] |
|
|
| print(f"\nFound {len(layers)} attention layers:") |
| print(f" Global: {global_indices}") |
| print(f" Local: {local_indices}") |
|
|
| if replaced_layers is not None: |
| replace_indices = {int(k) for k in replaced_layers.keys()} |
| elif variant == 'conservative': |
| replace_indices = set(local_indices) |
| elif variant == 'aggressive': |
| keep = set() |
| if global_indices: |
| keep.add(global_indices[0]) |
| keep.add(global_indices[-1]) |
| replace_indices = {idx for idx, _, _, _ in layers if idx not in keep} |
| elif variant == 'pure': |
| replace_indices = {idx for idx, _, _, _ in layers} |
| else: |
| raise ValueError(f"Unknown variant: {variant}") |
|
|
| print(f"\nVariant '{variant}': replacing {len(replace_indices)} of {len(layers)} layers") |
|
|
| encoder = _find_encoder(model) |
| report = {} |
|
|
| for layer_idx, attn_path, attn_module, is_global in layers: |
| if layer_idx not in replace_indices: |
| print(f" Layer {layer_idx}: KEEP ({'global' if is_global else 'local'})") |
| continue |
|
|
| birwkv = BiRWKV7Layer(hidden_size, num_heads) |
| transferred = init_from_attention(birwkv, attn_module) |
|
|
| device = next(attn_module.parameters()).device |
| dtype = next(attn_module.parameters()).dtype |
| birwkv = birwkv.to(device=device, dtype=dtype) |
|
|
| attn_name = attn_path.split('.')[-1] |
| setattr(encoder.layers[layer_idx], attn_name, birwkv) |
|
|
| report[layer_idx] = {'was_global': is_global, 'transferred': transferred} |
| print(f" Layer {layer_idx}: REPLACED ({'global' if is_global else 'local'}) " |
| f"-> BiRWKV-7 [{', '.join(transferred)}]") |
|
|
| return report |
|
|
|
|
| def mean_pool(hidden_states, attention_mask): |
| mask = attention_mask.unsqueeze(-1).float() |
| return (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
|
|
|
|
| class HareWrapper(torch.nn.Module): |
|
|
| def __init__(self, model, tokenizer): |
| super().__init__() |
| self.model = model |
| self.tokenizer = tokenizer |
| self.config = model.config |
|
|
| def encode(self, texts, batch_size=32, max_length=512, show_progress=False): |
| all_embs = [] |
| iterator = range(0, len(texts), batch_size) |
| if show_progress: |
| from tqdm import tqdm |
| iterator = tqdm(iterator, desc="Encoding") |
|
|
| for i in iterator: |
| batch = texts[i:i+batch_size] |
| enc = self.tokenizer(batch, padding=True, truncation=True, |
| max_length=max_length, return_tensors='pt') |
| enc = {k: v.to(next(self.model.parameters()).device) for k, v in enc.items()} |
|
|
| with torch.no_grad(): |
| hidden = self.model(**enc).last_hidden_state |
| emb = mean_pool(hidden, enc['attention_mask']) |
| all_embs.append(F.normalize(emb, p=2, dim=-1).cpu()) |
|
|
| return torch.cat(all_embs, dim=0) |
|
|
| def forward(self, **kwargs): |
| return self.model(**kwargs) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--base_model', default='answerdotai/ModernBERT-base') |
| parser.add_argument('--variant', choices=['conservative', 'aggressive', 'pure'], |
| default='conservative') |
| parser.add_argument('--output', type=str, default=None) |
| parser.add_argument('--inspect_only', action='store_true') |
| args = parser.parse_args() |
|
|
| print(f"Loading {args.base_model}...") |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model) |
| model = AutoModel.from_pretrained(args.base_model, trust_remote_code=True) |
| config = model.config |
| hidden_size = config.hidden_size |
| num_heads = config.num_attention_heads |
| print(f" hidden_size={hidden_size}, num_heads={num_heads}, head_size={hidden_size // num_heads}") |
|
|
| if args.inspect_only: |
| layers = find_attention_layers(model) |
| print(f"\n{len(layers)} attention layers:") |
| for idx, path, attn, is_g in layers: |
| n = sum(p.numel() for p in attn.parameters()) |
| print(f" Layer {idx} ({'GLOBAL' if is_g else 'local'}): {type(attn).__name__} ({n:,}) @ {path}") |
| return |
|
|
| if not args.output: |
| parser.error("--output required for surgery (omit for --inspect_only)") |
|
|
| report = perform_surgery(model, args.variant, hidden_size, num_heads) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"\nPost-surgery: {total_params:,} params") |
|
|
| print("Sanity check :)") |
| inputs = tokenizer("Hello world", return_tensors='pt') |
| inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| out = model(**inputs) |
| print(f" Output: {out.last_hidden_state.shape}, norm={out.last_hidden_state.norm().item():.4f}") |
|
|
| output_dir = Path(args.output) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| torch.save(model.state_dict(), output_dir / 'model.pt') |
| tokenizer.save_pretrained(output_dir) |
| config.save_pretrained(output_dir) |
|
|
| meta = { |
| 'base_model': args.base_model, |
| 'variant': args.variant, |
| 'hidden_size': hidden_size, |
| 'num_heads': num_heads, |
| 'replaced_layers': {str(k): v for k, v in report.items()}, |
| 'total_params': total_params, |
| } |
| with open(output_dir / 'surgery_meta.json', 'w') as f: |
| json.dump(meta, f, indent=2) |
|
|
| print(f"\nSaved to {output_dir}/ ({total_params:,} params)") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|