Upload chat_full_sft.py with huggingface_hub
Browse files- chat_full_sft.py +91 -0
chat_full_sft.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from chat import LUNAModel, format_prompt, generate
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_args():
|
| 12 |
+
parser = argparse.ArgumentParser(description="Chat with the rag_mcp_full_sft checkpoint")
|
| 13 |
+
parser.add_argument("--ckpt", default="Base/out/sft/rag_mcp_full_sft/final/model.pth")
|
| 14 |
+
parser.add_argument("--hf-repo", default="ASTERIZER/LUNA-100M")
|
| 15 |
+
parser.add_argument("--hf-file", default="rag_mcp_full_sft/final/model.pth")
|
| 16 |
+
parser.add_argument("--tok-dir", default="Base/checkpoints/EleutherAI/pythia-160m")
|
| 17 |
+
parser.add_argument("--max-new", type=int, default=150)
|
| 18 |
+
parser.add_argument("--temp", type=float, default=0.7)
|
| 19 |
+
parser.add_argument("--top-p", type=float, default=0.9)
|
| 20 |
+
parser.add_argument("--top-k", type=int, default=40)
|
| 21 |
+
parser.add_argument("--rep-pen", type=float, default=1.0)
|
| 22 |
+
parser.add_argument("--device", default="auto")
|
| 23 |
+
return parser.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resolve_checkpoint(args):
|
| 27 |
+
ckpt_path = args.ckpt
|
| 28 |
+
if os.path.exists(ckpt_path):
|
| 29 |
+
return ckpt_path
|
| 30 |
+
|
| 31 |
+
downloaded = hf_hub_download(
|
| 32 |
+
repo_id=args.hf_repo,
|
| 33 |
+
filename=args.hf_file,
|
| 34 |
+
token=os.environ.get("HF_TOKEN"),
|
| 35 |
+
)
|
| 36 |
+
return downloaded
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
args = parse_args()
|
| 41 |
+
device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
|
| 42 |
+
if device == "auto":
|
| 43 |
+
device = "cpu"
|
| 44 |
+
|
| 45 |
+
ckpt_path = resolve_checkpoint(args)
|
| 46 |
+
print(f"\nDevice: {device}")
|
| 47 |
+
print(f"Loading: {ckpt_path}")
|
| 48 |
+
|
| 49 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 50 |
+
state = ckpt["model"] if "model" in ckpt else ckpt
|
| 51 |
+
model = LUNAModel()
|
| 52 |
+
model.load_state_dict(state, strict=True)
|
| 53 |
+
model = model.to(device).eval()
|
| 54 |
+
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
|
| 56 |
+
print(f"Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})")
|
| 57 |
+
print("Type 'quit' to exit.")
|
| 58 |
+
|
| 59 |
+
while True:
|
| 60 |
+
try:
|
| 61 |
+
user_input = input("You: ").strip()
|
| 62 |
+
except (EOFError, KeyboardInterrupt):
|
| 63 |
+
print("\nBye!")
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
if not user_input:
|
| 67 |
+
continue
|
| 68 |
+
if user_input.lower() in {"quit", "exit", "q"}:
|
| 69 |
+
print("Bye!")
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
prompt = format_prompt(user_input)
|
| 73 |
+
ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 74 |
+
tokens = generate(
|
| 75 |
+
model,
|
| 76 |
+
ids,
|
| 77 |
+
max_new=args.max_new,
|
| 78 |
+
temperature=args.temp,
|
| 79 |
+
top_p=args.top_p,
|
| 80 |
+
top_k=args.top_k,
|
| 81 |
+
repetition_penalty=args.rep_pen,
|
| 82 |
+
device=device,
|
| 83 |
+
)
|
| 84 |
+
response = tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 85 |
+
if "### " in response:
|
| 86 |
+
response = response.split("### ")[0].strip()
|
| 87 |
+
print(f"\nLUNA: {response}\n")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|