rufatronics commited on
Commit
d74a673
·
verified ·
1 Parent(s): cd49dba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py CHANGED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForImageTextToText, AutoProcessor
3
+ import torch
4
+ import os
5
+
6
+ # 1. Setup Model & Token
7
+ model_id = "google/gemma-3n-E2B-it"
8
+ hf_token = os.getenv("HF_TOKEN")
9
+ device = "cpu"
10
+
11
+ print("Loading Gemma 3n (10GB)... This takes a few minutes.")
12
+
13
+ # We add low_cpu_mem_usage=True to prevent crashing on load
14
+ processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
15
+ model = AutoModelForImageTextToText.from_pretrained(
16
+ model_id,
17
+ token=hf_token,
18
+ torch_dtype=torch.float32,
19
+ low_cpu_mem_usage=True,
20
+ device_map="auto"
21
+ )
22
+
23
+ def chat_function(message, history):
24
+ # Prepare history for the model
25
+ msgs = []
26
+ for user_msg, assistant_msg in history:
27
+ if user_msg: msgs.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
28
+ if assistant_msg: msgs.append({"role": "model", "content": [{"type": "text", "text": assistant_msg}]})
29
+
30
+ # Add new message
31
+ msgs.append({"role": "user", "content": [{"type": "text", "text": message}]})
32
+
33
+ # Apply template
34
+ inputs = processor.apply_chat_template(
35
+ msgs,
36
+ add_generation_prompt=True,
37
+ tokenize=True,
38
+ return_tensors="pt"
39
+ ).to(device)
40
+
41
+ # Generate
42
+ with torch.no_grad(): # Saves memory during generation
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_new_tokens=400,
46
+ do_sample=True,
47
+ temperature=0.4
48
+ )
49
+
50
+ response = processor.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
51
+ return response
52
+
53
+ # 5. The Interface
54
+ demo = gr.ChatInterface(
55
+ fn=chat_function,
56
+ title="Gemma 3n E2B (Fixed)",
57
+ description="Now with 'timm' installed and optimized for 16GB RAM!",
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()