| import math |
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.distributions import Independent, Normal, MultivariateNormal |
| import torch.nn.functional as F |
|
|
| from transformers import AutoModel, AutoModelForCausalLM |
| from tqdm import tqdm |
| from tqdm.notebook import tqdm as tqdm_notebook |
|
|
|
|
| class Res(nn.Module): |
| def __init__(self, H): |
| super().__init__() |
| self.u1 = nn.Linear(H, H) |
| self.u2 = nn.Linear(H, H) |
|
|
| self.v1 = nn.Linear(H, H) |
| self.v2 = nn.Linear(H, H) |
| self.w = nn.Linear(H, H) |
|
|
| def forward(self, x): |
| x = self.w(x) |
| x = x + torch.relu(self.v1(torch.relu(self.u1(x)))) |
| return x + torch.relu(self.v2(torch.relu(self.u2(x)))) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, H, out=None): |
| super().__init__() |
| out = out or H |
| self.mlp = nn.Sequential( |
| nn.Linear(H, H), |
| nn.ReLU(), |
| nn.Linear(H, H), |
| nn.ReLU(), |
| nn.Linear(H, out), |
| ) |
|
|
| def forward(self, x): |
| return self.mlp(x) |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs): |
| super().__init__() |
| self.encoder = AutoModel.from_pretrained(model_name_or_path) |
| self.encoder.resize_token_embeddings(len(tokenizer)) |
| self.dim = self.encoder.config.hidden_size |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def forward(self, **inputs): |
| model_inputs = { |
| k: inputs[k].to(self.device) |
| for k in ("input_ids", "attention_mask") |
| } |
| if inputs.get("token_type_ids", None) is not None: |
| model_inputs["token_type_ids"] = inputs["token_type_ids"].to( |
| self.device |
| ) |
| out = self.encoder(**model_inputs) |
| emb = out.last_hidden_state[:, 0] |
| return emb |
|
|
|
|
| class PrefixDecoder(nn.Module): |
| def __init__( |
| self, |
| tokenizer, |
| model_name_or_path="gpt2", |
| prefix_length=1, |
| ffn="res", |
| **kwargs, |
| ): |
| super().__init__() |
| self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
| self.hidden_dim = D = self.decoder.config.n_embd |
| self.num_layers = L = self.decoder.config.n_layer |
| self.num_heads = H = self.decoder.config.n_head |
| self.prefix_length = K = prefix_length |
| self.lin1 = nn.Linear(D, D * 2) |
| self.z_size = D * L * K * 2 |
| if ffn == "res": |
| self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size)) |
| else: |
| self.mlp = MLP(D, self.z_size) |
|
|
| def get_prefix(self, z): |
| B = z.shape[0] |
| D, L, H, K = ( |
| self.hidden_dim, |
| self.num_layers, |
| self.num_heads, |
| self.prefix_length, |
| ) |
| z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
| keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| layers = tuple( |
| [ |
| (k.squeeze(-1), v.squeeze(-1)) |
| for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| ] |
| ) |
| return layers |
|
|
| def forward(self, z, **inputs): |
| B = z.shape[0] |
| D, L, H, K = ( |
| self.hidden_dim, |
| self.num_layers, |
| self.num_heads, |
| self.prefix_length, |
| ) |
| z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
| keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| layers = tuple( |
| [ |
| (k.squeeze(-1), v.squeeze(-1)) |
| for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| ] |
| ) |
| input_ids = inputs["input_ids"].to(z.device) |
| attention_mask = inputs["attention_mask"].to(z.device) |
| attention_mask = torch.cat( |
| [torch.ones(B, K, dtype=bool, device=z.device), attention_mask], |
| 1, |
| ) |
| out = self.decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=layers, |
| ) |
| return out |
|
|
|
|
| def get_inputs( |
| inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"] |
| ): |
| return {k: inputs.get(f"{prefix}{k}", None) for k in keys} |
|
|
|
|
| class VAE(nn.Module): |
| def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.beta = beta |
| D = decoder.hidden_dim |
| self.lin = nn.Linear(D, D * 2) |
| self.do_sample = do_sample |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, sample=True, **inputs): |
| enc = self.encoder(**get_inputs(inputs, "enc_")) |
| B, D = enc.shape |
| mu, logvar = ( |
| t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1) |
| ) |
| qz = Normal(mu, logvar.exp()) |
| pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0])) |
| kl = torch.distributions.kl_divergence(qz, pz).sum(-1) |
| if sample: |
| z = qz.rsample() |
| else: |
| z = mu |
| return z, kl |
|
|
| def forward(self, **inputs): |
| z, kl = self.get_z(sample=self.do_sample, **inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| out["kl"] = kl |
| return out |
|
|
|
|
| class AAE(nn.Module): |
| def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self._lambda = _lambda |
| dim = decoder.hidden_dim |
| self.D = nn.Sequential( |
| nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| ) |
| self.word_drop = word_drop |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| if self.word_drop is not None: |
| m = inputs["enc_attention_mask"] |
| b = torch.rand_like(m.float()) > self.word_drop |
| inputs["enc_attention_mask"] = m & b |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_adv(self, z): |
| |
| zn = torch.randn_like(z) |
| zeros = torch.zeros(len(z), 1, device=z.device) |
| ones = torch.ones(len(z), 1, device=z.device) |
| loss_d = F.binary_cross_entropy( |
| self.D(z.detach()), zeros, reduction="none" |
| ) + F.binary_cross_entropy(self.D(zn), ones, reduction="none") |
| adv = F.binary_cross_entropy(self.D(z), ones, reduction="none") |
| return loss_d, adv |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["l_rec"] = -log_probs.sum(-1) |
| out["loss_d"], out["adv"] = self.loss_adv(z) |
| return out |
|
|
|
|
| class AE(nn.Module): |
| def __init__(self, encoder, decoder, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| dim = decoder.hidden_dim |
| self.D = nn.Sequential( |
| nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| ) |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def step(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| return z, out |
|
|
| def forward(self, **inputs): |
| z, out = self.step(**inputs) |
| out["loss_c"] = torch.zeros_like(out["loss_r"]) |
| return out |
|
|
|
|
| class CDAE(nn.Module): |
| def __init__( |
| self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs |
| ): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self._lambda = _lambda |
| dim = decoder.hidden_dim |
| self.D = nn.Sequential( |
| nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
| ) |
| self.word_drop = word_drop |
| self.tau = tau |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def do_mask(self, **inputs): |
| m = inputs["enc_attention_mask"] |
| b = torch.rand_like(m.float()) > self.word_drop |
| inputs["enc_attention_mask"] = m & b |
|
|
| B, N = inputs["dec_attention_mask"].shape |
| _, M = m.shape |
| m2 = inputs["dec_attention_mask"] |
| if N <= M: |
| b2 = b[:, :N] |
| else: |
| b_ = torch.rand((B, N - M), device=b.device) > self.word_drop |
| b2 = torch.cat([b, b_], -1) |
| inputs["dec_attention_mask"] = m2 & b2 |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def step(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| return z, out |
|
|
| def loss_c(self, z, z2): |
| scores = -(torch.cdist(z, z2) ** 2) |
| log_probs = (scores / self.tau).log_softmax(-1) |
| loss = -torch.diagonal(log_probs) |
| return loss |
|
|
| def forward(self, **inputs): |
| z, out = self.step(**inputs) |
| self.do_mask(**inputs) |
| z_, out_ = self.step(**inputs) |
| out["loss_r"] = out["loss_r"] + out_["loss_r"] |
| out["loss_c"] = self.loss_c(z, z_) |
| return out |
|
|
|
|
| def run_aae_epoch( |
| model, |
| batches, |
| opt, |
| optD, |
| num_samples=1, |
| lambda_adv=1.0, |
| desc="", |
| notebook=True, |
| ): |
| losses = {k: [] for k in ("l_rec", "adv", "loss_d")} |
| t = ( |
| tqdm_notebook(batches, desc=desc) |
| if notebook |
| else tqdm(batches, desc=desc) |
| ) |
| for batch in t: |
| model_inputs = { |
| k: v.to(model.device) |
| for k, v in batch.items() |
| if type(v) == torch.Tensor |
| } |
| out = model(**model_inputs) |
| loss = (out["l_rec"] + lambda_adv * out["adv"]).sum() |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
|
|
| loss_d = out["loss_d"].sum() |
| optD.zero_grad() |
| loss_d.backward() |
| optD.step() |
|
|
| d = {} |
| for k in ("l_rec", "adv", "loss_d"): |
| d[k] = out[k].mean().item() |
| losses[k].append(out[k].detach().cpu().numpy()) |
| t.set_postfix(d) |
| return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
| class GAE(nn.Module): |
| def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.tau = tau |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_c(self, z, z2): |
| scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
| log_probs = (scores / self.tau).log_softmax(-1) |
| loss = -torch.diagonal(log_probs) |
| return loss |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| out["loss_c"] = self.loss_c(z) |
| return out |
|
|
|
|
| class CAE(nn.Module): |
| def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.tau = tau |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_c(self, z, z2): |
| scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
| log_probs = (scores / self.tau).log_softmax(-1) |
| loss = -torch.diagonal(log_probs) |
| return loss |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| with torch.no_grad(): |
| z2, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| out["loss_c"] = self.loss_c(z, z2) |
| return out |
|
|
|
|
| def run_cae_epoch( |
| model, |
| batches, |
| opt, |
| num_samples=1, |
| lambda_c=1.0, |
| desc="", |
| notebook=True, |
| ): |
| losses = {k: [] for k in ("loss_r", "loss_c")} |
| t = ( |
| tqdm_notebook(batches, desc=desc) |
| if notebook |
| else tqdm(batches, desc=desc) |
| ) |
| model.train() |
| for batch in t: |
| model_inputs = { |
| k: v.to(model.device) |
| for k, v in batch.items() |
| if type(v) == torch.Tensor |
| } |
| out = model(**model_inputs) |
| loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum() |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| d = {} |
| for k in ("loss_r", "loss_c"): |
| d[k] = out[k].mean().item() |
| losses[k].append(out[k].detach().cpu().numpy()) |
| t.set_postfix(d) |
| return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
| def batch_kl(l1, s1, l2=None, s2=None): |
| |
| return |
|
|
|
|
| class SubpopCondAE(nn.Module): |
| def __init__( |
| self, |
| encoder, |
| decoder, |
| num_labels, |
| sublabels=4, |
| tau=0.05, |
| disc_loss=True, |
| **kwargs, |
| ): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.dim = dim = decoder.hidden_dim |
| self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim)) |
| self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim)) |
| self.num_labels = num_labels |
| self.sublabels = sublabels |
| self.L = num_labels * sublabels |
| self.tau = tau |
| self.disc_loss = disc_loss |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_c(self, z, **inputs): |
| scores = [] |
| for i in range(self.L): |
| dist = Independent( |
| Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
| ) |
| scores.append(dist.log_prob(z)) |
| B = z.shape[0] |
| sub_log_probs = torch.stack(scores, -1) |
| if self.disc_loss: |
| sub_log_probs = sub_log_probs.log_softmax(-1) |
| log_probs = sub_log_probs.view( |
| B, self.num_labels, self.num_sublabels |
| ).logsumexp(-1) |
| loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| acc = log_probs.argmax(-1) == inputs["label"] |
| return { |
| "loss_c": loss, |
| "log_probs": log_probs, |
| "sub_log_probs": sub_log_probs, |
| "acc": acc.float(), |
| } |
|
|
| def get_kl(self): |
| p = MultivariateNormal( |
| torch.zeros(self.dim, device=self.device), |
| torch.eye(self.dim, device=self.device), |
| ) |
| kl = 0 |
| for i in range(self.L): |
| q = MultivariateNormal( |
| self.locs[i], torch.diag(self.log_scales[i].exp()) |
| ) |
| kl += torch.distributions.kl_divergence(q, p) |
| return kl |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| out_c = self.loss_c(z, **inputs) |
| for k, v in out_c.items(): |
| out[k] = v |
| out["kl"] = self.get_kl().unsqueeze(0) |
| return out |
|
|
|
|
| def gaussian_prob_product(m1, s1, m2, s2, rho=1.0): |
| |
| s1_inv = 1 / s1 |
| s2_inv = 1 / s2 |
| s_hat = 1 / (s1 + s2) |
| m_hat = s1_inv * s1 + s2_inv * s2 |
| dim = m1.shape[-1] |
| return ( |
| ((2 * math.pi) ** ((1 - 2 * rho) * dim / 2)) |
| * (rho ** (-dim / 2)) |
| * torch.sqrt(s_hat.prod(-1)) |
| * ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2)) |
| * torch.exp( |
| -(1 / rho) |
| * ( |
| m1 @ (s1_inv * m1).T |
| + m2 @ (s2_inv * m2).T |
| - m_hat @ (s_hat * m_hat).T |
| ) |
| ) |
| ) |
|
|
|
|
| class CondAE(nn.Module): |
| def __init__( |
| self, |
| encoder, |
| decoder, |
| num_labels, |
| logdet=False, |
| l2_reg=False, |
| disc_loss=True, |
| tau=0.05, |
| **kwargs, |
| ): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.dim = dim = decoder.hidden_dim |
| self.locs = nn.Parameter(torch.randn(num_labels, dim)) |
| self.log_scales = nn.Parameter(torch.zeros(num_labels, dim)) |
| self.num_labels = num_labels |
| self.tau = tau |
| self.logdet = logdet |
| self.l2_reg = l2_reg |
| self.disc_loss = disc_loss |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_c(self, z, **inputs): |
| scores = [] |
| for i in range(self.num_labels): |
| dist = Independent( |
| Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
| ) |
| scores.append(dist.log_prob(z)) |
| log_probs = torch.stack(scores, -1) |
| if self.disc_loss: |
| log_probs = log_probs.log_softmax(-1) |
| loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| acc = log_probs.argmax(-1) == inputs["label"] |
| return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
|
|
| def get_kl(self): |
| p = MultivariateNormal( |
| torch.zeros(self.dim, device=self.device), |
| torch.eye(self.dim, device=self.device), |
| ) |
| kl = 0 |
| for i in range(self.num_labels): |
| q = MultivariateNormal( |
| self.locs[i], torch.diag(self.log_scales[i].exp()) |
| ) |
| kl += torch.distributions.kl_divergence(q, p) |
| if self.logdet: |
| K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
| kl += torch.logdet(K) |
| elif self.l2_reg: |
| K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
| kl += torch.log( |
| torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2 |
| ).sum() |
| return kl |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| out_c = self.loss_c(z, **inputs) |
| for k, v in out_c.items(): |
| out[k] = v |
| out["kl"] = self.get_kl().unsqueeze(0) |
| return out |
|
|
|
|
| class BasicCondAE(nn.Module): |
| def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.dim = dim = decoder.hidden_dim |
| self.linear = nn.Linear(dim, num_labels) |
| self.num_labels = num_labels |
| self.tau = tau |
|
|
| @property |
| def device(self): |
| return self.encoder.device |
|
|
| def get_z(self, **inputs): |
| return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
| def loss_c(self, z, **inputs): |
| log_probs = self.linear(z).log_softmax(-1) |
| loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
| acc = log_probs.argmax(-1) == inputs["label"] |
| return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
|
|
| def forward(self, **inputs): |
| z, _ = self.get_z(**inputs) |
| out = self.decoder(z, **get_inputs(inputs, "dec_")) |
| b, n, _ = out["logits"].shape |
| log_probs = out["logits"].log_softmax(-1) |
| log_probs = torch.gather( |
| log_probs[:, :-1], |
| -1, |
| inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
| ).squeeze(-1) |
| log_probs = log_probs.masked_fill( |
| ~inputs["dec_attention_mask"][:, 1:], 0 |
| ) |
| out["loss_r"] = -log_probs.sum(-1) |
| out_c = self.loss_c(z, **inputs) |
| for k, v in out_c.items(): |
| out[k] = v |
| out["kl"] = torch.zeros_like(out["loss_r"]) |
| return out |
|
|
|
|
| def run_cond_ae_epoch( |
| model, |
| batches, |
| opt, |
| num_samples=1, |
| lambda_c=1.0, |
| lambda_r=1.0, |
| beta=1.0, |
| desc="", |
| notebook=True, |
| ): |
| losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
| t = ( |
| tqdm_notebook(batches, desc=desc) |
| if notebook |
| else tqdm(batches, desc=desc) |
| ) |
| model.train() |
| for batch in t: |
| model_inputs = { |
| k: v.to(model.device) |
| for k, v in batch.items() |
| if type(v) == torch.Tensor |
| } |
| out = model(**model_inputs) |
| loss = ( |
| lambda_r * out["loss_r"] + lambda_c * out["loss_c"] |
| ).sum() + beta * out["kl"].sum() |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| d = {} |
| for k in ("loss_r", "loss_c", "kl", "acc"): |
| d[k] = out[k].mean().item() |
| losses[k].append(out[k].detach().cpu().numpy()) |
| t.set_postfix(d) |
| return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
| def run_cond_ae_eval( |
| model, |
| batches, |
| lambda_c=1.0, |
| beta=1.0, |
| desc="", |
| notebook=True, |
| ): |
| losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
| t = ( |
| tqdm_notebook(batches, desc=desc) |
| if notebook |
| else tqdm(batches, desc=desc) |
| ) |
| model.eval() |
| for batch in t: |
| model_inputs = { |
| k: v.to(model.device) |
| for k, v in batch.items() |
| if type(v) == torch.Tensor |
| } |
| with torch.no_grad(): |
| out = model(**model_inputs) |
| loss = ( |
| out["loss_r"] + lambda_c * out["loss_c"] |
| ).sum() + beta * out["kl"].sum() |
| d = {} |
| for k in ("loss_r", "loss_c", "kl", "acc"): |
| d[k] = out[k].mean().item() |
| losses[k].append(out[k].detach().cpu().numpy()) |
| t.set_postfix(d) |
| return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
| def generate( |
| model, |
| tokenizer, |
| batch=None, |
| z=None, |
| do_sample=False, |
| max_length=128, |
| **kwargs, |
| ): |
| if z is None: |
| with torch.no_grad(): |
| z, _ = model.get_z(sample=False, **batch) |
| B, D = z.shape |
| else: |
| z = torch.tensor(z, device=model.device) |
| B, D = z.shape |
| D, L, H, K = ( |
| model.decoder.hidden_dim, |
| model.decoder.num_layers, |
| model.decoder.num_heads, |
| model.decoder.prefix_length, |
| ) |
| z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2) |
| keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
| layers = tuple( |
| [ |
| (k.squeeze(-1), v.squeeze(-1)) |
| for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
| ] |
| ) |
| output = model.decoder.decoder.generate( |
| input_ids=torch.tensor( |
| [[tokenizer.bos_token_id]] * B, device=model.device |
| ), |
| attention_mask=torch.ones((B, K + 1), device=model.device), |
| past=layers, |
| do_sample=do_sample, |
| max_length=max_length, |
| **kwargs, |
| ) |
| lst = tokenizer.batch_decode(output[:, 1:]) |
| return [l.replace("<|endoftext|>", "") for l in lst] |
|
|
|
|
| def get_embeddings(model, batches, desc="", notebook=True): |
| out = [] |
| t = ( |
| tqdm_notebook(batches, desc=desc) |
| if notebook |
| else tqdm(batches, desc=desc) |
| ) |
| model.eval() |
| for batch in t: |
| with torch.no_grad(): |
| model_inputs = { |
| k: v.to(model.device) |
| for k, v in batch.items() |
| if type(v) == torch.Tensor |
| } |
| z, _ = model.get_z(sample=False, **model_inputs) |
| out.append(z.detach().cpu().numpy()) |
| return np.concatenate(out, 0) |
|
|
|
|
| def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs): |
| z = np.stack( |
| [l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0 |
| ) |
| return generate(model, tokenizer, z=z, **kwargs) |
|
|