| |
| import os |
| import secrets |
| import logging |
| import asyncio |
| import html |
| from dataclasses import dataclass |
| from typing import Any, Optional, Tuple |
|
|
| import gradio as gr |
| from transformers import pipeline |
| from dotenv import load_dotenv |
| from pydantic import BaseModel |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
|
|
| |
| load_dotenv() |
|
|
|
|
| @dataclass |
| class Config: |
| HF_TOKEN: str = os.getenv("HF_TOKEN", "") |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it") |
| MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048")) |
| LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") |
|
|
|
|
| class GenerationRequest(BaseModel): |
| prompt: str |
| max_tokens: int = 512 |
| temperature: float = 0.7 |
| top_k: int = 50 |
| top_p: float = 0.95 |
|
|
|
|
| class APIResponse(BaseModel): |
| success: bool |
| data: Any = None |
| error: Optional[str] = None |
|
|
|
|
| |
| def setup_logger() -> logging.Logger: |
| cfg = Config() |
| log_level = getattr(logging, cfg.LOG_LEVEL.upper(), logging.INFO) |
| logger = logging.getLogger("gemma_saas") |
| if not logger.handlers: |
| logger.setLevel(log_level) |
| formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| fh = logging.FileHandler("gemma_saas.log") |
| fh.setFormatter(formatter) |
| sh = logging.StreamHandler() |
| sh.setFormatter(formatter) |
| logger.addHandler(fh) |
| logger.addHandler(sh) |
| return logger |
|
|
|
|
| logger = setup_logger() |
|
|
|
|
| |
| class ModelManager: |
| def __init__(self, config: Config): |
| self.config = config |
| self.pipeline = None |
| self.model_loaded = False |
|
|
| async def initialize(self) -> None: |
| if not self.config.HF_TOKEN: |
| logger.error("Token do Hugging Face não encontrado. O carregamento do modelo poderá falhar.") |
| return |
|
|
| try: |
| logger.info(f"A carregar o modelo: {self.config.MODEL_NAME}...") |
| os.environ.setdefault("HF_TOKEN", self.config.HF_TOKEN) |
|
|
| loop = asyncio.get_event_loop() |
|
|
| def load_pipeline(): |
| return pipeline( |
| "text-generation", |
| model=self.config.MODEL_NAME, |
| token=self.config.HF_TOKEN, |
| torch_dtype="auto", |
| device_map="auto", |
| ) |
|
|
| self.pipeline = await loop.run_in_executor(None, load_pipeline) |
| self.model_loaded = True |
| logger.info("✅ Modelo carregado com sucesso!") |
| except Exception as e: |
| logger.error(f"❌ Erro ao carregar o modelo: {e}", exc_info=True) |
|
|
| async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]: |
| if not self.model_loaded or self.pipeline is None: |
| return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0 |
|
|
| if not request.prompt.strip(): |
| return False, "⚠️ O prompt não pode estar vazio.", 0 |
|
|
| loop = asyncio.get_event_loop() |
| messages = [{"role": "user", "content": request.prompt.strip()}] |
|
|
| def do_generation(): |
| tokenizer = getattr(self.pipeline, "tokenizer", None) |
|
|
| if tokenizer and hasattr(tokenizer, "apply_chat_template"): |
| prompt_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| else: |
| prompt_text = request.prompt.strip() |
|
|
| outputs = self.pipeline( |
| prompt_text, |
| max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS), |
| do_sample=True, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| ) |
|
|
| generated_text = outputs[0].get("generated_text", "") |
| if generated_text.startswith(prompt_text): |
| generated_text = generated_text[len(prompt_text):] |
|
|
| tokens_used = 0 |
| if tokenizer and hasattr(tokenizer, "encode"): |
| try: |
| tokens_used = len(tokenizer.encode(generated_text)) |
| except Exception: |
| tokens_used = 0 |
|
|
| return generated_text, tokens_used |
|
|
| generated_text, tokens_used = await loop.run_in_executor(None, do_generation) |
| return True, generated_text, tokens_used |
|
|
|
|
| |
| class GemmaService: |
| def __init__(self): |
| self.config = Config() |
| self.model_manager = ModelManager(self.config) |
|
|
| async def initialize(self): |
| await self.model_manager.initialize() |
|
|
| async def generate_text(self, api_key: str, prompt: str, **kwargs) -> APIResponse: |
| if not api_key or not isinstance(api_key, str) or not api_key.startswith("gsk-"): |
| return APIResponse(success=False, error="Chave de API inválida ou ausente.") |
| try: |
| req = GenerationRequest(prompt=prompt, **kwargs) |
| success, text, tokens_used = await self.model_manager.generate(req) |
| if success: |
| return APIResponse(success=True, data={"generated_text": text, "tokens_used": tokens_used}) |
| else: |
| return APIResponse(success=False, error=text) |
| except Exception as e: |
| logger.error(f"Erro de serviço durante a geração de texto: {e}", exc_info=True) |
| return APIResponse(success=False, error="Ocorreu um erro interno no serviço.") |
|
|
|
|
| |
| class GradioInterface: |
| def __init__(self, service: GemmaService): |
| self.service = service |
|
|
| def create_custom_css(self) -> str: |
| return """ |
| @import url('https://fonts.googleapis.com/css2?family=Material+Icons&display=swap'); |
| :root { --dark-bg:#0a0a0a; --panel-bg:#1a1a1a; --border-color:#333; --text-color:#f0f0f0; --text-light:#a0a0a0; --accent-orange:#FF4500; --accent-orange-hover:#FF6347; --code-bg:#282c34; } |
| .gradio-container { background: var(--dark-bg) !important; color: var(--text-color); } |
| /* ... rest of CSS (trimmed for brevity) ... */ |
| #send_button::before { content: "send"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; } |
| #generate_button::before { content: "auto_awesome"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; } |
| """ |
|
|
| def create_interface(self) -> gr.Blocks: |
| |
| demo = gr.Blocks(css=self.create_custom_css(), theme=None) |
| with demo: |
| with gr.Row(elem_id="main_layout", equal_height=False): |
| with gr.Column(scale=2): |
| with gr.Column(elem_id="left_panel"): |
| output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #a0a0a0;'>A sua resposta aparecerá aqui...</p>") |
| with gr.Column(elem_id="input_area"): |
| api_key_input = gr.Textbox(label="A Sua Chave de API", placeholder="Cole a sua chave gsk-... aqui", type="password", elem_id="api_key_input") |
| with gr.Row(): |
| prompt_input = gr.Textbox(show_label=False, placeholder="Digite a sua mensagem...", elem_id="prompt_input", scale=10) |
| send_button = gr.Button("➤ Enviar", elem_id="send_button", scale=2) |
|
|
| with gr.Column(scale=1): |
| with gr.Column(elem_id="right_panel"): |
| gr.Markdown("## Controlo") |
| key_button = gr.Button("✨ Gerar Nova Chave", elem_id="generate_button") |
|
|
| with gr.Accordion("Parâmetros Avançados", open=False): |
| temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperatura") |
| max_tokens_slider = gr.Slider(minimum=64, maximum=self.service.config.MAX_TOKENS, value=512, step=64, label="Max Tokens") |
| top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K") |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") |
|
|
| gr.Markdown("### Como Usar a API") |
| api_example_display = gr.HTML("<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>") |
|
|
| def handle_key_generation(): |
| key = f"gsk-{secrets.token_urlsafe(24).replace('_', '').replace('-', '')}" |
| code_html = f"<div class='code-snippet'> ... </div>" |
| return key, gr.update(value=code_html) |
|
|
| async def handle_generation(api_key, prompt, temp, max_tokens, top_k, top_p, btn): |
| if not api_key: |
| yield "<p style='color: #FFCC00;'>Por favor, insira a sua chave de API para começar.</p>", gr.update(value="➤ Enviar", interactive=True) |
| return |
| if not prompt: |
| yield "<p style='color: #FFCC00;'>Por favor, digite um prompt.</p>", gr.update(value="➤ Enviar", interactive=True) |
| return |
|
|
| yield "<p style='color: #a0a0a0;'>A gerar resposta...</p>", gr.update(value="A gerar...", interactive=False) |
|
|
| response = await self.service.generate_text(api_key=api_key, prompt=prompt, temperature=temp, max_tokens=int(max_tokens), top_k=int(top_k), top_p=top_p) |
| if response.success: |
| formatted_text = html.escape(response.data["generated_text"]).replace("\n", "<br>") |
| yield formatted_text, gr.update(value="➤ Enviar", interactive=True) |
| else: |
| yield f"<p style='color: #FF4500;'>{response.error}</p>", gr.update(value="➤ Enviar", interactive=True) |
|
|
| |
| send_button.click( |
| handle_generation, |
| inputs=[api_key_input, prompt_input, temp_slider, max_tokens_slider, top_k_slider, top_p_slider, send_button], |
| outputs=[output_display, send_button], |
| api_name="generate", |
| ) |
| key_button.click(handle_key_generation, outputs=[api_key_input, api_example_display]) |
| demo.load(lambda: gr.update(value="<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>"), [], [api_example_display]) |
|
|
| return demo |
|
|
|
|
| |
| service = GemmaService() |
| gradio_interface = GradioInterface(service) |
| gradio_blocks = gradio_interface.create_interface() |
|
|
| app = FastAPI(title="Gemma Service (Gradio + API)") |
|
|
| |
| try: |
| gr.mount_gradio_app(app, gradio_blocks, path="/") |
| except Exception as exc: |
| logger.warning("Não foi possível montar Gradio automaticamente: %s", exc) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| |
| |
| asyncio.create_task(service.initialize()) |
|
|
|
|
| @app.post("/api/generate") |
| async def api_generate(req: Request): |
| try: |
| body = await req.json() |
| except Exception: |
| return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."}) |
|
|
| api_key = body.get("api_key") |
| prompt = body.get("prompt", "") |
| max_tokens = int(body.get("max_tokens", 512)) |
| temperature = float(body.get("temperature", 0.7)) |
| top_k = int(body.get("top_k", 50)) |
| top_p = float(body.get("top_p", 0.95)) |
|
|
| resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p) |
| status = 200 if resp.success else 400 |
| return JSONResponse(status_code=status, content=resp.dict()) |
|
|
|
|
| @app.post("/run/generate") |
| async def gradio_compatible_generate(req: Request): |
| try: |
| body = await req.json() |
| except Exception: |
| return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."}) |
|
|
| data = body.get("data") |
| if not isinstance(data, list): |
| return JSONResponse(status_code=400, content={"success": False, "error": "Campo 'data' inválido. Esperado array."}) |
|
|
| try: |
| api_key = data[0] |
| prompt = data[1] if len(data) > 1 else "" |
| max_tokens = int(data[2]) if len(data) > 2 else 512 |
| temperature = float(data[3]) if len(data) > 3 else 0.7 |
| top_k = int(data[4]) if len(data) > 4 else 50 |
| top_p = float(data[5]) if len(data) > 5 else 0.95 |
| except Exception as e: |
| return JSONResponse(status_code=400, content={"success": False, "error": f"Erro ao parsear 'data': {e}"}) |
|
|
| resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p) |
| status = 200 if resp.success else 400 |
| return JSONResponse(status_code=status, content=resp.dict()) |
|
|