| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple |
|
|
| from transformers import PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from .configuration_roberta_zinc_compression_encoder import RZCompressionConfig |
|
|
| |
| def pairwise_cosine(x: torch.Tensor) -> torch.Tensor: |
| x = F.normalize(x, p=2, dim=-1) |
| return x @ x.t() |
|
|
| |
| def drop_diag(M: torch.Tensor) -> torch.Tensor: |
| n = M.size(0) |
| return M.masked_select(~torch.eye(n, dtype=torch.bool, device=M.device)).view(n, n - 1) |
|
|
| |
| def rowwise_pearson(ref: torch.Tensor, comp: torch.Tensor, rm_diag: bool=True) -> torch.Tensor: |
| if rm_diag: |
| ref = drop_diag(ref) |
| comp = drop_diag(comp) |
| ref_z = F.normalize(ref - ref.mean(dim=1, keepdim=True), p=2, dim=1) |
| cmp_z = F.normalize(comp - comp.mean(dim=1, keepdim=True), p=2, dim=1) |
| return 1 - (ref_z * cmp_z).sum(dim=1).mean() |
|
|
| |
| def compute_losses( |
| embedding: torch.Tensor, |
| compressed: Dict[int, torch.Tensor], |
| recon_stack: torch.Tensor | None, |
| cfg, |
| ) -> tuple[torch.Tensor, dict[str, float]]: |
| """Return (total_loss, terms_dict)""" |
| device = embedding.device |
| loss_total = torch.zeros((), device=device) |
| terms: dict[str, float] = {} |
|
|
| |
| with torch.no_grad(): |
| base_sims = pairwise_cosine(embedding) |
| ranks = base_sims.argsort(-1, descending=True) |
|
|
| |
| |
| |
| for size, z in compressed.items(): |
| tag = f"cmp{size}" |
| comp_sims = pairwise_cosine(z) |
|
|
| |
| if cfg.mse_loss_weight: |
| mse = F.mse_loss(drop_diag(base_sims), drop_diag(comp_sims)) |
| loss_total += cfg.mse_loss_weight * mse |
| terms[f"{tag}_mse"] = mse.detach() |
|
|
| |
| if cfg.mse_loss_weight and cfg.topk_values: |
| tk_vals = [] |
| for k in cfg.topk_values: |
| idx = ranks[:, 1 : k + 1] |
| ref_k = torch.gather(base_sims, 1, idx) |
| cmp_k = torch.gather(comp_sims, 1, idx) |
| tk_mse = F.mse_loss(ref_k, cmp_k) |
| tk_vals.append(tk_mse) |
| terms[f"{tag}_top{k}"] = tk_mse.detach() |
| tk_agg = torch.stack(tk_vals).mean() |
| loss_total += cfg.topk_mse_loss_weight * tk_agg |
| terms[f"{tag}_topk_mean"] = tk_agg.detach() |
|
|
| |
| if cfg.pearson_loss_weight: |
| pr = rowwise_pearson(base_sims, comp_sims) |
| loss_total += cfg.pearson_loss_weight * pr |
| terms[f"{tag}_pearson"] = pr.detach() |
|
|
| if cfg.pearson_loss_weight and cfg.topk_values: |
| pr_vals = [] |
| for k in cfg.topk_values: |
| idx = ranks[:, 1 : k + 1] |
| ref_k = torch.gather(base_sims, 1, idx) |
| cmp_k = torch.gather(comp_sims, 1, idx) |
| pr = rowwise_pearson(ref_k, cmp_k, rm_diag=False) |
| pr_vals.append(pr) |
| terms[f"{tag}_pearson_top{k}"] = pr.detach() |
| pr_agg = torch.stack(pr_vals).sum() |
| loss_total += cfg.pearson_loss_weight * pr_agg |
|
|
| |
| |
| |
| if recon_stack is not None: |
| |
| if cfg.decoder_cosine_weight: |
| cos_loss = 1 - F.cosine_similarity( |
| recon_stack, |
| embedding.unsqueeze(1).expand_as(recon_stack), |
| dim=-1, |
| ).mean() |
| loss_total += cfg.decoder_cosine_weight * cos_loss |
| terms["dec_cosine"] = cos_loss.detach() |
|
|
| return loss_total, terms |
|
|
|
|
| |
| class FeedForward(nn.Module): |
| def __init__(self, d_in: int, d_out: int): |
| super().__init__() |
| self.fc1 = nn.Linear(d_in, d_out * 2) |
| self.fc2 = nn.Linear(d_out, d_out) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x1, x2 = x.chunk(2, dim=-1) |
| return self.fc2(F.silu(x1) * x2) |
|
|
| class FeedForwardLayer(nn.Module): |
| def __init__( |
| self, d_in: int, d_out: int, dropout: float = 0.1, layer_norm_eps: Optional[float] = 1e-12 |
| ): |
| super().__init__() |
| self.ff = FeedForward(d_in, d_out) |
| self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity() |
| self.dropout = nn.Dropout(dropout) |
| self.norm = ( |
| nn.LayerNorm(d_out, eps=layer_norm_eps) |
| if layer_norm_eps is not None else nn.Identity() |
| ) |
|
|
| def forward(self, x): |
| y = self.ff(self.dropout(x)) + self.skip(x) |
| return self.norm(y) |
|
|
|
|
| |
| class CompressionModel(nn.Module): |
| """ |
| Encoder β (optional) Decoder. |
| """ |
|
|
| def __init__( |
| self, |
| d_in: int, |
| d_comp: int, |
| encoder_layers: int, |
| decoder_layers: int, |
| dropout: float, |
| layer_norm_eps: Optional[float], |
| ): |
| super().__init__() |
|
|
| enc_layers: List[nn.Module] = [] |
| for i in range(encoder_layers): |
| last = i == encoder_layers - 1 |
| enc_layers.append( |
| FeedForwardLayer( |
| d_in, |
| d_comp if last else d_in, |
| dropout if not last else 0.0, |
| None if last else layer_norm_eps, |
| ) |
| ) |
| self.encoder = nn.Sequential(*enc_layers) |
|
|
| |
| dec_layers: List[nn.Module] = [] |
| for i in range(decoder_layers): |
| last = i == decoder_layers - 1 |
| d_prev = d_comp if i==0 else d_in |
| dec_layers.append( |
| FeedForwardLayer( |
| d_prev, |
| d_in, |
| dropout if not last else 0.0, |
| None if last else layer_norm_eps, |
| ) |
| ) |
| self.decoder = nn.Sequential(*dec_layers) if dec_layers else None |
|
|
| def forward(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| z = self.encoder(x) |
| x_recon = self.decoder(z) if self.decoder is not None else None |
| return z, x_recon |
|
|
|
|
| |
| @dataclass |
| class RZCompressionOutput(ModelOutput): |
| loss: torch.FloatTensor |
| loss_terms: Dict[str, torch.Tensor] | None = None |
| compressed: Dict[int, torch.FloatTensor] | None = None |
| reconstructed: torch.FloatTensor | None = None |
|
|
| class RZCompressionModel(PreTrainedModel): |
| config_class = RZCompressionConfig |
|
|
| def __init__(self, config: RZCompressionConfig): |
| super().__init__(config) |
|
|
| self.compressors = nn.ModuleDict( |
| { |
| str(size): CompressionModel( |
| d_in=config.input_size, |
| d_comp=size, |
| encoder_layers=config.encoder_layers, |
| decoder_layers=config.decoder_layers, |
| dropout=config.dropout, |
| layer_norm_eps=config.layer_norm_eps, |
| ) |
| for size in config.compression_sizes |
| } |
| ) |
|
|
| self.post_init() |
|
|
| def get_encoders(self, unpack_single=False): |
| encoders = {} |
| for k,v in self.compressors.items(): |
| v = v.encoder |
| if len(v)==1 and unpack_single: |
| |
| v = v[0] |
| encoders[k] = v |
| encoders = nn.ModuleDict(encoders) |
| return encoders |
| |
| def save_encoders(self, path, unpack_single=False): |
| encoders = self.get_encoders(unpack_single) |
| torch.save(encoders.state_dict(), path) |
| |
| def compress(self, |
| inputs: torch.Tensor, |
| compression_sizes: List[int]): |
| compressed = {d: self.compressors[str(d)].encoder(inputs) for d in compression_sizes} |
| return compressed |
| |
| def forward(self, embedding, return_dict=True, compute_loss=True): |
| |
| compressed, recons = {}, [] |
| for size, module in self.compressors.items(): |
| z, rec = module(embedding) |
| compressed[int(size)] = z |
| if rec is not None: |
| recons.append(rec) |
| recon_stack = torch.stack(recons, dim=1) if recons else None |
|
|
| |
| if compute_loss: |
| loss_total, terms = compute_losses(embedding, compressed, recon_stack, self.config) |
| else: |
| loss_total, terms = torch.zeros((), device=embedding.device), {} |
|
|
| if not return_dict: |
| return compressed, recon_stack, loss_total, terms |
|
|
| return RZCompressionOutput( |
| loss=loss_total, |
| loss_terms=terms, |
| compressed=compressed, |
| reconstructed=recon_stack, |
| ) |