SecureBERT Vulnerability Retriever (CVSS Classifier & CWE Bi-Encoder)
This model analyzes raw vulnerability descriptions (e.g., CVEs, bug bounty reports) to predict 8 CVSS v3.1 metrics and retrieve CWE categories across 4 abstraction levels (Pillar, Class, Base, Variant).
The CWE component is trained using Metric Learning, specifically Online Hard-Negative Mining Triplet Loss and Hierarchical Loss. The model functions as a Bi-Encoder, mapping the vulnerability description against official MITRE CWE definitions in a shared 768-dimensional vector space.
Intended Use
- CVSS Prediction: Predicts nominal and ordinal metrics from unstructured text to determine the CVSS v3.1 Base Score components.
- CWE Mapping: Retrieves the Top-K relevant CWEs by calculating cosine similarity between the input text embedding and a pre-computed knowledge base of CWE definitions.
Model Architecture
Based on the cisco-ai/SecureBERT2.0 backbone, this model employs a Multi-Task Learning (MTL) architecture:
- CVSS Heads: 8 Multi-Layer Perceptron (MLP) heads (
LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax) mapped to the[CLS]token. - CWE Heads: 4 MLP projection heads (
LayerNorm -> Dropout -> Linear -> GELU -> Dropout -> Linear -> GELU -> Dropout -> Linear -> L2 Normalization) mapped to the Mean-Pooled token embeddings. These heads project the text into a spherical vector space of 768 dimensions.
Repository Structure & Dependencies
This architecture requires the following specific files included in the repository:
config.json: Contains thecvss_mapdefining the output tensor sizes and label decoders for the CVSS classification heads.cwe_embeddings_new.pkl: A serialized dictionary containing the pre-computed, L2-normalized 768-d vectors for all reference CWE definitions. This file is required for the CWE retrieval process.
Usage & Inference
Below is the standalone Python snippet to download the model, configuration, and the CWE embeddings knowledge base to perform predictions.
import json
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# 1. Define the Custom Multi-Head Architecture
class SecureBERTMultiHead(nn.Module):
def __init__(self, model_name_or_path):
super().__init__()
self.modernbert_config = AutoConfig.from_pretrained(model_name_or_path)
if hasattr(self.modernbert_config, "reference_compile"):
self.modernbert_config.reference_compile = False
self.bert = AutoModel.from_config(self.modernbert_config)
cvss_map = getattr(self.modernbert_config, "cvss_map", {})
self.cvss_heads = nn.ModuleDict({
k: nn.Sequential(
nn.LayerNorm(768), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, len(classes)), nn.Softmax(dim=1)
) for k, classes in cvss_map.items()
})
self.cwe_heads = nn.ModuleDict({
k: nn.Sequential(
nn.LayerNorm(768), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, 768)
) for k in['pillar', 'class', 'base', 'variant']
})
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
cls_emb = out[:, 0, :]
mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
res = {}
for k, head in self.cvss_heads.items():
res[k] = head(cls_emb)
for k, head in self.cwe_heads.items():
res[k] = F.normalize(head(mean_emb), p=2, dim=1)
return res
# 2. Inference Wrapper
class VulnRetriever:
def __init__(self, repo_id):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
emb_path = hf_hub_download(repo_id=repo_id, filename="cwe_embeddings_new.pkl")
with open(conf_path, "r", encoding='utf-8') as f:
self.config = json.load(f)
self.cvss_map = self.config.get("cvss_map", {})
base_model_name = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
self.model = SecureBERTMultiHead(repo_id)
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
self.model.to(self.device).eval()
with open(emb_path, 'rb') as f:
self.embeddings_map = pickle.load(f)
self.candidates = {}
for level, data in self.embeddings_map.items():
ids = list(data.keys())
vecs = torch.stack([data[k]['vector'] for k in ids]).to(self.device)
vecs = F.normalize(vecs, p=2, dim=1)
self.candidates[level] = {'ids': ids, 'matrix': vecs}
def predict(self, text, top_k=3):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(self.device)
with torch.no_grad():
out = self.model(inputs['input_ids'], inputs['attention_mask'])
results = {'cvss': {}, 'cwe': {}}
for task, labels in self.cvss_map.items():
score, idx = torch.max(out[task], dim=1)
idx_val = idx.item()
results['cvss'][task] = {
'value': labels[idx_val] if idx_val < len(labels) else "Unknown",
'confidence': round(score.item(), 4)
}
for level in ['pillar', 'class', 'base', 'variant']:
if level not in self.candidates: continue
query_vec = out[level]
cand_matrix = self.candidates[level]['matrix']
cand_ids = self.candidates[level]['ids']
scores = torch.matmul(query_vec, cand_matrix.T).squeeze()
if scores.dim() == 0: scores = scores.unsqueeze(0)
top_scores, top_indices = torch.topk(scores, k=min(top_k, scores.size(0)))
probs = F.softmax(top_scores, dim=0)
level_preds =[]
for score_val, idx in zip(probs, top_indices):
cwe_id = cand_ids[idx.item()]
cwe_name = self.embeddings_map[level][cwe_id]['name']
level_preds.append({
'id': int(str(cwe_id).replace('CWE-', '')),
'name': cwe_name,
'score': round(score_val.item(), 4)
})
results['cwe'][level] = level_preds
return results
# 3. Execution
if __name__ == "__main__":
REPO_ID = "YourUsername/Your-Triplet-Repo-Name"
retriever = VulnRetriever(REPO_ID)
sample_cve = "A buffer overflow in the web server allows remote attackers to execute arbitrary code via a long URL."
results = retriever.predict(sample_cve)
print(json.dumps(results, indent=2))
- Downloads last month
- 76
Model tree for bziemba/SecureBERT-triplets-new-arch
Base model
answerdotai/ModernBERT-base