| |
| |
| ''' |
| @license: (C) Copyright 2025, Hey. |
| @author: Hey |
| @email: sanyuan.hy@alibaba-inc.com |
| @tel: 137****6540 |
| @datetime: 2025/12/30 11:35 |
| @project: lucaone |
| @file: modeling_lucaone |
| @desc: modeling_lucaone |
| ''' |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutput |
| from transformers.modeling_outputs import MaskedLMOutput |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from transformers.modeling_outputs import TokenClassifierOutput |
| from typing import Optional, List, Union, Tuple |
| from .configuration_lucaone import LucaGPLMConfig |
| try: |
| from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
| class LucaGPLM1bLayerNorm(_FusedLayerNorm): |
| @torch.jit.unused |
| def forward(self, x): |
| if not x.is_cuda: |
| return super().forward(x) |
| else: |
| with torch.cuda.device(x.device): |
| return super().forward(x) |
| except ImportError: |
| from torch.nn import LayerNorm as LucaGPLM1bLayerNorm |
|
|
| def gelu(x): |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def apply_rotary_pos_emb(x, cos, sin): |
| cos = cos[:, : x.shape[-2], :] |
| sin = sin[:, : x.shape[-2], :] |
| return (x * cos) + (rotate_half(x) * sin) |
|
|
| class LucaGPLMRotaryEmbedding(torch.nn.Module): |
| def __init__(self, dim: int, *_, **__): |
| super().__init__() |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| self._seq_len_cached = None |
| self._cos_cached = None |
| self._sin_cached = None |
|
|
| def _update_cos_sin_tables(self, x, seq_dimension=1): |
| seq_len = x.shape[seq_dimension] |
|
|
| if (seq_len != self._seq_len_cached or |
| self._cos_cached is None or |
| self._sin_cached is None or |
| self._cos_cached.device != x.device): |
| self._seq_len_cached = seq_len |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
|
| self._cos_cached = emb.cos()[None, :, :] |
| self._sin_cached = emb.sin()[None, :, :] |
|
|
| return self._cos_cached, self._sin_cached |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) |
|
|
| return ( |
| apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| ) |
|
|
| class LucaGPLMGlobalMaskWeightedAttentionPooling1D(nn.Module): |
| def __init__(self, embed_size, use_bias=False): |
| super(LucaGPLMGlobalMaskWeightedAttentionPooling1D, self).__init__() |
| self.embed_size = embed_size |
| self.use_bias = use_bias |
|
|
| self.W = nn.Parameter(torch.Tensor(self.embed_size)) |
| nn.init.trunc_normal_(self.W, std=0.01) |
| if self.use_bias: |
| self.b = nn.Parameter(torch.Tensor(1)) |
| nn.init.trunc_normal_(self.b, std=0.01) |
|
|
| def forward(self, x, mask=None): |
| |
| logits = torch.matmul(x, self.W) |
| if self.use_bias: |
| logits += self.b |
|
|
| if mask is not None: |
| attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000) |
| else: |
| attention_probs = nn.Softmax(dim=-1)(logits) |
| x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
| return x |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + ' (' + str(self.embed_size) + (', bias=%r)' % self.use_bias) |
|
|
| class LucaGPLMGlobalMaskContextAttentionPooling1D(nn.Module): |
| def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
| super(LucaGPLMGlobalMaskContextAttentionPooling1D, self).__init__() |
| self.embed_size = embed_size |
| self.use_additive_bias = use_additive_bias |
| self.use_attention_bias = use_attention_bias |
| self.units = units if units else embed_size |
|
|
| self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| if self.use_additive_bias: |
| self.b1 = nn.Parameter(torch.Tensor(self.units)) |
| nn.init.trunc_normal_(self.b1, std=0.01) |
| if self.use_attention_bias: |
| self.b2 = nn.Parameter(torch.Tensor(1)) |
| nn.init.trunc_normal_(self.b2, std=0.01) |
|
|
| self.c = nn.Parameter(torch.Tensor(self.units)) |
|
|
| nn.init.trunc_normal_(self.U, std=0.01) |
| nn.init.trunc_normal_(self.V, std=0.01) |
| nn.init.trunc_normal_(self.c, std=0.01) |
|
|
| def forward(self, x, mask=None): |
| |
| q = torch.matmul(x, self.U) |
| k = torch.matmul(x, self.V) |
| if self.use_additive_bias: |
| h = torch.tanh(q + k + self.b1) |
| else: |
| h = torch.tanh(q + k) |
|
|
| if self.use_attention_bias: |
| e = torch.matmul(h, self.c) + self.b2 |
| else: |
| e = torch.matmul(h, self.c) |
| if mask is not None: |
| attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000) |
| else: |
| attention_probs = nn.Softmax(dim=-1)(e) |
| x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1) |
| return x |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
|
|
| class LucaGPLMGlobalMaskValueAttentionPooling1D(nn.Module): |
| def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False): |
| super(LucaGPLMGlobalMaskValueAttentionPooling1D, self).__init__() |
| self.embed_size = embed_size |
| self.use_additive_bias = use_additive_bias |
| self.use_attention_bias = use_attention_bias |
| self.units = units if units else embed_size |
|
|
| self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units)) |
| if self.use_additive_bias: |
| self.b1 = nn.Parameter(torch.Tensor(self.units)) |
| nn.init.trunc_normal_(self.b1, std=0.01) |
| if self.use_attention_bias: |
| self.b2 = nn.Parameter(torch.Tensor(self.embed_size)) |
| nn.init.trunc_normal_(self.b2, std=0.01) |
|
|
| self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size)) |
|
|
| nn.init.trunc_normal_(self.U, std=0.01) |
| nn.init.trunc_normal_(self.V, std=0.01) |
| nn.init.trunc_normal_(self.W, std=0.01) |
|
|
| def forward(self, x, mask=None): |
| |
| q = torch.matmul(x, self.U) |
| k = torch.matmul(x, self.V) |
| if self.use_additive_bias: |
| h = torch.tanh(q + k + self.b1) |
| else: |
| h = torch.tanh(q + k) |
|
|
| |
| if self.use_attention_bias: |
| e = torch.matmul(h, self.W) + self.b2 |
| else: |
| e = torch.matmul(h, self.W) |
| if mask is not None: |
| attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1)) |
| else: |
| attention_probs = nn.Softmax(dim=1)(e) |
| x = torch.sum(attention_probs * x, dim=1) |
| return x |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.units) + ', bias=(%r, %r))' % (self.use_additive_bias, self.use_attention_bias) |
|
|
| class LucaGPLM1LayerNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-12, affine=True): |
| super().__init__() |
| self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) |
| self.eps = eps |
| self.affine = bool(affine) |
| if self.affine: |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) |
| else: |
| self.weight, self.bias = None, None |
|
|
| def forward(self, x): |
| dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) |
| means = x.mean(dims, keepdim=True) |
| x_zeromean = x - means |
| variances = x_zeromean.pow(2).mean(dims, keepdim=True) |
| x = x_zeromean / torch.sqrt(variances + self.eps) |
| if self.affine: |
| x = (self.weight * x) + self.bias |
| return x |
|
|
| class LucaGPLMMultiheadAttention(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| kdim=None, |
| vdim=None, |
| dropout=0.0, |
| bias=True, |
| add_bias_kv: bool = False, |
| add_zero_attn: bool = False, |
| self_attention: bool = False, |
| encoder_decoder_attention: bool = False, |
| use_rotary_embeddings: bool = False, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.head_dim = embed_dim // num_heads |
| assert ( |
| self.head_dim * num_heads == self.embed_dim |
| ), "embed_dim must be divisible by num_heads" |
| self.scaling = self.head_dim**-0.5 |
|
|
| self.self_attention = self_attention |
| self.encoder_decoder_attention = encoder_decoder_attention |
|
|
| assert not self.self_attention or self.qkv_same_dim, ( |
| "Self-attention requires query, key and " "value to be of the same size" |
| ) |
|
|
| self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
| self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| if add_bias_kv: |
| self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| else: |
| self.bias_k = self.bias_v = None |
|
|
| self.add_zero_attn = add_zero_attn |
|
|
| self.reset_parameters() |
|
|
| self.rot_emb = None |
| if use_rotary_embeddings: |
| self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
| |
| if self.out_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.0) |
| if self.bias_k is not None: |
| nn.init.xavier_normal_(self.bias_k) |
| if self.bias_v is not None: |
| nn.init.xavier_normal_(self.bias_v) |
|
|
| def forward( |
| self, |
| query, |
| key: Optional[torch.Tensor] = None, |
| value: Optional[torch.Tensor] = None, |
| key_padding_mask: Optional[torch.Tensor] = None, |
| need_weights: bool = True, |
| attn_mask: Optional[torch.Tensor] = None, |
| need_head_weights: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| if need_head_weights: |
| need_weights = True |
|
|
| tgt_len, bsz, embed_dim = query.size() |
| assert embed_dim == self.embed_dim |
|
|
| if self.self_attention: |
| q = self.q_proj(query) |
| k = self.k_proj(query) |
| v = self.v_proj(query) |
| else: |
| assert key is not None and value is not None |
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
|
|
| q *= self.scaling |
|
|
| if self.bias_k is not None: |
| assert self.bias_v is not None |
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
| if attn_mask is not None: |
| attn_mask = torch.cat( |
| [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
| ) |
| if key_padding_mask is not None: |
| key_padding_mask = torch.cat( |
| [ |
| key_padding_mask, |
| key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
| ], |
| dim=1, |
| ) |
|
|
| q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| if k is not None: |
| k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| if v is not None: |
| v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
| assert k is not None |
| src_len = k.size(1) |
|
|
| if self.rot_emb: |
| q, k = self.rot_emb(q, k) |
|
|
| attn_weights = torch.bmm(q, k.transpose(1, 2)) |
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
| if attn_mask is not None: |
| attn_mask = attn_mask.unsqueeze(0) |
| attn_weights += attn_mask |
|
|
| if key_padding_mask is not None: |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| attn_weights_float = F.softmax(attn_weights, dim=-1) |
| attn_weights = attn_weights_float.type_as(attn_weights) |
| attn_probs = F.dropout( |
| attn_weights_float.type_as(attn_weights), |
| p=self.dropout, |
| training=self.training, |
| ) |
|
|
| assert v is not None |
| attn = torch.bmm(attn_probs, v) |
| assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
| attn = self.out_proj(attn) |
|
|
| attn_weights_output: Optional[torch.Tensor] = None |
| if need_weights: |
| attn_weights_output = attn_weights_float.view( |
| bsz, self.num_heads, tgt_len, src_len |
| ).type_as(attn).transpose(1, 0) |
| if not need_head_weights: |
| |
| attn_weights_output = attn_weights_output.mean(dim=0) |
|
|
| return attn, attn_weights_output |
|
|
| class LucaGPLMMultiheadAttentionWithSDPA(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| kdim=None, |
| vdim=None, |
| dropout=0.0, |
| bias=True, |
| add_bias_kv: bool = False, |
| add_zero_attn: bool = False, |
| self_attention: bool = False, |
| encoder_decoder_attention: bool = False, |
| use_rotary_embeddings: bool = True, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.head_dim = embed_dim // num_heads |
| assert ( |
| self.head_dim * num_heads == self.embed_dim |
| ), "embed_dim must be divisible by num_heads" |
| self.scaling = self.head_dim**-0.5 |
|
|
| self.self_attention = self_attention |
| self.encoder_decoder_attention = encoder_decoder_attention |
|
|
| assert not self.self_attention or self.qkv_same_dim, ( |
| "Self-attention requires query, key and " "value to be of the same size" |
| ) |
|
|
| self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) |
| self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
| if add_bias_kv: |
| self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| else: |
| self.bias_k = self.bias_v = None |
|
|
| self.add_zero_attn = add_zero_attn |
|
|
| self.reset_parameters() |
|
|
| self.rot_emb = None |
| if use_rotary_embeddings: |
| self.rot_emb = LucaGPLMRotaryEmbedding(dim=self.head_dim) |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu")) |
| nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu")) |
| |
| if self.out_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.0) |
| if self.bias_k is not None: |
| nn.init.xavier_normal_(self.bias_k) |
| if self.bias_v is not None: |
| nn.init.xavier_normal_(self.bias_v) |
|
|
| def forward( |
| self, |
| query, |
| key: Optional[torch.Tensor] = None, |
| value: Optional[torch.Tensor] = None, |
| key_padding_mask: Optional[torch.Tensor] = None, |
| need_weights: bool = True, |
| attn_mask: Optional[torch.Tensor] = None, |
| need_head_weights: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| |
| tgt_len, bsz, embed_dim = query.size() |
| assert embed_dim == self.embed_dim |
|
|
| if self.self_attention: |
| q = self.q_proj(query) |
| k = self.k_proj(query) |
| v = self.v_proj(query) |
| else: |
| assert key is not None and value is not None |
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
|
|
| if self.bias_k is not None: |
| assert self.bias_v is not None |
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
| if attn_mask is not None: |
| attn_mask = torch.cat( |
| [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
| ) |
| if key_padding_mask is not None: |
| key_padding_mask = torch.cat( |
| [ |
| key_padding_mask, |
| key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
| ], |
| dim=1, |
| ) |
|
|
| |
| |
| |
| |
| if not need_head_weights and hasattr(F, "scaled_dot_product_attention"): |
| |
| |
| q_sdpa = q.view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
| k_sdpa = k.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
| v_sdpa = v.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) |
|
|
| |
| if self.rot_emb: |
| |
| |
| q_sdpa, k_sdpa = self.rot_emb(q_sdpa, k_sdpa) |
|
|
| |
| |
| |
| sdpa_mask = None |
| if attn_mask is not None or key_padding_mask is not None: |
| |
| target_shape = (bsz, 1, tgt_len, k_sdpa.size(2)) |
| sdpa_mask = torch.zeros(target_shape, device=q.device, dtype=q.dtype) |
| |
| if key_padding_mask is not None: |
| |
| sdpa_mask = sdpa_mask.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
| float("-inf") |
| ) |
| |
| if attn_mask is not None: |
| if attn_mask.dim() == 2: |
| sdpa_mask = sdpa_mask + attn_mask.unsqueeze(0).unsqueeze(0) |
| elif attn_mask.dim() == 3: |
| pass |
| else: |
| sdpa_mask = sdpa_mask + attn_mask |
|
|
| |
| |
| attn_output = F.scaled_dot_product_attention( |
| q_sdpa, |
| k_sdpa, |
| v_sdpa, |
| attn_mask=sdpa_mask, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=False |
| ) |
|
|
| |
| |
| attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(tgt_len, bsz, self.embed_dim) |
| |
| |
| attn_output = self.out_proj(attn_output) |
|
|
| |
| return attn_output, None |
|
|
| q = q * self.scaling |
| |
| |
| |
| |
| q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| if k is not None: |
| k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
| if v is not None: |
| v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
|
| assert k is not None |
| src_len = k.size(1) |
|
|
| if self.rot_emb: |
| q, k = self.rot_emb(q, k) |
|
|
| attn_weights = torch.bmm(q, k.transpose(1, 2)) |
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
| if attn_mask is not None: |
| attn_mask = attn_mask.unsqueeze(0) |
| attn_weights += attn_mask |
|
|
| if key_padding_mask is not None: |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| attn_weights_float = F.softmax(attn_weights, dim=-1) |
| attn_weights = attn_weights_float.type_as(attn_weights) |
| attn_probs = F.dropout( |
| attn_weights_float.type_as(attn_weights), |
| p=self.dropout, |
| training=self.training, |
| ) |
|
|
| assert v is not None |
| attn = torch.bmm(attn_probs, v) |
| assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
| attn = self.out_proj(attn) |
|
|
| attn_weights_output: Optional[torch.Tensor] = None |
| if need_weights: |
| attn_weights_output = attn_weights_float.view( |
| bsz, self.num_heads, tgt_len, src_len |
| ).type_as(attn).transpose(1, 0) |
| if not need_head_weights: |
| |
| attn_weights_output = attn_weights_output.mean(dim=0) |
|
|
| return attn, attn_weights_output |
|
|
| class LucaGPLMRobertaLMHead(nn.Module): |
| def __init__(self, embed_dim, output_dim): |
| super().__init__() |
| self.dense = nn.Linear(embed_dim, embed_dim) |
| self.layer_norm = LucaGPLM1bLayerNorm(embed_dim) |
| |
| self.decoder = nn.Linear(embed_dim, output_dim, bias=False) |
| self.bias = nn.Parameter(torch.zeros(output_dim)) |
|
|
| def forward(self, features): |
| x = self.dense(features) |
| x = gelu(x) |
| x = self.layer_norm(x) |
| |
| |
| x = self.decoder(x) + self.bias |
| return x |
|
|
| class LucaGPLMTransformerLayer(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| ffn_embed_dim, |
| attention_heads, |
| add_bias_kv=True, |
| use_lucagplm1b_layer_norm=False, |
| use_rotary_embeddings: bool=True, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.ffn_embed_dim = ffn_embed_dim |
| self.attention_heads = attention_heads |
| self.use_rotary_embeddings = use_rotary_embeddings |
| |
| LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm |
|
|
| self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
|
|
| self.self_attn = LucaGPLMMultiheadAttentionWithSDPA( |
| self.embed_dim, |
| self.attention_heads, |
| add_bias_kv=add_bias_kv, |
| add_zero_attn=False, |
| self_attention=True, |
| use_rotary_embeddings=self.use_rotary_embeddings, |
| ) |
|
|
| |
| self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim) |
|
|
| |
| self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) |
|
|
| |
| self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) |
|
|
| def forward( |
| self, |
| x, |
| self_attn_mask=None, |
| self_attn_padding_mask=None, |
| need_head_weights=False |
| ): |
| residual = x |
| x = self.pre_layer_norm(x) |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=self_attn_padding_mask, |
| need_weights=True, |
| need_head_weights=need_head_weights, |
| attn_mask=self_attn_mask, |
| ) |
| x = residual + x |
|
|
| residual = x |
| x = self.post_layer_norm(x) |
| x = gelu(self.fc1(x)) |
| x = self.fc2(x) |
| x = residual + x |
|
|
| return x, attn |
|
|
| class LucaGPLMEmbeddings(nn.Module): |
| def __init__(self, config: LucaGPLMConfig): |
| super().__init__() |
| |
| |
| self.no_position_embeddings = getattr(config, 'no_position_embeddings', False) |
| self.no_token_type_embeddings = getattr(config, 'no_token_type_embeddings', False) |
| self.use_embed_layer_norm = getattr(config, 'use_embed_layer_norm', True) |
| self.embed_scale = getattr(config, 'embed_scale', 1.0) |
| self.token_dropout = getattr(config, 'token_dropout', False) |
| |
| |
| self.mask_idx = getattr(config, 'mask_token_id', 4) |
| self.padding_idx = getattr(config, 'pad_token_id', 0) |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| |
| |
| if not self.no_position_embeddings: |
| self.embed_pos = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| else: |
| self.embed_pos = None |
| |
| |
| if not self.no_token_type_embeddings: |
| self.embed_type = nn.Embedding(config.type_vocab_size, config.hidden_size) |
| else: |
| self.embed_type = None |
| |
| |
| if self.use_embed_layer_norm: |
| self.embed_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
| else: |
| self.embed_layer_norm = None |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| input_shape = input_ids.size() |
| seq_length = input_shape[1] |
|
|
| |
| inputs_embeds = self.embed_scale * self.embed_tokens(input_ids) |
| |
| |
| if not self.no_position_embeddings and self.embed_pos is not None: |
| if position_ids is None: |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand(input_shape) |
| position_embeddings = self.embed_scale * self.embed_pos(position_ids) |
| inputs_embeds = inputs_embeds + position_embeddings |
|
|
| |
| if not self.no_token_type_embeddings and self.embed_type is not None: |
| if token_type_ids is None: |
| token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
| token_type_embeddings = self.embed_scale * self.embed_type(token_type_ids) |
| inputs_embeds = inputs_embeds + token_type_embeddings |
| |
| |
| if self.use_embed_layer_norm and self.embed_layer_norm is not None: |
| embeddings = self.embed_layer_norm(inputs_embeds) |
| else: |
| embeddings = inputs_embeds |
|
|
| |
| if self.token_dropout and self.training: |
| |
| embeddings = embeddings.masked_fill((input_ids == self.mask_idx).unsqueeze(-1), 0.0) |
| |
| |
| mask_ratio_train = 0.15 * 0.8 |
| padding_mask = input_ids.eq(self.padding_idx) |
| src_lengths = (~padding_mask).sum(-1) |
| mask_ratio_observed = (input_ids == self.mask_idx).sum(-1).to(embeddings.dtype) / src_lengths |
| embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
|
|
| |
| padding_mask = input_ids.eq(self.padding_idx) |
| if padding_mask.any(): |
| embeddings = embeddings * (1 - padding_mask.unsqueeze(-1).type_as(embeddings)) |
|
|
| return embeddings |
|
|
| class LucaGPLMEncoder(nn.Module): |
| def __init__(self, config: LucaGPLMConfig): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList([ |
| LucaGPLMTransformerLayer( |
| config.hidden_size, |
| 4 * config.hidden_size, |
| config.num_attention_heads, |
| add_bias_kv=False, |
| use_lucagplm1b_layer_norm=True, |
| use_rotary_embeddings=True, |
| ) |
| for _ in range(config.num_hidden_layers) |
| ]) |
| |
| self.use_last_layer_norm = getattr(config, 'use_last_layer_norm', True) |
| if self.use_last_layer_norm: |
| self.last_layer_norm = LucaGPLM1bLayerNorm(config.hidden_size) |
| else: |
| self.last_layer_norm = None |
|
|
| self.padding_idx = config.pad_token_id |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| need_head_weights: bool = False, |
| repr_layers: Optional[List[int]] = None, |
| use_last_layer_norm: bool = True, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| |
| if repr_layers is None: |
| repr_layers = [-1] |
| |
| |
| layer_size = len(self.layers) |
| repr_layers = [(i + layer_size + 1) % (layer_size + 1) for i in repr_layers] |
| repr_layers = set(repr_layers) |
| hidden_representations = {} |
|
|
| |
| if attention_mask is None: |
| padding_mask = hidden_states.new_zeros(hidden_states.shape[:2]).eq(self.padding_idx) |
| else: |
| |
| padding_mask = attention_mask.eq(0) |
|
|
| |
| if 0 in repr_layers: |
| hidden_representations[0] = hidden_states |
|
|
| |
| hidden_states = hidden_states.transpose(0, 1) |
| |
| if not padding_mask.any(): |
| padding_mask = None |
|
|
| |
| if need_head_weights or output_attentions: |
| attn_weights = [] |
|
|
| for layer_idx, layer_module in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.__call__, |
| hidden_states, |
| None, |
| padding_mask, |
| need_head_weights or output_attentions, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| self_attn_mask=None, |
| self_attn_padding_mask=padding_mask, |
| need_head_weights=need_head_weights or output_attentions, |
| ) |
|
|
| hidden_states, attn = layer_outputs |
|
|
| if (layer_idx + 1) in repr_layers: |
| hidden_representations[layer_idx + 1] = hidden_states.transpose(0, 1) |
|
|
| if need_head_weights or output_attentions: |
| |
| attn_weights.append(attn.transpose(1, 0)) |
|
|
| |
| if self.last_layer_norm is not None and use_last_layer_norm: |
| hidden_states = self.last_layer_norm(hidden_states) |
|
|
| |
| hidden_states = hidden_states.transpose(0, 1) |
|
|
| |
| if (layer_idx + 1) in repr_layers: |
| hidden_representations[layer_idx + 1] = hidden_states |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if need_head_weights or output_attentions: |
| |
| if attn_weights: |
| |
| all_attentions = torch.stack(attn_weights, 1) |
| if padding_mask is not None: |
| attention_mask_expanded = 1 - padding_mask.type_as(all_attentions) |
| attention_mask_expanded = attention_mask_expanded.unsqueeze(1) * attention_mask_expanded.unsqueeze(2) |
| all_attentions = all_attentions * attention_mask_expanded[:, None, None, :, :] |
| |
| if not output_attentions: |
| all_attentions = None |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) |
|
|
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
| class LucaGPLMPreTrainedModel(PreTrainedModel): |
| config_class = LucaGPLMConfig |
| base_model_prefix = "lucaone" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LucaGPLMTransformerLayer"] |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, (LucaGPLM1LayerNorm, LucaGPLM1bLayerNorm)): |
| if hasattr(module, 'weight') and module.weight is not None: |
| module.weight.data.fill_(1.0) |
| if hasattr(module, 'bias') and module.bias is not None: |
| module.bias.data.zero_() |
|
|
| class LucaGPLMModel(LucaGPLMPreTrainedModel): |
| """ |
| The LucaGPLM model for extracting sequence representations and optionally predicting contacts. |
| Based on the original LucaGPLM implementation but restructured to use modern transformers architecture. |
| """ |
| |
| def __init__(self, config: LucaGPLMConfig): |
| super().__init__(config) |
| self.config = config |
| self.embeddings = LucaGPLMEmbeddings(self.config) |
| self.encoder = LucaGPLMEncoder(self.config) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.embed_tokens = value |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_contacts: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| need_head_weights: Optional[bool] = None, |
| repr_layers: Optional[List[int]] = None, |
| use_last_layer_norm: Optional[bool] = True, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| |
| output_attentions = output_attentions if output_attentions is not None else getattr(self.config, 'output_attentions', False) |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, 'output_hidden_states', False) |
| return_contacts = return_contacts if return_contacts is not None else False |
| return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| need_head_weights = need_head_weights if need_head_weights is not None else return_contacts |
| use_last_layer_norm = use_last_layer_norm if use_last_layer_norm is not None else True |
|
|
| |
| if return_contacts: |
| output_attentions = True |
| need_head_weights = True |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| batch_size, seq_length = input_shape |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones(input_shape, device=device) |
|
|
| |
| if inputs_embeds is None: |
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| ) |
| else: |
| embedding_output = inputs_embeds |
|
|
| |
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| need_head_weights=need_head_weights, |
| repr_layers=repr_layers, |
| use_last_layer_norm=use_last_layer_norm, |
| ) |
| |
| sequence_output = encoder_outputs[0] |
| |
| |
| contacts = None |
| if return_contacts and encoder_outputs.attentions is not None: |
| |
| |
| attentions = encoder_outputs.attentions |
| |
| averaged_attention = attentions.mean(dim=(1, 2)) |
| contacts = (averaged_attention + averaged_attention.transpose(-1, -2)) / 2 |
| |
| |
| if attention_mask is not None: |
| |
| seq_lens = attention_mask.sum(dim=1) |
| |
|
|
| if not return_dict: |
| outputs = (sequence_output, ) + encoder_outputs[1:] |
| if contacts is not None: |
| outputs = outputs + (contacts,) |
| return outputs |
|
|
| |
| output = BaseModelOutput( |
| last_hidden_state=sequence_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
| |
| |
| if contacts is not None: |
| output.contacts = contacts |
| |
| return output |
|
|
| class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.lucaone = LucaGPLMModel(config) |
|
|
| |
| self.lm_head = LucaGPLMRobertaLMHead( |
| embed_dim=config.hidden_size, |
| output_dim=config.vocab_size |
| ) |
| self._tied_weights_keys = [ |
| "lucaone.embeddings.embed_tokens.weight", |
| "lm_head.decoder.weight" |
| ] |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.lucaone.get_input_embeddings() |
|
|
| def get_output_embeddings(self): |
| return self.lm_head.decoder |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head.decoder = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, MaskedLMOutput]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.lucaone( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| |
| prediction_scores = self.lm_head(sequence_output) |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + outputs[2:] |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=masked_lm_loss, |
| logits=prediction_scores, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel): |
| def __init__(self, config): |
| if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
| config.num_labels = config.classifier_num_labels |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.task_level = config.task_level |
| self.task_type = config.task_type |
| assert self.task_level == "seq_level" |
| self.classifier_pooling_type = config.classifier_pooling_type |
| self.classifier_loss_type = config.classifier_loss_type |
| self.classifier_loss_reduction = config.classifier_loss_reduction |
| self.classifier_pos_weight = config.classifier_pos_weight |
| self.classifier_weight = config.classifier_weight |
| self.lucaone = LucaGPLMModel(config) |
| if self.classifier_pooling_type == "value_attention": |
| self.pooler = LucaGPLMGlobalMaskValueAttentionPooling1D(config.hidden_size) |
| elif self.classifier_pooling_type == "context_attention": |
| self.pooler = LucaGPLMGlobalMaskContextAttentionPooling1D(embed_size=config.hidden_size) |
| elif self.classifier_pooling_type == "weighted_attention": |
| self.pooler = LucaGPLMGlobalMaskWeightedAttentionPooling1D(embed_size=config.hidden_size) |
| else: |
| self.pooler = None |
| self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
|
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| if self.task_type == "multi_class": |
| weight = None |
| if self.classifier_weight: |
| if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
| weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_weight, float): |
| weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_weight, list): |
| weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
| self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
| elif self.task_type == "binary_class": |
| pos_weight = None |
| if self.classifier_pos_weight: |
| if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| elif isinstance(self.classifier_pos_weight, float): |
| pos_weight = torch.tensor([self.classifier_pos_weight], dtype=torch.float32) |
| self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
| elif self.task_type == "regression": |
| if self.classifier_loss_type == "mae": |
| self.loss_fct = nn.L1Loss(reduction="mean") |
| else: |
| self.loss_fct = nn.MSELoss(reduction="mean") |
| elif self.task_type == "multi_label": |
| pos_weight = None |
| if self.classifier_pos_weight: |
| if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_pos_weight, float): |
| pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
| self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
| else: |
| raise ValueError("Invalid task type: %s" % self.task_type) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| token_type_ids=None, |
| attention_mask=None, |
| labels=None, |
| return_dict=None |
| ): |
| return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| outputs = self.lucaone( |
| input_ids, |
| token_type_ids=token_type_ids, |
| attention_mask=attention_mask, |
| return_dict=return_dict |
| ) |
| if self.pooler is not None: |
| pooled_output = self.pooler(outputs[0]) |
| elif self.classifier_pooling_type == "cls": |
| |
| pooled_output = outputs[0][:, 0, :] |
| elif self.classifier_pooling_type == "mean": |
| pooled_output = outputs[0].mean(dim=1) |
| else: |
| raise ValueError("Invalid classifier pooling type: %s" % self.classifier_pooling_type) |
|
|
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.task_type == "multi_class": |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.task_type == "binary_class": |
| loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| elif self.task_type == "regression": |
| loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
| elif self.task_type == "multi_label": |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
| else: |
| raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput(loss=loss, logits=logits) |
|
|
| class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel): |
| def __init__(self, config): |
| if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0: |
| config.num_labels = config.classifier_num_labels |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.task_level = config.task_level |
| self.task_type = config.task_type |
| assert self.task_level == "token_level" |
| self.classifier_pooling_type = config.classifier_pooling_type |
| self.classifier_loss_type = config.classifier_loss_type |
| self.classifier_loss_reduction = config.classifier_loss_reduction |
| self.classifier_pos_weight = config.classifier_pos_weight |
| self.classifier_weight = config.classifier_weight |
| self.lucaone = LucaGPLMModel(config) |
| self.dropout = nn.Dropout(config.classifier_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| if self.task_type == "multi_class": |
| weight = None |
| if self.classifier_weight: |
| |
| if isinstance(self.classifier_weight, str) or isinstance(self.classifier_weight, int): |
| weight = torch.tensor([float(self.classifier_weight)] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_weight, float): |
| weight = torch.tensor([self.classifier_weight] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_weight, list): |
| weight = torch.tensor(self.classifier_weight, dtype=torch.float32) |
| self.loss_fct = nn.CrossEntropyLoss(weight=weight, reduction="mean") |
| elif self.task_type == "binary_class": |
| pos_weight = None |
| if self.classifier_pos_weight: |
| if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| elif isinstance(self.classifier_pos_weight, float): |
| pos_weight = torch.tensor([float(self.classifier_pos_weight)], dtype=torch.float32) |
| self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
| elif self.task_type == "regression": |
| if self.classifier_loss_type == "mae": |
| self.loss_fct = nn.L1Loss(reduction="mean") |
| else: |
| self.loss_fct = nn.MSELoss(reduction="mean") |
| elif self.task_type == "multi_label": |
| pos_weight = None |
| if self.classifier_pos_weight: |
| if isinstance(self.classifier_pos_weight, str) or isinstance(self.classifier_pos_weight, int): |
| pos_weight = torch.tensor([float(self.classifier_pos_weight)] * self.num_labels, dtype=torch.float32) |
| elif isinstance(self.classifier_pos_weight, float): |
| pos_weight = torch.tensor([self.classifier_pos_weight] * self.num_labels, dtype=torch.float32) |
| self.loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=self.classifier_loss_reduction) |
| else: |
| raise ValueError("Invalid task type: %s" % self.task_type) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| token_type_ids=None, |
| attention_mask=None, |
| labels=None, |
| return_dict=None |
| ): |
| return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', True) |
| outputs = self.lucaone( |
| input_ids, |
| token_type_ids=token_type_ids, |
| attention_mask=attention_mask, |
| return_dict=return_dict |
| ) |
| sequence_output = outputs[0][:, 1:-1, :] |
|
|
| sequence_output = self.dropout(sequence_output) |
| logits = self.classifier(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.task_type == "multi_class": |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.task_type == "binary_class": |
| loss = self.loss_fct(logits.view(-1), labels.view(-1).float()) |
| elif self.task_type == "regression": |
| loss = self.loss_fct(logits.view(-1), labels.view(-1)) |
| elif self.task_type == "multi_label": |
| loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels).float()) |
| else: |
| raise ValueError("Invalid task type: %s" % self.task_type) |
|
|
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
| return TokenClassifierOutput(loss=loss, logits=logits) |
|
|
| __all__ = [ |
| "LucaGPLMModel", |
| "LucaGPLMPreTrainedModel", |
| "LucaGPLMForMaskedLM", |
| "LucaGPLMForSequenceClassification", |
| "LucaGPLMForTokenClassification" |
| ] |
|
|