| import gradio as gr |
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
| TARGET_MODEL_ID = "google/gemma-4-E2B-it" |
| ASSISTANT_MODEL_ID = "google/gemma-4-E2B-it-assistant" |
|
|
| |
| processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID) |
| target_model = AutoModelForCausalLM.from_pretrained( |
| TARGET_MODEL_ID, |
| dtype="auto", |
| device_map="auto", |
|
|
| ) |
|
|
| |
| assistant_model = AutoModelForCausalLM.from_pretrained( |
| ASSISTANT_MODEL_ID, |
| dtype="auto", |
| device_map="auto", |
| ) |
|
|
|
|
| def greet(name): |
| |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": name}, |
| ] |
| |
| |
| text = processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| inputs = processor(text=text, return_tensors="pt").to(target_model.device) |
| input_len = inputs["input_ids"].shape[-1] |
| |
| |
| outputs = target_model.generate( |
| **inputs, |
| assistant_model=assistant_model, |
| max_new_tokens=256, |
| ) |
| response = processor.decode(outputs[0][input_len:], skip_special_tokens=False) |
| |
| |
| textofinal =processor.parse_response(response) |
| return textofinal |
|
|
| demo = gr.Interface(fn=greet, inputs="text", outputs="text") |
| demo.launch() |
|
|