Vasanthakumar R
feat: add ZeroGPU support via @spaces.GPU decorator
f244f86
"""
Achilles Code Scanner β€” HuggingFace Space
AI-powered SAST: paste code, find vulnerabilities.
Deploy:
1. Create Space on huggingface.co (Gradio SDK, T4 Small GPU)
2. Upload this directory
3. Set secrets: HF_MODEL (your fine-tuned SAST model or base model)
"""
import os
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ── Model ───────────────────────────────────────────────────────
MODEL_ID = os.environ.get("HF_MODEL", "Qwen/Qwen2.5-Coder-1.5B-Instruct")
ADAPTER_ID = os.environ.get("HF_ADAPTER", "")
SYSTEM_PROMPT = (
"You are Achilles, an elite AI Security Engineer. "
"You ONLY report genuine vulnerabilities β€” you never raise false positives. "
"You ALWAYS provide a response β€” never return empty output."
)
print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True,
)
if ADAPTER_ID:
from peft import PeftModel
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()
print("Model ready!")
# ── Inference (GPU allocated only during this call) ─────────────
@spaces.GPU(duration=120)
def scan_code(language: str, code: str, max_tokens: int = 1024) -> str:
if not code.strip():
return "Paste some code to scan."
user_msg = f"Analyze the following {language} code for security vulnerabilities:\n\n```{language}\n{code}\n```"
prompt = (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_msg}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.3,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
if "<|im_end|>" in response:
response = response[:response.index("<|im_end|>")]
return response.strip()
# ── Examples ────────────────────────────────────────────────────
EXAMPLES = [
["python", '''import sqlite3, sys
def get_user(username):
conn = sqlite3.connect("app.db")
query = f"SELECT * FROM users WHERE username = '{username}'"
return conn.execute(query).fetchone()
print(get_user(sys.argv[1]))'''],
["javascript", '''const express = require('express');
const { exec } = require('child_process');
const app = express();
app.get('/ping', (req, res) => {
exec(`ping -c 3 ${req.query.host}`, (err, stdout) => {
res.send(`<pre>${stdout}</pre>`);
});
});
app.listen(3000);'''],
["php", '''<?php
$file = $_GET["page"];
include("/var/www/templates/" . $file);
?>'''],
["c", '''#include <stdio.h>
#include <string.h>
void process_input(char *input) {
char buffer[64];
strcpy(buffer, input);
printf("Processed: %s\\n", buffer);
}
int main(int argc, char *argv[]) {
if (argc > 1) process_input(argv[1]);
return 0;
}'''],
["java", '''import java.io.*;
import javax.servlet.http.*;
public class FileServlet extends HttpServlet {
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
String filename = req.getParameter("file");
FileInputStream fis = new FileInputStream("/uploads/" + filename);
byte[] data = fis.readAllBytes();
resp.getOutputStream().write(data);
}
}'''],
["ruby", '''class UsersController < ApplicationController
def search
@users = User.where("name LIKE '%#{params[:q]}%'")
render json: @users
end
end'''],
["typescript", '''import express from 'express';
const app = express();
app.use(express.json());
app.post('/api/login', async (req, res) => {
const user = await db.collection('users').findOne({
email: req.body.email,
password: req.body.password
});
res.json({ user });
});'''],
["go", '''package main
import (
"database/sql"
"fmt"
"net/http"
)
func handler(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
query := fmt.Sprintf("SELECT * FROM users WHERE name = '%s'", name)
rows, _ := db.Query(query)
defer rows.Close()
}'''],
]
# ── UI ──────────────────────────────────────────────────────────
LANG_CHOICES = ["python", "javascript", "typescript", "php", "c", "cpp", "java", "ruby", "go", "rust"]
CSS = """
.header { text-align: center; padding: 24px 0 12px; }
.header h1 { color: #dc2626; font-size: 2.4em; margin: 0; letter-spacing: -0.02em; }
.header .sub { color: #94a3b8; margin: 4px 0 0; font-size: 1em; }
.header .brand { color: #475569; font-size: 0.8em; margin-top: 6px; }
.status-bar { background: #1e293b; border-radius: 8px; padding: 10px 16px; margin: 0 0 16px;
display: flex; justify-content: space-between; align-items: center; }
.status-bar span { color: #94a3b8; font-size: 0.85em; }
.status-bar .model { color: #22c55e; font-weight: 600; }
.status-bar .device { color: #f59e0b; }
footer { display: none !important; }
"""
with gr.Blocks(
title="Achilles Code Scanner",
theme=gr.themes.Base(primary_hue="red", secondary_hue="slate", neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")),
css=CSS,
) as demo:
gr.HTML(f"""
<div class="header">
<h1>ACHILLES</h1>
<p class="sub">AI-Powered Code Vulnerability Scanner</p>
<p class="brand">Built by HTS-ASPM</p>
</div>
<div class="status-bar">
<span>Model: <span class="model">{MODEL_ID.split('/')[-1]}</span></span>
<span>Device: <span class="device">{device.upper()}</span></span>
<span>Languages: 10 supported</span>
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
lang = gr.Dropdown(choices=LANG_CHOICES, value="python", label="Language")
code_input = gr.Code(label="Paste your code", language="python", lines=18)
with gr.Row():
max_tok = gr.Slider(256, 2048, value=1024, step=128, label="Max tokens")
scan_btn = gr.Button("Scan for Vulnerabilities", variant="primary", size="lg")
with gr.Column(scale=1):
output = gr.Markdown(label="Security Analysis")
scan_btn.click(fn=scan_code, inputs=[lang, code_input, max_tok], outputs=output)
with gr.Accordion("Example Vulnerabilities", open=False):
gr.Examples(
examples=EXAMPLES,
inputs=[lang, code_input],
label="Click to load",
)
gr.HTML("""
<p style="text-align:center; color:#475569; font-size:0.78em; padding:12px;">
Achilles Code Scanner &mdash; Results are AI-generated. Always verify findings with manual review.
</p>
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)