from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoModelForCausalLM, AutoTokenizer import json import re import time from contextlib import asynccontextmanager # --- Performance Optimizations & Model Loading --- # 1. Device Selection: Use CUDA GPU if available for a massive speed boost. device = "cuda" if torch.cuda.is_available() else "cpu" # 2. Data Type: Use float16 on GPU for faster computation and less memory usage. torch_dtype = torch.float16 if device == "cuda" else torch.float32 print(f"--- System Info ---") print(f"Using device: {device}") print(f"Using dtype: {torch_dtype}") print("--------------------") # --- App State and Model Placeholders --- model_name = "Qwen/Qwen2.5-0.5B-Instruct" tokenizer = None model = None # --- Lifespan Event Handler --- @asynccontextmanager async def lifespan(app: FastAPI): """ Handles startup and shutdown events. Loads the ML model and tokenizer on startup. """ global tokenizer, model print("Loading model and tokenizer...") start_time = time.time() tokenizer = AutoTokenizer.from_pretrained(model_name) # Set pad token if it's not already set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: # 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs. print("Attempting to load model with Flash Attention 2...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" ).to(device) print("Successfully loaded model with Flash Attention 2.") except (ImportError, RuntimeError) as e: print(f"Flash Attention 2 not available ({e}), falling back to default attention.") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, ).to(device) # 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution. print("Compiling model with torch.compile()...") try: model = torch.compile(model, mode="reduce-overhead", fullgraph=True) print("Model compiled successfully.") except Exception as e: print(f"torch.compile() failed: {e}. Running with uncompiled model.") end_time = time.time() print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.") yield # Clean up resources on shutdown (optional) print("Cleaning up and shutting down.") model = None tokenizer = None # --- FastAPI App Initialization --- app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) # --- API Request and Response Models --- class GenerationRequest(BaseModel): llm_commands: list[str] batch_size: int = 50 class GenerationResponse(BaseModel): data: list raw_output: str # Added for debugging duration_s: float # Added for performance tracking # --- Helper Functions --- def extract_json_from_text(text: str): """ Extracts a JSON array from the model's raw text output. This version is more robust and handles incomplete JSON at the end. """ # Find the first '[' and the last ']' to bound the JSON content start_bracket = text.find('[') end_bracket = text.rfind(']') if start_bracket == -1 or end_bracket == -1: return None # No JSON array found json_str = text[start_bracket : end_bracket + 1] try: # Attempt to parse the primary JSON string return json.loads(json_str) except json.JSONDecodeError: # Fallback for malformed JSON: try to parse line by line print("Warning: Initial JSON parsing failed. Attempting to recover partial data.") potential_rows = json_str.strip()[1:-1].split('],[') valid_rows = [] for row_str in potential_rows: try: # Reconstruct and parse each potential row clean_row_str = row_str.replace('[', '').replace(']', '').strip() if clean_row_str: valid_rows.append(json.loads(f'[{clean_row_str}]')) except json.JSONDecodeError: continue # Skip malformed rows return valid_rows if valid_rows else None def create_structured_prompt(commands: list[str], batch_size: int) -> str: """ Creates a more structured and forceful prompt to ensure the model returns clean JSON. """ cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)]) return f""" Generate exactly {batch_size} rows of data. Each inner array must have exactly {len(commands)} columns. The columns are defined as follows: {cols_description} Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown. Example of a valid response: [["value1", "value2"], ["value3", "value4"]] """ # --- API Endpoints --- @app.post("/generate", response_model=GenerationResponse) async def generate_data(request: GenerationRequest): if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.") start_time = time.time() try: # Create a more reliable prompt prompt = create_structured_prompt(request.llm_commands, request.batch_size) messages = [ {"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."}, {"role": "user", "content": prompt} ] # Apply the chat template text_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text_input], return_tensors="pt").to(device) # Generate with no_grad context for better performance with torch.no_grad(): # Dynamically set max_new_tokens based on expected output size with a buffer max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50) generated_ids = model.generate( **model_inputs, max_new_tokens=min(4096, max_new_tokens), do_sample=True, temperature=0.7, top_p=0.95, pad_token_id=tokenizer.pad_token_id, ) # Decode the output response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] # Extract and validate JSON data json_data = extract_json_from_text(response_text) final_data = [] if json_data and isinstance(json_data, list): expected_cols = len(request.llm_commands) # Filter for valid rows and cap at the requested batch size final_data = [ row for row in json_data if isinstance(row, list) and len(row) == expected_cols ][:request.batch_size] else: print(f"Failed to parse JSON. Raw output: {response_text}") end_time = time.time() return { "data": final_data, "raw_output": response_text, "duration_s": round(end_time - start_time, 2) } except Exception as e: print(f"An error occurred during generation: {e}") raise HTTPException(status_code=500, detail=str(e)) # --- New Test Route --- @app.get("/test", response_model=GenerationResponse, summary="Run a predefined test generation") async def test_generation(): """ A simple test endpoint that generates 10 rows of sample data with fixed commands. This allows for easy performance testing and validation. """ test_request = GenerationRequest( llm_commands=[ "a common first name starting with the letter A", "an age as an integer between 20 and 30" ], batch_size=10 ) print("--- Running /test endpoint ---") return await generate_data(test_request) # --- Health and Status Routes --- @app.get("/", summary="Root status check") def read_root(): return {"status": "ok", "model_name": model_name, "device": device} @app.get("/health", summary="Health check for the service") def health_check(): return { "status": "healthy", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "device": device }