| import os |
| import torch |
| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from threading import Thread |
|
|
| |
| MODEL_ID = "google/gemma-3-270m-it" |
| HF_TOKEN = os.getenv('HF_TOKEN') |
|
|
| print("--- [1] Loading Assets ---") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| device_map="cpu", |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| token=HF_TOKEN |
| ) |
| print("--- [2] Model Ready ---") |
|
|
| def chat(message, history): |
| |
| inputs = tokenizer(message, return_tensors="pt").to("cpu") |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| |
| |
| kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=256, |
| do_sample=True, |
| temperature=0.7, |
| ) |
| |
| thread = Thread(target=model.generate, kwargs=kwargs) |
| thread.start() |
| |
| buffer = "" |
| for new_text in streamer: |
| buffer += new_text |
| yield buffer |
|
|
| |
| demo = gr.ChatInterface(fn=chat, type="messages") |
|
|
| if __name__ == "__main__": |
| print("--- [3] Launching on Port 7860 ---") |
| |
| demo.launch(server_name="0.0.0.0", server_port=7860) |