import gradio as gr from transformers import AutoModelForImageTextToText, AutoProcessor import torch import os model_id = "google/gemma-3n-E2B-it" hf_token = os.getenv("HF_TOKEN") device = "cpu" print("Loading Gemma 3n with Memory Optimizations...") # 1. We use bfloat16 to cut RAM usage by 50% # 2. low_cpu_mem_usage prevents the 'double loading' crash processor = AutoProcessor.from_pretrained(model_id, token=hf_token) model = AutoModelForImageTextToText.from_pretrained( model_id, token=hf_token, torch_dtype=torch.bfloat16, # KEY FIX: Half-precision for CPU low_cpu_mem_usage=True, # KEY FIX: Don't use double RAM on load device_map="auto" ) def chat_function(message, history): msgs = [] for user_msg, assistant_msg in history: if user_msg: msgs.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) if assistant_msg: msgs.append({"role": "model", "content": [{"type": "text", "text": assistant_msg}]}) msgs.append({"role": "user", "content": [{"type": "text", "text": message}]}) inputs = processor.apply_chat_template( msgs, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(device) # Note: Inference on CPU with bfloat16 is much safer for RAM with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=400, do_sample=True, temperature=0.4 ) response = processor.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) return response demo = gr.ChatInterface(fn=chat_function, title="Gemma 3n E2B (RAM Optimized)") if __name__ == "__main__": demo.launch()