ASTERIZER commited on
Commit
33c9cef
·
verified ·
1 Parent(s): aba5a87

Upload chat_full_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat_full_sft.py +91 -0
chat_full_sft.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from transformers import AutoTokenizer
7
+
8
+ from chat import LUNAModel, format_prompt, generate
9
+
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser(description="Chat with the rag_mcp_full_sft checkpoint")
13
+ parser.add_argument("--ckpt", default="Base/out/sft/rag_mcp_full_sft/final/model.pth")
14
+ parser.add_argument("--hf-repo", default="ASTERIZER/LUNA-100M")
15
+ parser.add_argument("--hf-file", default="rag_mcp_full_sft/final/model.pth")
16
+ parser.add_argument("--tok-dir", default="Base/checkpoints/EleutherAI/pythia-160m")
17
+ parser.add_argument("--max-new", type=int, default=150)
18
+ parser.add_argument("--temp", type=float, default=0.7)
19
+ parser.add_argument("--top-p", type=float, default=0.9)
20
+ parser.add_argument("--top-k", type=int, default=40)
21
+ parser.add_argument("--rep-pen", type=float, default=1.0)
22
+ parser.add_argument("--device", default="auto")
23
+ return parser.parse_args()
24
+
25
+
26
+ def resolve_checkpoint(args):
27
+ ckpt_path = args.ckpt
28
+ if os.path.exists(ckpt_path):
29
+ return ckpt_path
30
+
31
+ downloaded = hf_hub_download(
32
+ repo_id=args.hf_repo,
33
+ filename=args.hf_file,
34
+ token=os.environ.get("HF_TOKEN"),
35
+ )
36
+ return downloaded
37
+
38
+
39
+ def main():
40
+ args = parse_args()
41
+ device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
42
+ if device == "auto":
43
+ device = "cpu"
44
+
45
+ ckpt_path = resolve_checkpoint(args)
46
+ print(f"\nDevice: {device}")
47
+ print(f"Loading: {ckpt_path}")
48
+
49
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
50
+ state = ckpt["model"] if "model" in ckpt else ckpt
51
+ model = LUNAModel()
52
+ model.load_state_dict(state, strict=True)
53
+ model = model.to(device).eval()
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
56
+ print(f"Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})")
57
+ print("Type 'quit' to exit.")
58
+
59
+ while True:
60
+ try:
61
+ user_input = input("You: ").strip()
62
+ except (EOFError, KeyboardInterrupt):
63
+ print("\nBye!")
64
+ break
65
+
66
+ if not user_input:
67
+ continue
68
+ if user_input.lower() in {"quit", "exit", "q"}:
69
+ print("Bye!")
70
+ break
71
+
72
+ prompt = format_prompt(user_input)
73
+ ids = tokenizer.encode(prompt, return_tensors="pt")
74
+ tokens = generate(
75
+ model,
76
+ ids,
77
+ max_new=args.max_new,
78
+ temperature=args.temp,
79
+ top_p=args.top_p,
80
+ top_k=args.top_k,
81
+ repetition_penalty=args.rep_pen,
82
+ device=device,
83
+ )
84
+ response = tokenizer.decode(tokens, skip_special_tokens=True).strip()
85
+ if "### " in response:
86
+ response = response.split("### ")[0].strip()
87
+ print(f"\nLUNA: {response}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()