from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import re import json # ---------------------------- # 1. Configuration # ---------------------------- MODEL_NAME = "Salesforce/codegen-350M-mono" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------------------------- # 2. FastAPI App Initialization # ---------------------------- app = FastAPI( title="AI Code Review Service", description="An API to get AI-powered code reviews for pull request diffs.", version="1.0.0", ) # ---------------------------- # 3. AI Model Loading # ---------------------------- model = None tokenizer = None def load_model(): """Loads the model and tokenizer into memory.""" global model, tokenizer if model is None: print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map="cpu", ) print("Model loaded successfully.") @app.on_event("startup") async def startup_event(): """ On server startup, we trigger the model loading. """ print("Server starting up...") load_model() # ---------------------------- # 4. API Request/Response Models # ---------------------------- class ReviewRequest(BaseModel): diff: str class ReviewComment(BaseModel): file_path: str line_number: int comment_text: str class ReviewResponse(BaseModel): comments: list[ReviewComment] # ---------------------------- # 5. The AI Review Logic # ---------------------------- def run_ai_inference(diff: str) -> str: """ Runs the AI model to get the review. """ if not model or not tokenizer: raise RuntimeError("Model is not loaded.") # Improved prompt for codegen-350M-mono prompt = ( "Below is a Python function. Please provide a code review comment with suggestions for improvement, in natural language. " "Do not repeat the code.\n" f"{diff[:800]}\n" "Review comment:" ) encoded = tokenizer( prompt, return_tensors="pt", max_length=1024, truncation=True, padding="max_length" ) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.95, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, use_cache=True ) response_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) # Post-process: filter out code-like lines and fallback if needed review_lines = [line.strip() for line in response_text.strip().split('\n') if line.strip()] # Filter out lines that look like code comment_lines = [l for l in review_lines if not l.startswith("def ") and not l.startswith("class ") and not l.endswith(":") and not l.startswith("#")] review = comment_lines[0] if comment_lines else "Consider adding a docstring and input validation." return review def parse_ai_response(response_text: str) -> list[ReviewComment]: """ Parses the raw text from the AI to extract the JSON array. """ # For codegen-350M-mono, just wrap the review in a single comment return [ReviewComment( file_path="code_reviewed.py", line_number=1, comment_text=response_text.strip() )] # ---------------------------- # 6. The API Endpoint # ---------------------------- @app.post("/review", response_model=ReviewResponse) async def get_code_review(request: ReviewRequest): if not request.diff: raise HTTPException(status_code=400, detail="Diff content cannot be empty.") import time start_time = time.time() print(f"Starting review request at {start_time}") try: print("Running AI inference...") ai_response_text = run_ai_inference(request.diff) print(f"AI inference completed in {time.time() - start_time:.2f} seconds") print("Parsing AI response...") parsed_comments = parse_ai_response(ai_response_text) print(f"Total processing time: {time.time() - start_time:.2f} seconds") return ReviewResponse(comments=parsed_comments) except Exception as e: print(f"An unexpected error occurred after {time.time() - start_time:.2f} seconds: {e}") raise HTTPException(status_code=500, detail="An internal error occurred while processing the review.") # ---------------------------- # 7. Health Check Endpoint # ---------------------------- @app.get("/health") async def health_check(): return {"status": "ok", "model_loaded": model is not None}