SecureBERT Vulnerability Retriever (CVSS Classifier & CWE Bi-Encoder)

This model analyzes raw vulnerability descriptions (e.g., CVEs, bug bounty reports) to predict 8 CVSS v3.1 metrics and retrieve CWE categories across 4 abstraction levels (Pillar, Class, Base, Variant).

The CWE component is trained using Metric Learning, specifically Online Hard-Negative Mining Triplet Loss and Hierarchical Loss. The model functions as a Bi-Encoder, mapping the vulnerability description against official MITRE CWE definitions in a shared 768-dimensional vector space.

Intended Use

  • CVSS Prediction: Predicts nominal and ordinal metrics from unstructured text to determine the CVSS v3.1 Base Score components.
  • CWE Mapping: Retrieves the Top-K relevant CWEs by calculating cosine similarity between the input text embedding and a pre-computed knowledge base of CWE definitions.

Model Architecture

Based on the cisco-ai/SecureBERT2.0 backbone, this model employs a Multi-Task Learning (MTL) architecture:

  • CVSS Heads: 8 Multi-Layer Perceptron (MLP) heads (LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax) mapped to the [CLS] token.
  • CWE Heads: 4 MLP projection heads (LayerNorm -> Dropout -> Linear -> GELU -> Dropout -> Linear -> GELU -> Dropout -> Linear -> L2 Normalization) mapped to the Mean-Pooled token embeddings. These heads project the text into a spherical vector space of 768 dimensions.

Repository Structure & Dependencies

This architecture requires the following specific files included in the repository:

  • config.json: Contains the cvss_map defining the output tensor sizes and label decoders for the CVSS classification heads.
  • cwe_embeddings_new.pkl: A serialized dictionary containing the pre-computed, L2-normalized 768-d vectors for all reference CWE definitions. This file is required for the CWE retrieval process.

Usage & Inference

Below is the standalone Python snippet to download the model, configuration, and the CWE embeddings knowledge base to perform predictions.

import json
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# 1. Define the Custom Multi-Head Architecture
class SecureBERTMultiHead(nn.Module):
    def __init__(self, model_name_or_path):
        super().__init__()
        self.modernbert_config = AutoConfig.from_pretrained(model_name_or_path)
        if hasattr(self.modernbert_config, "reference_compile"):
            self.modernbert_config.reference_compile = False

        self.bert = AutoModel.from_config(self.modernbert_config)
        cvss_map = getattr(self.modernbert_config, "cvss_map", {})

        self.cvss_heads = nn.ModuleDict({
            k: nn.Sequential(
                nn.LayerNorm(768), nn.Dropout(0.1), 
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, len(classes)), nn.Softmax(dim=1)
            ) for k, classes in cvss_map.items()
        })
                
        self.cwe_heads = nn.ModuleDict({
            k: nn.Sequential(
                nn.LayerNorm(768), nn.Dropout(0.1),
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, 768)
            ) for k in['pillar', 'class', 'base', 'variant']
        })

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        cls_emb = out[:, 0, :]
        
        mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
        mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

        res = {}
        for k, head in self.cvss_heads.items():
            res[k] = head(cls_emb)
            
        for k, head in self.cwe_heads.items():
            res[k] = F.normalize(head(mean_emb), p=2, dim=1)
            
        return res

# 2. Inference Wrapper
class VulnRetriever:
    def __init__(self, repo_id):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
        emb_path = hf_hub_download(repo_id=repo_id, filename="cwe_embeddings_new.pkl") 
        
        with open(conf_path, "r", encoding='utf-8') as f:
            self.config = json.load(f)
        
        self.cvss_map = self.config.get("cvss_map", {})
        base_model_name = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.model = SecureBERTMultiHead(repo_id)
        
        self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
        self.model.to(self.device).eval()
        
        with open(emb_path, 'rb') as f:
            self.embeddings_map = pickle.load(f)
            
        self.candidates = {}
        for level, data in self.embeddings_map.items():
            ids = list(data.keys())
            vecs = torch.stack([data[k]['vector'] for k in ids]).to(self.device)
            vecs = F.normalize(vecs, p=2, dim=1)
            self.candidates[level] = {'ids': ids, 'matrix': vecs}

    def predict(self, text, top_k=3):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(self.device)
        
        with torch.no_grad():
            out = self.model(inputs['input_ids'], inputs['attention_mask'])
            
        results = {'cvss': {}, 'cwe': {}}
        
        for task, labels in self.cvss_map.items():
            score, idx = torch.max(out[task], dim=1)
            idx_val = idx.item()
            results['cvss'][task] = {
                'value': labels[idx_val] if idx_val < len(labels) else "Unknown", 
                'confidence': round(score.item(), 4)
            }
            
        for level in ['pillar', 'class', 'base', 'variant']:
            if level not in self.candidates: continue
            
            query_vec = out[level]
            cand_matrix = self.candidates[level]['matrix']
            cand_ids = self.candidates[level]['ids']
            
            scores = torch.matmul(query_vec, cand_matrix.T).squeeze()
            if scores.dim() == 0: scores = scores.unsqueeze(0)
                
            top_scores, top_indices = torch.topk(scores, k=min(top_k, scores.size(0)))
            probs = F.softmax(top_scores, dim=0) 
            
            level_preds =[]
            for score_val, idx in zip(probs, top_indices):
                cwe_id = cand_ids[idx.item()]
                cwe_name = self.embeddings_map[level][cwe_id]['name']
                level_preds.append({
                    'id': int(str(cwe_id).replace('CWE-', '')),
                    'name': cwe_name,
                    'score': round(score_val.item(), 4)
                })
            
            results['cwe'][level] = level_preds
            
        return results

# 3. Execution
if __name__ == "__main__":
    REPO_ID = "YourUsername/Your-Triplet-Repo-Name" 
    
    retriever = VulnRetriever(REPO_ID)
    
    sample_cve = "A buffer overflow in the web server allows remote attackers to execute arbitrary code via a long URL."
    results = retriever.predict(sample_cve)
    
    print(json.dumps(results, indent=2))
Downloads last month
76
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for bziemba/SecureBERT-triplets-new-arch

Finetuned
(3)
this model

Dataset used to train bziemba/SecureBERT-triplets-new-arch