LUNA-Training / chat_full_sft.py
ASTERIZER's picture
Upload chat_full_sft.py with huggingface_hub
33c9cef verified
import argparse
import os
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from chat import LUNAModel, format_prompt, generate
def parse_args():
parser = argparse.ArgumentParser(description="Chat with the rag_mcp_full_sft checkpoint")
parser.add_argument("--ckpt", default="Base/out/sft/rag_mcp_full_sft/final/model.pth")
parser.add_argument("--hf-repo", default="ASTERIZER/LUNA-100M")
parser.add_argument("--hf-file", default="rag_mcp_full_sft/final/model.pth")
parser.add_argument("--tok-dir", default="Base/checkpoints/EleutherAI/pythia-160m")
parser.add_argument("--max-new", type=int, default=150)
parser.add_argument("--temp", type=float, default=0.7)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--top-k", type=int, default=40)
parser.add_argument("--rep-pen", type=float, default=1.0)
parser.add_argument("--device", default="auto")
return parser.parse_args()
def resolve_checkpoint(args):
ckpt_path = args.ckpt
if os.path.exists(ckpt_path):
return ckpt_path
downloaded = hf_hub_download(
repo_id=args.hf_repo,
filename=args.hf_file,
token=os.environ.get("HF_TOKEN"),
)
return downloaded
def main():
args = parse_args()
device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
if device == "auto":
device = "cpu"
ckpt_path = resolve_checkpoint(args)
print(f"\nDevice: {device}")
print(f"Loading: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state = ckpt["model"] if "model" in ckpt else ckpt
model = LUNAModel()
model.load_state_dict(state, strict=True)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
print(f"Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})")
print("Type 'quit' to exit.")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input:
continue
if user_input.lower() in {"quit", "exit", "q"}:
print("Bye!")
break
prompt = format_prompt(user_input)
ids = tokenizer.encode(prompt, return_tensors="pt")
tokens = generate(
model,
ids,
max_new=args.max_new,
temperature=args.temp,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.rep_pen,
device=device,
)
response = tokenizer.decode(tokens, skip_special_tokens=True).strip()
if "### " in response:
response = response.split("### ")[0].strip()
print(f"\nLUNA: {response}\n")
if __name__ == "__main__":
main()