gemma / main.py
gijl's picture
Update main.py
b386df6 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
from PIL import Image
import io
import os
app = FastAPI()
model_id = "gijl/gemma-4-E4B-it"
# --- السطر 13: استبدال الجزء القديم بهذا الجزء الجديد ---
print("جاري تحميل المعالج...")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("جاري تحميل النموذج (قد يستغرق وقتاً بسبب الحجم)...")
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True
)
# -------------------------------------------------------
@app.get("/")
async def read_index():
with open("index.html", "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
@app.post("/generate")
async def generate_text(image: UploadFile = File(...), text: str = Form(...)):
try:
image_data = await image.read()
pil_image = Image.open(io.BytesIO(image_data)).convert("RGB")
inputs = processor(text=text, images=pil_image, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {"result": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))