import argparse import os import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer from chat import LUNAModel, format_prompt, generate def parse_args(): parser = argparse.ArgumentParser(description="Chat with the rag_mcp_full_sft checkpoint") parser.add_argument("--ckpt", default="Base/out/sft/rag_mcp_full_sft/final/model.pth") parser.add_argument("--hf-repo", default="ASTERIZER/LUNA-100M") parser.add_argument("--hf-file", default="rag_mcp_full_sft/final/model.pth") parser.add_argument("--tok-dir", default="Base/checkpoints/EleutherAI/pythia-160m") parser.add_argument("--max-new", type=int, default=150) parser.add_argument("--temp", type=float, default=0.7) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--top-k", type=int, default=40) parser.add_argument("--rep-pen", type=float, default=1.0) parser.add_argument("--device", default="auto") return parser.parse_args() def resolve_checkpoint(args): ckpt_path = args.ckpt if os.path.exists(ckpt_path): return ckpt_path downloaded = hf_hub_download( repo_id=args.hf_repo, filename=args.hf_file, token=os.environ.get("HF_TOKEN"), ) return downloaded def main(): args = parse_args() device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device if device == "auto": device = "cpu" ckpt_path = resolve_checkpoint(args) print(f"\nDevice: {device}") print(f"Loading: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) state = ckpt["model"] if "model" in ckpt else ckpt model = LUNAModel() model.load_state_dict(state, strict=True) model = model.to(device).eval() tokenizer = AutoTokenizer.from_pretrained(args.tok_dir) print(f"Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})") print("Type 'quit' to exit.") while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if not user_input: continue if user_input.lower() in {"quit", "exit", "q"}: print("Bye!") break prompt = format_prompt(user_input) ids = tokenizer.encode(prompt, return_tensors="pt") tokens = generate( model, ids, max_new=args.max_new, temperature=args.temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.rep_pen, device=device, ) response = tokenizer.decode(tokens, skip_special_tokens=True).strip() if "### " in response: response = response.split("### ")[0].strip() print(f"\nLUNA: {response}\n") if __name__ == "__main__": main()