| import argparse | |
| import os | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer | |
| from chat import LUNAModel, format_prompt, generate | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Chat with the rag_mcp_full_sft checkpoint") | |
| parser.add_argument("--ckpt", default="Base/out/sft/rag_mcp_full_sft/final/model.pth") | |
| parser.add_argument("--hf-repo", default="ASTERIZER/LUNA-100M") | |
| parser.add_argument("--hf-file", default="rag_mcp_full_sft/final/model.pth") | |
| parser.add_argument("--tok-dir", default="Base/checkpoints/EleutherAI/pythia-160m") | |
| parser.add_argument("--max-new", type=int, default=150) | |
| parser.add_argument("--temp", type=float, default=0.7) | |
| parser.add_argument("--top-p", type=float, default=0.9) | |
| parser.add_argument("--top-k", type=int, default=40) | |
| parser.add_argument("--rep-pen", type=float, default=1.0) | |
| parser.add_argument("--device", default="auto") | |
| return parser.parse_args() | |
| def resolve_checkpoint(args): | |
| ckpt_path = args.ckpt | |
| if os.path.exists(ckpt_path): | |
| return ckpt_path | |
| downloaded = hf_hub_download( | |
| repo_id=args.hf_repo, | |
| filename=args.hf_file, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| return downloaded | |
| def main(): | |
| args = parse_args() | |
| device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device | |
| if device == "auto": | |
| device = "cpu" | |
| ckpt_path = resolve_checkpoint(args) | |
| print(f"\nDevice: {device}") | |
| print(f"Loading: {ckpt_path}") | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| state = ckpt["model"] if "model" in ckpt else ckpt | |
| model = LUNAModel() | |
| model.load_state_dict(state, strict=True) | |
| model = model.to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(args.tok_dir) | |
| print(f"Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})") | |
| print("Type 'quit' to exit.") | |
| while True: | |
| try: | |
| user_input = input("You: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\nBye!") | |
| break | |
| if not user_input: | |
| continue | |
| if user_input.lower() in {"quit", "exit", "q"}: | |
| print("Bye!") | |
| break | |
| prompt = format_prompt(user_input) | |
| ids = tokenizer.encode(prompt, return_tensors="pt") | |
| tokens = generate( | |
| model, | |
| ids, | |
| max_new=args.max_new, | |
| temperature=args.temp, | |
| top_p=args.top_p, | |
| top_k=args.top_k, | |
| repetition_penalty=args.rep_pen, | |
| device=device, | |
| ) | |
| response = tokenizer.decode(tokens, skip_special_tokens=True).strip() | |
| if "### " in response: | |
| response = response.split("### ")[0].strip() | |
| print(f"\nLUNA: {response}\n") | |
| if __name__ == "__main__": | |
| main() |