Sgridda commited on
Commit
39b69d9
·
1 Parent(s): 9d8ec9c

Re-enable TinyLlama model for actual inference

Browse files
Files changed (1) hide show
  1. main.py +102 -27
main.py CHANGED
@@ -1,34 +1,66 @@
1
-
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
 
 
4
  import re
5
  import json
6
 
7
  # ----------------------------
8
- # 1. FastAPI App Initialization
 
 
 
 
 
 
 
9
  # ----------------------------
10
 
11
  app = FastAPI(
12
- title="AI Code Review Service (Test Mode)",
13
- description="A test version of the API without a live AI model.",
14
  version="1.0.0",
15
  )
16
 
17
  # ----------------------------
18
- # 2. Mock AI Model Loading (Simulated)
19
  # ----------------------------
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @app.on_event("startup")
22
  async def startup_event():
23
  """
24
- In this test version, we just print a message.
25
- We are not loading any real model.
26
  """
27
- print("Server starting up in test mode.")
28
- print("Model loading is disabled.")
29
 
30
  # ----------------------------
31
- # 3. API Request/Response Models
32
  # ----------------------------
33
 
34
  class ReviewRequest(BaseModel):
@@ -43,35 +75,78 @@ class ReviewResponse(BaseModel):
43
  comments: list[ReviewComment]
44
 
45
  # ----------------------------
46
- # 4. The API Endpoint (with Mocked Response)
47
  # ----------------------------
48
 
49
- @app.post("/review", response_model=ReviewResponse)
50
- async def get_code_review(request: ReviewRequest):
51
  """
52
- This endpoint now returns a hardcoded, successful response.
53
- It does not call an AI model.
54
  """
55
- print("Received request for /review. Returning mocked response.")
56
- if not request.diff:
57
- raise HTTPException(status_code=400, detail="Diff content cannot be empty.")
58
 
59
- # Create a fake response to prove the endpoint is working.
60
- mock_comments = [
 
 
 
61
  {
62
- "file_path": "src/mock/test.py",
63
- "line_number": 10,
64
- "comment_text": "This is a test comment from the mock server. If you see this, the API is working!"
65
  }
66
  ]
 
 
 
 
67
 
68
- return ReviewResponse(comments=[ReviewComment(**c) for c in mock_comments])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # ----------------------------
71
- # 5. Health Check Endpoint
72
  # ----------------------------
73
 
74
  @app.get("/health")
75
  async def health_check():
76
- """A simple endpoint to confirm the server is running."""
77
- return {"status": "ok", "model_loaded": False} # Model is not loaded in test mode
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import torch
5
  import re
6
  import json
7
 
8
  # ----------------------------
9
+ # 1. Configuration
10
+ # ----------------------------
11
+
12
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # ----------------------------
16
+ # 2. FastAPI App Initialization
17
  # ----------------------------
18
 
19
  app = FastAPI(
20
+ title="AI Code Review Service",
21
+ description="An API to get AI-powered code reviews for pull request diffs.",
22
  version="1.0.0",
23
  )
24
 
25
  # ----------------------------
26
+ # 3. AI Model Loading
27
  # ----------------------------
28
 
29
+ model = None
30
+ tokenizer = None
31
+
32
+ def load_model():
33
+ """Loads the model and tokenizer into memory."""
34
+ global model, tokenizer
35
+ if model is None:
36
+ print(f"Loading model: {MODEL_NAME} on device: {DEVICE}...")
37
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
38
+
39
+ quantization_config = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ bnb_4bit_compute_dtype=torch.bfloat16,
43
+ bnb_4bit_use_double_quant=False,
44
+ )
45
+
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ MODEL_NAME,
48
+ trust_remote_code=True,
49
+ quantization_config=quantization_config,
50
+ device_map="auto",
51
+ )
52
+ print("Model loaded successfully.")
53
+
54
  @app.on_event("startup")
55
  async def startup_event():
56
  """
57
+ On server startup, we trigger the model loading.
 
58
  """
59
+ print("Server starting up...")
60
+ load_model()
61
 
62
  # ----------------------------
63
+ # 4. API Request/Response Models
64
  # ----------------------------
65
 
66
  class ReviewRequest(BaseModel):
 
75
  comments: list[ReviewComment]
76
 
77
  # ----------------------------
78
+ # 5. The AI Review Logic
79
  # ----------------------------
80
 
81
+ def run_ai_inference(diff: str) -> str:
 
82
  """
83
+ Runs the AI model to get the review.
 
84
  """
85
+ if not model or not tokenizer:
86
+ raise RuntimeError("Model is not loaded.")
 
87
 
88
+ messages = [
89
+ {
90
+ "role": "system",
91
+ "content": """You are an expert code reviewer. Your task is to analyze a pull request diff and provide constructive feedback.\nAnalyze the provided diff and identify potential issues, suggest improvements, or point out good practices.\n\nIMPORTANT: Respond with a JSON array of comment objects. Each object must have three fields: 'file_path', 'line_number', and 'comment_text'.\nThe 'file_path' should be the full path of the file being changed.\nThe 'line_number' must be an integer corresponding to the line number in the *new* version of the file where the comment applies.\nThe 'comment_text' should be your concise and clear review comment.\n\nExample response format:\n[\n {\n "file_path": "src/utils/helpers.py",\n "line_number": 42,\n "comment_text": "This function could be simplified by using a list comprehension."\n }\n]\n\nDo not add any introductory text or explanations outside of the JSON array.\n"""
92
+ },
93
  {
94
+ "role": "user",
95
+ "content": f"Here is the diff to review:\n\n```diff\n{diff}\n```"
 
96
  }
97
  ]
98
+
99
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
100
+
101
+ outputs = model.generate(inputs, max_new_tokens=1024, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
102
 
103
+ response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
104
+ return response_text.strip()
105
+
106
+ def parse_ai_response(response_text: str) -> list[ReviewComment]:
107
+ """
108
+ Parses the raw text from the AI to extract the JSON array.
109
+ """
110
+ print(f"Raw AI Response:\n---\n{response_text}\n---")
111
+
112
+ json_match = re.search(r'\[.*\]', response_text, re.DOTALL)
113
+ if not json_match:
114
+ print("Warning: Could not find a JSON array in the AI response.")
115
+ return []
116
+
117
+ json_string = json_match.group(0)
118
+
119
+ try:
120
+ comments_data = json.loads(json_string)
121
+ validated_comments = [ReviewComment(**item) for item in comments_data]
122
+ return validated_comments
123
+ except (json.JSONDecodeError, TypeError, KeyError) as e:
124
+ print(f"Error parsing JSON from AI response: {e}")
125
+ print(f"Invalid JSON string: {json_string}")
126
+ return []
127
+
128
+ # ----------------------------
129
+ # 6. The API Endpoint
130
+ # ----------------------------
131
+
132
+ @app.post("/review", response_model=ReviewResponse)
133
+ async def get_code_review(request: ReviewRequest):
134
+ if not request.diff:
135
+ raise HTTPException(status_code=400, detail="Diff content cannot be empty.")
136
+
137
+ try:
138
+ ai_response_text = run_ai_inference(request.diff)
139
+ parsed_comments = parse_ai_response(ai_response_text)
140
+ return ReviewResponse(comments=parsed_comments)
141
+
142
+ except Exception as e:
143
+ print(f"An unexpected error occurred: {e}")
144
+ raise HTTPException(status_code=500, detail="An internal error occurred while processing the review.")
145
 
146
  # ----------------------------
147
+ # 7. Health Check Endpoint
148
  # ----------------------------
149
 
150
  @app.get("/health")
151
  async def health_check():
152
+ return {"status": "ok", "model_loaded": model is not None}