| """ |
| SPLADE Model for HuggingFace Hub |
| Adapted from: https://github.com/naver/splade |
| """ |
|
|
| import torch |
| from transformers import AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
| class SpladeConfig(PretrainedConfig): |
| """Configuration class for SPLADE model""" |
| model_type = "splade" |
| |
| def __init__( |
| self, |
| base_model="neuralmind/bert-base-portuguese-cased", |
| aggregation="max", |
| fp16=True, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.base_model = base_model |
| self.aggregation = aggregation |
| self.fp16 = fp16 |
|
|
|
|
| class Splade(PreTrainedModel): |
| """ |
| SPLADE model for sparse retrieval. |
| |
| This model produces sparse representations by: |
| 1. Using a MLM head to get vocabulary-sized logits |
| 2. Applying log(1 + ReLU(logits)) |
| 3. Max-pooling over sequence length |
| |
| Usage: |
| from transformers import AutoTokenizer |
| from modeling_splade import Splade |
| |
| model = Splade.from_pretrained("AxelPCG/splade-pt-br") |
| tokenizer = AutoTokenizer.from_pretrained("AxelPCG/splade-pt-br") |
| |
| # Encode query |
| query_tokens = tokenizer("Qual é a capital do Brasil?", return_tensors="pt") |
| with torch.no_grad(): |
| query_vec = model(q_kwargs=query_tokens)["q_rep"] |
| """ |
| config_class = SpladeConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| |
| base_model = getattr(config, 'base_model', 'neuralmind/bert-base-portuguese-cased') |
| self.transformer = AutoModelForMaskedLM.from_pretrained(base_model) |
| self.aggregation = getattr(config, 'aggregation', 'max') |
| self.fp16 = getattr(config, 'fp16', True) |
| |
| def encode(self, tokens): |
| """Encode tokens to sparse representation""" |
| |
| out = self.transformer(**tokens) |
| logits = out.logits |
| |
| |
| relu_log = torch.log1p(torch.relu(logits)) |
| |
| |
| attention_mask = tokens["attention_mask"].unsqueeze(-1) |
| masked = relu_log * attention_mask |
| |
| |
| if self.aggregation == "max": |
| values, _ = torch.max(masked, dim=1) |
| return values |
| else: |
| return torch.sum(masked, dim=1) |
| |
| def forward(self, q_kwargs=None, d_kwargs=None, **kwargs): |
| """ |
| Forward pass supporting both query and document encoding. |
| |
| Args: |
| q_kwargs: Query tokens (dict with input_ids, attention_mask) |
| d_kwargs: Document tokens (dict with input_ids, attention_mask) |
| **kwargs: Additional arguments (for compatibility) |
| |
| Returns: |
| dict with 'q_rep' and/or 'd_rep' keys containing sparse vectors |
| """ |
| output = {} |
| |
| if q_kwargs is not None: |
| output["q_rep"] = self.encode(q_kwargs) |
| |
| if d_kwargs is not None: |
| output["d_rep"] = self.encode(d_kwargs) |
| |
| |
| if not output and kwargs: |
| output["rep"] = self.encode(kwargs) |
| |
| return output |
|
|