|
|
| """
|
| Generation script for Circuit Transformer.
|
|
|
| Usage:
|
| python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
|
| """
|
|
|
| import argparse
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
| from transformers import AutoTokenizer
|
|
|
| from .config import CircuitConfig
|
| from .model import CircuitTransformer
|
| from .mirrored import MirroredConfig, MirroredTransformer
|
| from .graft_g2lu import load_g2lu_model
|
| from .layers import build_word_start_table
|
| from .data import get_tokenizer
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
|
|
|
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
|
| parser.add_argument("--prompt", type=str, default="", help="Prompt text")
|
| parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
|
| parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
|
| parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
|
| parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
|
| parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
|
|
|
| return parser.parse_args()
|
|
|
| def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| """Migrate checkpoint state_dict to match current model architecture.
|
|
|
| Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| """
|
| if any(k.startswith("_orig_mod.") for k in state_dict):
|
| state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
|
|
| model_keys = set(model.state_dict().keys())
|
| ckpt_keys = set(state_dict.keys())
|
|
|
| missing = model_keys - ckpt_keys
|
| unexpected = ckpt_keys - model_keys
|
|
|
| print(unexpected)
|
|
|
| if not missing and not unexpected:
|
| return state_dict
|
|
|
| migrated = dict(state_dict)
|
| migrations = []
|
|
|
|
|
| for key in list(unexpected):
|
| if ".ffn.gate_expand.weight" in key:
|
| new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| if new_key in missing:
|
| migrated[new_key] = migrated.pop(key)
|
| missing.discard(new_key)
|
| unexpected.discard(key)
|
| migrations.append(f" {key} → {new_key}")
|
| if ".ffn.gate_compress.weight" in key:
|
| new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| if new_key in missing:
|
| migrated[new_key] = migrated.pop(key)
|
| missing.discard(new_key)
|
| unexpected.discard(key)
|
| migrations.append(f" {key} → {new_key}")
|
|
|
| if migrations:
|
| print(f"State dict migration ({len(migrations)} keys renamed):")
|
| for m in migrations:
|
| print(m)
|
|
|
| still_missing = model_keys - set(migrated.keys())
|
| if still_missing:
|
| print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| for k in sorted(still_missing):
|
| print(f" {k}")
|
|
|
| return migrated
|
|
|
| def generate():
|
| args = parse_args()
|
|
|
|
|
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| print(f"Device: {device}")
|
|
|
|
|
| print(f"Loading checkpoint: {args.checkpoint}")
|
| checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
|
|
|
|
| model_type = checkpoint.get("model_type", "standard")
|
| is_folded = model_type == "folded"
|
|
|
| if model_type == "graft_g2lu":
|
| model = load_g2lu_model(args.checkpoint, device=device)
|
| model.eval()
|
| pretrained_name = checkpoint.get("pretrained_name", "unknown")
|
| print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
|
| tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
|
| tokenizer = get_tokenizer(tokenizer_name)
|
| elif is_folded:
|
| from grafting.fold_llama import FoldedLlama
|
| model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
|
| model.eval()
|
| fold_cfg = model.config
|
| print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
|
| f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
|
| tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
|
| else:
|
| if model_type == "mirrored":
|
| if checkpoint["config"].get("dual_gate_middle"):
|
| checkpoint["config"].pop("dual_gate_middle")
|
| config = MirroredConfig.from_dict(checkpoint["config"])
|
| model = MirroredTransformer(config).to(device)
|
| print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
|
| else:
|
| config = CircuitConfig.from_dict(checkpoint["config"])
|
| model = CircuitTransformer(config).to(device)
|
| print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
|
|
|
|
|
| state_dict = _migrate_state_dict(checkpoint["model"], model)
|
|
|
| model.load_state_dict(state_dict)
|
| model.eval()
|
| tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
|
| tokenizer = get_tokenizer(tokenizer_name)
|
|
|
|
|
| word_start_table_device = None
|
| if model_type not in ("graft_g2lu", "folded"):
|
| ckpt_config = checkpoint.get("config", {})
|
| word_rope_dims = ckpt_config.get("word_rope_dims", 0)
|
| if word_rope_dims > 0:
|
| word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| print(f"Word-position RoPE: {word_rope_dims} dims")
|
|
|
|
|
| if args.prompt:
|
| prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
|
| else:
|
|
|
| prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
|
|
|
| print(f"\nPrompt: {args.prompt or '<empty>'}")
|
| print(f"Prompt tokens: {prompt_ids.shape[1]}")
|
| print(f"Generating {args.max_tokens} tokens...")
|
| print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
|
| print("-" * 50)
|
|
|
|
|
| with torch.no_grad():
|
| gen_kwargs = dict(
|
| max_new_tokens=args.max_tokens,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| use_cache=not args.no_cache,
|
| )
|
| if args.repetition_penalty != 1.0:
|
| gen_kwargs["repetition_penalty"] = args.repetition_penalty
|
|
|
|
|
| if model_type == "graft_g2lu":
|
| if args.temperature > 0 and args.temperature != 1.0:
|
| gen_kwargs["do_sample"] = True
|
| elif args.top_p < 1.0 or args.top_k > 0:
|
| gen_kwargs["do_sample"] = True
|
|
|
| if word_start_table_device is not None:
|
| gen_kwargs["word_start_table"] = word_start_table_device
|
|
|
| output_ids = model.generate(prompt_ids, **gen_kwargs)
|
|
|
|
|
| generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| print(generated_text)
|
| print("-" * 50)
|
| print(f"Total tokens: {output_ids.shape[1]}")
|
|
|
|
|
| if __name__ == "__main__":
|
| generate()
|
|
|