File size: 6,064 Bytes
5fc9222 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """CASSANDRA model definitions.
This module defines the LabelAttentionClassifier — a CTI-BERT encoder with
per-label attention queries — used in the CASSANDRA paper (anonymous
submission to ACM CCS 2026).
The architecture is custom (not derived from PreTrainedModel), so loading
weights requires this module. The convenience function `load_seed()` wraps
the standard pattern: load encoder, instantiate classifier, restore
state_dict from safetensors, attach tokenizer.
Usage:
from modeling import load_seed
model, tokenizer, config = load_seed("seeds/seed-42")
model.eval()
inputs = tokenizer(["The malware uses Registry Run Keys for persistence."],
return_tensors="pt", truncation=True, max_length=512)
logits = model(**inputs).logits
probs = torch.sigmoid(logits)
preds = [config["labels"][i] for i in (probs[0] >= 0.5).nonzero(as_tuple=True)[0]]
"""
from __future__ import annotations
import json
import os
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
class LabelAttentionClassifier(nn.Module):
"""CTI-BERT + per-label attention queries.
Each ATT&CK technique gets a learned 768-dim query vector that attends
over the encoder's last_hidden_state. The attended representation is
classified by a shared 1-output linear head, yielding one logit per
technique. This replaces the standard CLS -> Linear head, removing the
shared-representation bottleneck for multi-label classification with
many rare classes.
"""
def __init__(self, encoder, num_labels: int, dropout: float = 0.1):
super().__init__()
self.encoder = encoder
hidden = encoder.config.hidden_size
self.num_labels = num_labels
self.label_queries = nn.Parameter(torch.randn(num_labels, hidden) * 0.02)
self.classifier = nn.Linear(hidden, 1)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
h = outputs.last_hidden_state # [B, S, H]
attn = torch.matmul(self.label_queries.unsqueeze(0),
h.transpose(1, 2)) # [B, L, S]
if attention_mask is not None:
mask = attention_mask.unsqueeze(1).expand_as(attn)
attn = attn.masked_fill(mask == 0, -1e9)
weights = F.softmax(attn, dim=-1) # [B, L, S]
reps = torch.matmul(weights, h) # [B, L, H]
reps = self.dropout(reps)
logits = self.classifier(reps).squeeze(-1) # [B, L]
return SequenceClassifierOutput(logits=logits)
def load_seed(seed_dir: str, device: str = "cpu") -> Tuple[LabelAttentionClassifier, AutoTokenizer, dict]:
"""Load a single CASSANDRA seed directory.
Args:
seed_dir: path containing model.safetensors, config.json, tokenizer.* files.
device: 'cpu' or 'cuda'.
Returns:
(model, tokenizer, config) — config is the parsed config.json dict.
"""
with open(os.path.join(seed_dir, "config.json")) as f:
config = json.load(f)
encoder = AutoModel.from_pretrained(config["encoder_model_name"])
model = LabelAttentionClassifier(
encoder,
num_labels=config["num_labels"],
dropout=config.get("dropout", 0.1),
)
state_dict = load_file(os.path.join(seed_dir, "model.safetensors"), device=device)
model.load_state_dict(state_dict)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(seed_dir)
return model, tokenizer, config
def load_ensemble(seed_dirs, device: str = "cpu"):
"""Load multiple seeds for ensemble inference.
Returns a list of (model, tokenizer, config) tuples. The tokenizer + config
are identical across seeds in a configuration, but returned per-seed for
convenience.
"""
return [load_seed(d, device=device) for d in seed_dirs]
def predict_ensemble(seeds, sentences, threshold: float = 0.5, max_length: int = 512,
batch_size: int = 32):
"""Average sigmoid probabilities across seeds, threshold to predicted labels.
Args:
seeds: list returned by load_ensemble().
sentences: list of strings.
threshold: per-class probability cutoff. The paper's headline numbers
use either tau=0.5 (uniform) or a dev-tuned tau (see model card).
max_length: tokenizer max_length (paper used 512).
batch_size: per-model inference batch size.
Returns:
list of (sentence, [predicted_label_ids]) tuples.
"""
if not seeds:
raise ValueError("seeds is empty")
all_probs = None
config = seeds[0][2]
tokenizer = seeds[0][1]
device = next(seeds[0][0].parameters()).device
for model, _, _ in seeds:
model.eval()
seed_probs = []
with torch.no_grad():
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i + batch_size]
enc = tokenizer(batch, return_tensors="pt", truncation=True,
max_length=max_length, padding=True).to(device)
logits = model(input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"]).logits
seed_probs.append(torch.sigmoid(logits).cpu())
seed_probs = torch.cat(seed_probs, dim=0)
all_probs = seed_probs if all_probs is None else all_probs + seed_probs
avg_probs = all_probs / len(seeds)
labels = config["labels"]
out = []
for i, sentence in enumerate(sentences):
ids = [labels[j] for j in range(len(labels)) if avg_probs[i, j].item() >= threshold]
out.append((sentence, ids))
return out
|