micas23 commited on
Commit
42c3583
·
verified ·
1 Parent(s): 1af664c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+
4
+ TARGET_MODEL_ID = "google/gemma-4-E2B-it"
5
+ ASSISTANT_MODEL_ID = "google/gemma-4-E2B-it-assistant"
6
+
7
+ # Target Model
8
+ processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
9
+ target_model = AutoModelForCausalLM.from_pretrained(
10
+ TARGET_MODEL_ID,
11
+ dtype="auto",
12
+ device_map="auto",
13
+
14
+ )
15
+
16
+ # Assistant Model (the drafter)
17
+ assistant_model = AutoModelForCausalLM.from_pretrained(
18
+ ASSISTANT_MODEL_ID,
19
+ dtype="auto",
20
+ device_map="auto",
21
+ )
22
+
23
+
24
+ def greet(name):
25
+ # Prompt
26
+ messages = [
27
+ {"role": "system", "content": "You are a helpful assistant."},
28
+ {"role": "user", "content": name},
29
+ ]
30
+
31
+ # Process input
32
+ text = processor.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True,
36
+ )
37
+ inputs = processor(text=text, return_tensors="pt").to(target_model.device)
38
+ input_len = inputs["input_ids"].shape[-1]
39
+
40
+ # Generate output
41
+ outputs = target_model.generate(
42
+ **inputs,
43
+ assistant_model=assistant_model,
44
+ max_new_tokens=256,
45
+ )
46
+ response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
47
+
48
+ # Parse output
49
+ textofinal =processor.parse_response(response)
50
+ return textofinal
51
+
52
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
53
+ demo.launch()