File size: 6,064 Bytes
5fc9222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""CASSANDRA model definitions.

This module defines the LabelAttentionClassifier — a CTI-BERT encoder with
per-label attention queries — used in the CASSANDRA paper (anonymous
submission to ACM CCS 2026).

The architecture is custom (not derived from PreTrainedModel), so loading
weights requires this module. The convenience function `load_seed()` wraps
the standard pattern: load encoder, instantiate classifier, restore
state_dict from safetensors, attach tokenizer.

Usage:
    from modeling import load_seed
    model, tokenizer, config = load_seed("seeds/seed-42")
    model.eval()
    inputs = tokenizer(["The malware uses Registry Run Keys for persistence."],
                       return_tensors="pt", truncation=True, max_length=512)
    logits = model(**inputs).logits
    probs = torch.sigmoid(logits)
    preds = [config["labels"][i] for i in (probs[0] >= 0.5).nonzero(as_tuple=True)[0]]
"""
from __future__ import annotations

import json
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput


class LabelAttentionClassifier(nn.Module):
    """CTI-BERT + per-label attention queries.

    Each ATT&CK technique gets a learned 768-dim query vector that attends
    over the encoder's last_hidden_state. The attended representation is
    classified by a shared 1-output linear head, yielding one logit per
    technique. This replaces the standard CLS -> Linear head, removing the
    shared-representation bottleneck for multi-label classification with
    many rare classes.
    """

    def __init__(self, encoder, num_labels: int, dropout: float = 0.1):
        super().__init__()
        self.encoder = encoder
        hidden = encoder.config.hidden_size
        self.num_labels = num_labels
        self.label_queries = nn.Parameter(torch.randn(num_labels, hidden) * 0.02)
        self.classifier = nn.Linear(hidden, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        h = outputs.last_hidden_state                                  # [B, S, H]
        attn = torch.matmul(self.label_queries.unsqueeze(0),
                            h.transpose(1, 2))                          # [B, L, S]
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(1).expand_as(attn)
            attn = attn.masked_fill(mask == 0, -1e9)
        weights = F.softmax(attn, dim=-1)                              # [B, L, S]
        reps = torch.matmul(weights, h)                                 # [B, L, H]
        reps = self.dropout(reps)
        logits = self.classifier(reps).squeeze(-1)                     # [B, L]
        return SequenceClassifierOutput(logits=logits)


def load_seed(seed_dir: str, device: str = "cpu") -> Tuple[LabelAttentionClassifier, AutoTokenizer, dict]:
    """Load a single CASSANDRA seed directory.

    Args:
        seed_dir: path containing model.safetensors, config.json, tokenizer.* files.
        device: 'cpu' or 'cuda'.

    Returns:
        (model, tokenizer, config) — config is the parsed config.json dict.
    """
    with open(os.path.join(seed_dir, "config.json")) as f:
        config = json.load(f)

    encoder = AutoModel.from_pretrained(config["encoder_model_name"])
    model = LabelAttentionClassifier(
        encoder,
        num_labels=config["num_labels"],
        dropout=config.get("dropout", 0.1),
    )

    state_dict = load_file(os.path.join(seed_dir, "model.safetensors"), device=device)
    model.load_state_dict(state_dict)
    model.to(device)

    tokenizer = AutoTokenizer.from_pretrained(seed_dir)
    return model, tokenizer, config


def load_ensemble(seed_dirs, device: str = "cpu"):
    """Load multiple seeds for ensemble inference.

    Returns a list of (model, tokenizer, config) tuples. The tokenizer + config
    are identical across seeds in a configuration, but returned per-seed for
    convenience.
    """
    return [load_seed(d, device=device) for d in seed_dirs]


def predict_ensemble(seeds, sentences, threshold: float = 0.5, max_length: int = 512,
                     batch_size: int = 32):
    """Average sigmoid probabilities across seeds, threshold to predicted labels.

    Args:
        seeds: list returned by load_ensemble().
        sentences: list of strings.
        threshold: per-class probability cutoff. The paper's headline numbers
            use either tau=0.5 (uniform) or a dev-tuned tau (see model card).
        max_length: tokenizer max_length (paper used 512).
        batch_size: per-model inference batch size.

    Returns:
        list of (sentence, [predicted_label_ids]) tuples.
    """
    if not seeds:
        raise ValueError("seeds is empty")

    all_probs = None
    config = seeds[0][2]
    tokenizer = seeds[0][1]
    device = next(seeds[0][0].parameters()).device

    for model, _, _ in seeds:
        model.eval()
        seed_probs = []
        with torch.no_grad():
            for i in range(0, len(sentences), batch_size):
                batch = sentences[i:i + batch_size]
                enc = tokenizer(batch, return_tensors="pt", truncation=True,
                                max_length=max_length, padding=True).to(device)
                logits = model(input_ids=enc["input_ids"],
                               attention_mask=enc["attention_mask"]).logits
                seed_probs.append(torch.sigmoid(logits).cpu())
        seed_probs = torch.cat(seed_probs, dim=0)
        all_probs = seed_probs if all_probs is None else all_probs + seed_probs

    avg_probs = all_probs / len(seeds)
    labels = config["labels"]
    out = []
    for i, sentence in enumerate(sentences):
        ids = [labels[j] for j in range(len(labels)) if avg_probs[i, j].item() >= threshold]
        out.append((sentence, ids))
    return out