generate_api / main.py
aledraa's picture
Update main.py
a7a61ee verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
import time
from contextlib import asynccontextmanager
# --- Performance Optimizations & Model Loading ---
# 1. Device Selection: Use CUDA GPU if available for a massive speed boost.
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Data Type: Use float16 on GPU for faster computation and less memory usage.
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"--- System Info ---")
print(f"Using device: {device}")
print(f"Using dtype: {torch_dtype}")
print("--------------------")
# --- App State and Model Placeholders ---
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = None
model = None
# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events.
Loads the ML model and tokenizer on startup.
"""
global tokenizer, model
print("Loading model and tokenizer...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token if it's not already set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
# 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs.
print("Attempting to load model with Flash Attention 2...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2"
).to(device)
print("Successfully loaded model with Flash Attention 2.")
except (ImportError, RuntimeError) as e:
print(f"Flash Attention 2 not available ({e}), falling back to default attention.")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
).to(device)
# 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution.
print("Compiling model with torch.compile()...")
try:
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
print("Model compiled successfully.")
except Exception as e:
print(f"torch.compile() failed: {e}. Running with uncompiled model.")
end_time = time.time()
print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.")
yield
# Clean up resources on shutdown (optional)
print("Cleaning up and shutting down.")
model = None
tokenizer = None
# --- FastAPI App Initialization ---
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# --- API Request and Response Models ---
class GenerationRequest(BaseModel):
llm_commands: list[str]
batch_size: int = 50
class GenerationResponse(BaseModel):
data: list
raw_output: str # Added for debugging
duration_s: float # Added for performance tracking
# --- Helper Functions ---
def extract_json_from_text(text: str):
"""
Extracts a JSON array from the model's raw text output.
This version is more robust and handles incomplete JSON at the end.
"""
# Find the first '[' and the last ']' to bound the JSON content
start_bracket = text.find('[')
end_bracket = text.rfind(']')
if start_bracket == -1 or end_bracket == -1:
return None # No JSON array found
json_str = text[start_bracket : end_bracket + 1]
try:
# Attempt to parse the primary JSON string
return json.loads(json_str)
except json.JSONDecodeError:
# Fallback for malformed JSON: try to parse line by line
print("Warning: Initial JSON parsing failed. Attempting to recover partial data.")
potential_rows = json_str.strip()[1:-1].split('],[')
valid_rows = []
for row_str in potential_rows:
try:
# Reconstruct and parse each potential row
clean_row_str = row_str.replace('[', '').replace(']', '').strip()
if clean_row_str:
valid_rows.append(json.loads(f'[{clean_row_str}]'))
except json.JSONDecodeError:
continue # Skip malformed rows
return valid_rows if valid_rows else None
def create_structured_prompt(commands: list[str], batch_size: int) -> str:
"""
Creates a more structured and forceful prompt to ensure the model returns clean JSON.
"""
cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)])
return f"""
Generate exactly {batch_size} rows of data.
Each inner array must have exactly {len(commands)} columns.
The columns are defined as follows:
{cols_description}
Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown.
Example of a valid response:
[["value1", "value2"], ["value3", "value4"]]
"""
# --- API Endpoints ---
@app.post("/generate", response_model=GenerationResponse)
async def generate_data(request: GenerationRequest):
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.")
start_time = time.time()
try:
# Create a more reliable prompt
prompt = create_structured_prompt(request.llm_commands, request.batch_size)
messages = [
{"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."},
{"role": "user", "content": prompt}
]
# Apply the chat template
text_input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text_input], return_tensors="pt").to(device)
# Generate with no_grad context for better performance
with torch.no_grad():
# Dynamically set max_new_tokens based on expected output size with a buffer
max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=min(4096, max_new_tokens),
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
# Decode the output
response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
# Extract and validate JSON data
json_data = extract_json_from_text(response_text)
final_data = []
if json_data and isinstance(json_data, list):
expected_cols = len(request.llm_commands)
# Filter for valid rows and cap at the requested batch size
final_data = [
row for row in json_data
if isinstance(row, list) and len(row) == expected_cols
][:request.batch_size]
else:
print(f"Failed to parse JSON. Raw output: {response_text}")
end_time = time.time()
return {
"data": final_data,
"raw_output": response_text,
"duration_s": round(end_time - start_time, 2)
}
except Exception as e:
print(f"An error occurred during generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
# --- New Test Route ---
@app.get("/test", response_model=GenerationResponse, summary="Run a predefined test generation")
async def test_generation():
"""
A simple test endpoint that generates 10 rows of sample data with fixed commands.
This allows for easy performance testing and validation.
"""
test_request = GenerationRequest(
llm_commands=[
"a common first name starting with the letter A",
"an age as an integer between 20 and 30"
],
batch_size=10
)
print("--- Running /test endpoint ---")
return await generate_data(test_request)
# --- Health and Status Routes ---
@app.get("/", summary="Root status check")
def read_root():
return {"status": "ok", "model_name": model_name, "device": device}
@app.get("/health", summary="Health check for the service")
def health_check():
return {
"status": "healthy",
"model_loaded": model is not None,
"tokenizer_loaded": tokenizer is not None,
"device": device
}