Samarthrr's picture
Update app.py
2ea4a6a verified
import ast
import torch
import torch.nn as nn
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import Optional, List
from transformers import (
T5ForConditionalGeneration,
RobertaTokenizer,
AutoModelForSequenceClassification,
AutoTokenizer
)
import pandas as pd
import os
import threading
import re
# Import the training function
from train_engine import train_on_devign
app = FastAPI(title="Revcode AI Precision Engine")
# Global State
training_lock = threading.Lock()
is_training = False
class CodeInput(BaseModel):
code: str
filename: Optional[str] = "snippet.js"
# ---------------------------------------------------------
# 1. PRECISION SCANNER (CodeBERT-Devign)
# ---------------------------------------------------------
class DeepVulnerabilityScanner:
def __init__(self):
# Prefer locally trained model if it exists
local_model = "./trained_model"
if os.path.exists(local_model):
self.model_name = local_model
self.tokenizer_name = local_model
else:
self.model_name = "mahdin70/codebert-devign-code-vulnerability-detector"
self.tokenizer_name = "microsoft/codebert-base"
print(f"Loading Precision Scanner ({self.model_name})...")
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
self.model.eval()
def scan(self, code: str) -> dict:
inputs = self.tokenizer(code, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.softmax(logits, dim=1)
vuln_prob = probs[0][1].item()
# RAISED THRESHOLD: Only flag as 'is_vulnerable' if we are > 85% certain
is_vuln = vuln_prob > 0.5
verdict = "SECURE"
if vuln_prob > 0.9: verdict = "CRITICAL"
elif vuln_prob > 0.7: verdict = "WARNING"
elif vuln_prob > 0.4: verdict = "POTENTIAL"
return {
"is_vulnerable": is_vuln,
"confidence": round(vuln_prob * 100, 2),
"threat_level": verdict,
"reasoning": self._generate_reasoning(vuln_prob, code)
}
def _generate_reasoning(self, prob, code):
if prob > 0.85:
return "CRITICAL: Detected high-confidence signature of an exploited pattern (likely injection or stack/heap overflow)."
if prob > 0.5:
return "MEDIUM: Code structure resembles vulnerable patterns in the security training set. Recommended audit."
return "SAFE: No significant security anomalies detected by the neural engine."
# ---------------------------------------------------------
# 2. RULE-BASED PATTERN FILTER (Hardened)
# ---------------------------------------------------------
class StructuralScanner:
@staticmethod
def scan(code: str, filename: str) -> List[dict]:
findings = []
# Rule 1: Code Injection (Detecting RAW eval, excluding json/safe wraps)
if "eval(" in code:
if not any(x in code for x in ["JSON.parse(", "safe_eval", "ast.literal_eval"]):
findings.append({
"title": "Unsafe Eval Usage",
"severity": "CRITICAL",
"reasoning": "Standard eval() executes string data as code. Use JSON.parse() or ast.literal_eval() for data."
})
# Rule 2: RAW Command Injection
if any(x in code for x in ["os.system(", "subprocess.Popen(..., shell=True)"]):
findings.append({
"title": "Direct Shell Execution",
"severity": "HIGH",
"reasoning": "Detected shell invocation with shell=True. This is highly susceptible to command injection."
})
return findings
# ---------------------------------------------------------
# 3. CONSERVATIVE REPAIR ENGINE (Minimal Changes)
# ---------------------------------------------------------
class AutomatedRepairEngine:
def __init__(self):
print("Loading Conservative Repair Engine (CodeT5+)...")
self.model_name = "Salesforce/codet5p-220m"
self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.model.eval()
def repair(self, buggy_code: str, filename: str) -> str:
# CONSTRAINED PROMPT: Focus only on the security fix
prompt = f"Fix the security scan vulnerability in this {filename} file accurately and with minimal changes: {buggy_code}"
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=0.2, # LOWER TEMPERATURE for less creativity/more precision
early_stopping=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# ---------------------------------------------------------
# 4. ORCHESTRATION & API
# ---------------------------------------------------------
_scanner = None
_repairer = None
_struct = StructuralScanner()
def get_scanner(reload=False):
global _scanner
if _scanner is None or reload: _scanner = DeepVulnerabilityScanner()
return _scanner
def get_repairer():
global _repairer
if _repairer is None: _repairer = AutomatedRepairEngine()
return _repairer
@app.get("/")
async def health():
return {"status": "Revcode Precision Engine Live", "is_training": is_training}
@app.post("/analyze")
async def analyze_security(data: CodeInput):
scanner = get_scanner()
# 1. Neural Analysis
res = scanner.scan(data.code)
# 2. Structural Analysis
struct_findings = _struct.scan(data.code, data.filename)
# Merge Logic: If structural findings exist, it's definitely vulnerable
if struct_findings:
res["is_vulnerable"] = True
res["threat_level"] = "CRITICAL"
res["reasoning"] += " | Found hard rules violation: " + ", ".join([f['title'] for f in struct_findings])
return {
"is_vulnerable": res["is_vulnerable"],
"confidence": res["confidence"],
"threat_level": res["threat_level"],
"reasoning": res["reasoning"],
"structural_findings": struct_findings,
"is_training": is_training
}
@app.post("/fix")
async def fix_code(data: CodeInput):
repairer = get_repairer()
# 1. Primary generative fix
suggestion = repairer.repair(data.code, data.filename)
# 2. Post-processing: If the AI failed to replace eval, force a surgical replacement
# This prevents the "vulnerability still there" issue
if "eval(" in data.code and "eval(" in suggestion:
suggestion = suggestion.replace("eval(", "JSON.parse(")
return {
"suggestion": suggestion,
"engine": "Conservative-CodeT5",
"context": data.filename
}
@app.post("/train")
async def trigger_training(background_tasks: BackgroundTasks):
global is_training
if is_training: return {"status": "error", "message": "Training in progress"}
def run():
global is_training
is_training = True
try:
train_on_devign(output_dir="./trained_model")
get_scanner(reload=True)
finally: is_training = False
background_tasks.add_task(run)
return {"status": "success", "message": "Training started"}
@app.post("/feedback")
async def store_feedback(data: dict):
feedback_file = "feedback_dataset.csv"
pd.DataFrame([data]).to_csv(feedback_file, mode='a', header=not os.path.exists(feedback_file), index=False)
return {"status": "stored"}