| |
| import os |
| import torch |
| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
| BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
| PEFT_MODEL_ID = "befm/Be.FM-8B" |
|
|
| USE_PEFT = True |
| try: |
| from peft import PeftModel, PeftConfig |
| except Exception: |
| USE_PEFT = False |
| print("[WARN] 'peft' not installed; running base model only.") |
|
|
| def load_model_and_tokenizer(): |
| if HF_TOKEN is None: |
| raise RuntimeError( |
| "HF_TOKEN is not set. Add it in Space → Settings → Secrets. " |
| "Also ensure your account has access to the gated base model." |
| ) |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_auth_token=HF_TOKEN) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| device_map="auto" if torch.cuda.is_available() else None, |
| torch_dtype=dtype, |
| use_auth_token=HF_TOKEN, |
| ) |
|
|
| if USE_PEFT: |
| try: |
| _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, use_auth_token=HF_TOKEN) |
| model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, use_auth_token=HF_TOKEN) |
| print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}") |
| return model, tok |
| except Exception as e: |
| print(f"[WARN] Failed to load PEFT adapter: {e}") |
| return base, tok |
| return base, tok |
|
|
| model, tokenizer = load_model_and_tokenizer() |
| DEVICE = model.device |
|
|
| @torch.inference_mode() |
| def generate_response(prompt: str, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str: |
| enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} |
| out = model.generate( |
| **enc, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| return tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
| def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p): |
| |
| conv = [] |
| if system_prompt: |
| conv.append(f"system: {system_prompt}") |
| for u, a in (history or []): |
| if u: |
| conv.append(f"user: {u}") |
| if a: |
| conv.append(f"assistant: {a}") |
| if message: |
| conv.append(f"user: {message}") |
| prompt = "\n".join(conv) + "\nassistant:" |
| reply = generate_response( |
| prompt, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| |
| if "assistant:" in reply: |
| reply = reply.split("assistant:")[-1].strip() |
| return reply |
|
|
| demo = gr.ChatInterface( |
| fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p: |
| chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p), |
| additional_inputs=[ |
| gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2), |
| gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), |
| gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), |
| gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"), |
| ], |
| title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct", |
| description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B." |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|