Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import json | |
| import re | |
| import time | |
| from contextlib import asynccontextmanager | |
| # --- Performance Optimizations & Model Loading --- | |
| # 1. Device Selection: Use CUDA GPU if available for a massive speed boost. | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 2. Data Type: Use float16 on GPU for faster computation and less memory usage. | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| print(f"--- System Info ---") | |
| print(f"Using device: {device}") | |
| print(f"Using dtype: {torch_dtype}") | |
| print("--------------------") | |
| # --- App State and Model Placeholders --- | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| tokenizer = None | |
| model = None | |
| # --- Lifespan Event Handler --- | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Handles startup and shutdown events. | |
| Loads the ML model and tokenizer on startup. | |
| """ | |
| global tokenizer, model | |
| print("Loading model and tokenizer...") | |
| start_time = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Set pad token if it's not already set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| try: | |
| # 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs. | |
| print("Attempting to load model with Flash Attention 2...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| attn_implementation="flash_attention_2" | |
| ).to(device) | |
| print("Successfully loaded model with Flash Attention 2.") | |
| except (ImportError, RuntimeError) as e: | |
| print(f"Flash Attention 2 not available ({e}), falling back to default attention.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| ).to(device) | |
| # 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution. | |
| print("Compiling model with torch.compile()...") | |
| try: | |
| model = torch.compile(model, mode="reduce-overhead", fullgraph=True) | |
| print("Model compiled successfully.") | |
| except Exception as e: | |
| print(f"torch.compile() failed: {e}. Running with uncompiled model.") | |
| end_time = time.time() | |
| print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.") | |
| yield | |
| # Clean up resources on shutdown (optional) | |
| print("Cleaning up and shutting down.") | |
| model = None | |
| tokenizer = None | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| # --- API Request and Response Models --- | |
| class GenerationRequest(BaseModel): | |
| llm_commands: list[str] | |
| batch_size: int = 50 | |
| class GenerationResponse(BaseModel): | |
| data: list | |
| raw_output: str # Added for debugging | |
| duration_s: float # Added for performance tracking | |
| # --- Helper Functions --- | |
| def extract_json_from_text(text: str): | |
| """ | |
| Extracts a JSON array from the model's raw text output. | |
| This version is more robust and handles incomplete JSON at the end. | |
| """ | |
| # Find the first '[' and the last ']' to bound the JSON content | |
| start_bracket = text.find('[') | |
| end_bracket = text.rfind(']') | |
| if start_bracket == -1 or end_bracket == -1: | |
| return None # No JSON array found | |
| json_str = text[start_bracket : end_bracket + 1] | |
| try: | |
| # Attempt to parse the primary JSON string | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| # Fallback for malformed JSON: try to parse line by line | |
| print("Warning: Initial JSON parsing failed. Attempting to recover partial data.") | |
| potential_rows = json_str.strip()[1:-1].split('],[') | |
| valid_rows = [] | |
| for row_str in potential_rows: | |
| try: | |
| # Reconstruct and parse each potential row | |
| clean_row_str = row_str.replace('[', '').replace(']', '').strip() | |
| if clean_row_str: | |
| valid_rows.append(json.loads(f'[{clean_row_str}]')) | |
| except json.JSONDecodeError: | |
| continue # Skip malformed rows | |
| return valid_rows if valid_rows else None | |
| def create_structured_prompt(commands: list[str], batch_size: int) -> str: | |
| """ | |
| Creates a more structured and forceful prompt to ensure the model returns clean JSON. | |
| """ | |
| cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)]) | |
| return f""" | |
| Generate exactly {batch_size} rows of data. | |
| Each inner array must have exactly {len(commands)} columns. | |
| The columns are defined as follows: | |
| {cols_description} | |
| Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown. | |
| Example of a valid response: | |
| [["value1", "value2"], ["value3", "value4"]] | |
| """ | |
| # --- API Endpoints --- | |
| async def generate_data(request: GenerationRequest): | |
| if not model or not tokenizer: | |
| raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.") | |
| start_time = time.time() | |
| try: | |
| # Create a more reliable prompt | |
| prompt = create_structured_prompt(request.llm_commands, request.batch_size) | |
| messages = [ | |
| {"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Apply the chat template | |
| text_input = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text_input], return_tensors="pt").to(device) | |
| # Generate with no_grad context for better performance | |
| with torch.no_grad(): | |
| # Dynamically set max_new_tokens based on expected output size with a buffer | |
| max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50) | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=min(4096, max_new_tokens), | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode the output | |
| response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] | |
| # Extract and validate JSON data | |
| json_data = extract_json_from_text(response_text) | |
| final_data = [] | |
| if json_data and isinstance(json_data, list): | |
| expected_cols = len(request.llm_commands) | |
| # Filter for valid rows and cap at the requested batch size | |
| final_data = [ | |
| row for row in json_data | |
| if isinstance(row, list) and len(row) == expected_cols | |
| ][:request.batch_size] | |
| else: | |
| print(f"Failed to parse JSON. Raw output: {response_text}") | |
| end_time = time.time() | |
| return { | |
| "data": final_data, | |
| "raw_output": response_text, | |
| "duration_s": round(end_time - start_time, 2) | |
| } | |
| except Exception as e: | |
| print(f"An error occurred during generation: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- New Test Route --- | |
| async def test_generation(): | |
| """ | |
| A simple test endpoint that generates 10 rows of sample data with fixed commands. | |
| This allows for easy performance testing and validation. | |
| """ | |
| test_request = GenerationRequest( | |
| llm_commands=[ | |
| "a common first name starting with the letter A", | |
| "an age as an integer between 20 and 30" | |
| ], | |
| batch_size=10 | |
| ) | |
| print("--- Running /test endpoint ---") | |
| return await generate_data(test_request) | |
| # --- Health and Status Routes --- | |
| def read_root(): | |
| return {"status": "ok", "model_name": model_name, "device": device} | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": tokenizer is not None, | |
| "device": device | |
| } |