| import sys |
| import os |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
| |
| |
| from .yaml_util import MyLoader |
| from dataclasses import dataclass |
| from transformers import BertModel, BertConfig, PretrainedConfig |
| from typing import Optional, Union |
|
|
|
|
| @dataclass |
| class FoundationOutput: |
| loss: torch.Tensor = None |
| logits: torch.Tensor = None |
| num_output: torch.Tensor = None |
| est_err_output: torch.Tensor = None |
| hidden_states: torch.Tensor = None |
| masked_loss: torch.Tensor = None |
| num_loss: torch.Tensor = None |
| est_err_loss: torch.Tensor = None |
|
|
|
|
| @dataclass |
| class FoundationBertConfig: |
| vocab_size: int |
| hidden_size: int |
| num_hidden_layers: int |
| num_attention_heads: int |
| intermediate_size: int |
| hidden_dropout_prob: float |
| attention_probs_dropout_prob: float |
| pad_token_id: int |
| classifier_dropout: float |
| max_position_embeddings: int |
| contrastive_temperature: float |
| loss_weights: dict |
| use_xval_loss: bool = True |
| use_mlm_loss: bool = True |
| use_regression_loss: bool = False |
| use_contrastive_loss: bool = False |
| transform_numeric: bool = False |
|
|
| def to_dict(self): |
| return {k: getattr(self, k) for k in self.__dataclass_fields__.keys()} |
|
|
| class FoundationBert(BertModel): |
| def __init__(self, |
| config: FoundationBertConfig = None, |
| use_mlm_loss: bool = False, |
| use_regression_loss: bool = True, |
| use_contrastive_loss: bool = False, |
| use_xval_loss: bool = False, |
| transform_numeric: bool = False, |
| *args, |
| **kwargs): |
| self.gconfig = config |
| |
| bert_conf = BertConfig( |
| vocab_size=config.vocab_size, |
| hidden_size=config.hidden_size, |
| num_hidden_layers=config.num_hidden_layers, |
| num_attention_heads=config.num_attention_heads, |
| intermediate_size=config.intermediate_size, |
| hidden_dropout_prob=config.hidden_dropout_prob, |
| attention_probs_dropout_prob=config.attention_probs_dropout_prob, |
| pad_token_id=config.pad_token_id, |
| max_position_embeddings=config.max_position_embeddings, |
| _attn_implementation='sdpa' |
| ) |
| self.gconfig.transform_numeric = transform_numeric |
| super().__init__(bert_conf,) |
| try: |
| if not self.gconfig.use_mlm_loss and not self.gconfig.use_regression_loss and not self.gconfig.use_contrastive_loss: |
| raise ValueError("At least one loss must be enabled") |
| self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) |
| except: |
| self.gconfig.use_mlm_loss = use_mlm_loss |
| self.gconfig.use_regression_loss = use_regression_loss |
| self.gconfig.use_contrastive_loss = use_contrastive_loss |
| self.gconfig.use_xval_loss = use_xval_loss |
| self.loss_mod = float(self.gconfig.use_mlm_loss) + float(self.gconfig.use_regression_loss) + float(self.gconfig.use_contrastive_loss) + float(self.gconfig.use_xval_loss) |
|
|
| self.dataset_path = kwargs.get('dataset_path', None) |
|
|
| self.modalities = kwargs['modalities'] |
| self.mask_token = kwargs['mask_token'] |
|
|
| self.scalar_keys = [ |
| 'redshift', |
| 'halo_mass', |
| 'stellar_mass', |
| ] |
| self.vector_keys = [ |
| 'SED', |
| 'SFH', |
| 'mag_{band}_spherex', |
| 'mag_{band}_lsst', |
| ] |
| self.modalscalars = [m if m in self.vector_keys else 'scalars' for m in self.modalities] |
| self.modalscalars = list(dict.fromkeys(self.modalscalars)) |
|
|
| |
|
|
| self.embedding = torch.nn.ModuleDict() |
| self.num_head = torch.nn.ModuleDict() |
| |
| for modality in self.modalscalars: |
| self.embedding[modality] = torch.nn.Linear(1, config.hidden_size) |
| self.num_head[modality] = torch.nn.Sequential( |
| torch.nn.Linear(config.hidden_size, config.hidden_size), |
| torch.nn.LayerNorm(config.hidden_size), |
| torch.nn.GELU(), |
| torch.nn.Linear(config.hidden_size, config.hidden_size // 2), |
| torch.nn.GELU(), |
| torch.nn.Linear(config.hidden_size // 2, 1) |
| ) |
|
|
| self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.embed_dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
|
| self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.xval_loss = torch.nn.MSELoss(reduction='none') |
| |
| self.distributed_loss = False |
|
|
| @classmethod |
| def from_pretrained(self, |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| *model_args, |
| config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| cache_dir: Optional[Union[str, os.PathLike]] = None, |
| ignore_mismatched_sizes: bool = False, |
| force_download: bool = False, |
| local_files_only: bool = False, |
| token: Optional[Union[str, bool]] = None, |
| revision: str = "main", |
| use_safetensors: bool = None, |
| **kwargs, |
| ): |
| from huggingface_hub import hf_hub_download |
| |
| try: |
| model_config = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="train_config.yaml", |
| revision=kwargs.get("revision", "main") |
| ) |
| except Exception as e: |
| model_config = os.path.join(pretrained_model_name_or_path, "train_config.yaml") |
|
|
| |
| with open(model_config, 'r') as f: |
| config = yaml.load(f, Loader=MyLoader) |
|
|
| kwargs['modalities'] = config['modalities'] |
| kwargs['dataset_path'] = config['dataset_path'] |
| kwargs['mask_token'] = config['mask_token'] |
| |
| return super().from_pretrained( |
| pretrained_model_name_or_path, |
| **config['model_config'], |
| **kwargs |
| ) |
|
|
| def pool_output(self, |
| embeddings: torch.Tensor, |
| attention_mask: torch.Tensor, |
| use_last: bool = False |
| ) -> torch.Tensor: |
| """Average pool the hidden states using the attention mask. |
| |
| Parameters |
| ---------- |
| embeddings : torch.Tensor |
| The hidden states to pool (B, SeqLen, HiddenDim). |
| attention_mask : torch.Tensor |
| The attention mask for the hidden states (B, SeqLen). |
| |
| Returns |
| ------- |
| torch.Tensor |
| The pooled embeddings (B, HiddenDim). |
| """ |
| |
| sl_mod = 1 if use_last else 2 |
| seq_lengths = attention_mask.sum(axis=1) |
| |
| new_attention = attention_mask.clone() |
| new_attention[:, 0] = attention_mask[:,0] * 0 |
| new_attention[:, seq_lengths - sl_mod] = 0 * attention_mask[:, seq_lengths - sl_mod] |
|
|
| |
| pool_mask = new_attention.unsqueeze(-1).expand(embeddings.shape).to(embeddings.device) |
| |
| |
| sum_embeds = torch.sum(embeddings * pool_mask, 1) |
| |
| |
| seq_lengths = torch.clamp(seq_lengths, min=1).unsqueeze(-1) |
| |
| return sum_embeds / seq_lengths |
|
|
|
|
| def last_token_pool( |
| self, |
| embeddings: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Pool the last hidden states using the attention mask. |
| |
| Parameters |
| ---------- |
| embeddings : torch.Tensor |
| The last hidden states to pool (B, SeqLen, HiddenDim). |
| attention_mask : torch.Tensor |
| The attention mask for the hidden states (B, SeqLen). |
| |
| Returns |
| ------- |
| torch.Tensor |
| The pooled embeddings (B, HiddenDim). |
| """ |
| left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
| if left_padding: |
| return embeddings[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = embeddings.shape[0] |
| return embeddings[ |
| torch.arange(batch_size, device=embeddings.device), |
| sequence_lengths, |
| ] |
|
|
| def forward(self, inputs, return_input_label_mapping=False): |
| """ |
| Forward pass that computes predictions for each modality. |
| |
| Args: |
| input_label_mapping (dict): A dictionary containing inputs and labels for different modalities. |
| |
| Returns: |
| outputs (dict): A dictionary containing the logits and error logits for each modality. |
| """ |
|
|
| |
| input_label_mapping = {} |
| combined = [] |
| for src_modality in self.modalscalars: |
| |
| input_label_mapping[src_modality] = { |
| 'input': inputs[f"input_{src_modality}"], |
| 'labels': inputs[f"labels_{src_modality}"] |
| } |
|
|
| input_data = input_label_mapping[src_modality]['input'] |
| label = input_label_mapping[src_modality]['labels'] |
| input_data = torch.where(label, self.mask_token, input_data) |
|
|
| x = self.embedding[src_modality](input_data.unsqueeze(-1)) |
| x = torch.nn.functional.silu(x) |
| combined.append(x) |
|
|
| combined = torch.cat(combined, dim=1) |
| |
| self.position_ids = torch.arange(combined.size(1)).unsqueeze(0).to(combined.device) |
| combined += self.position_embeddings(self.position_ids) |
| combined = self.embed_dropout(combined) |
|
|
| x = self.encoder(combined, output_hidden_states=True).last_hidden_state |
|
|
| start = 0 |
| outputs = {} |
| |
| for tgt_modality in self.modalscalars: |
| length = input_label_mapping[tgt_modality]['input'].shape[1] |
| x_t = x[:, start:start+length, :] |
| outputs[f"{tgt_modality}_logits"] = self.num_head[tgt_modality](x_t) |
|
|
| start += length |
|
|
| if getattr(self, 'save_umap_for', None): |
| pooled = x_t.mean(dim=1) |
| self.save_pooled_embedding(pooled) |
|
|
| return (outputs, input_label_mapping) if return_input_label_mapping else outputs |
|
|
| def save_pooled_embedding(self, features): |
| """ |
| Save the last hidden state to a file. |
| """ |
| import h5py |
| fname = Path(self.save_umap_for) |
| fname.parent.mkdir(parents=True, exist_ok=True) |
|
|
| features = features.detach().cpu().numpy() |
|
|
| if fname.exists(): |
| with h5py.File(fname, 'r+') as f: |
| old_size = f['features'].shape[0] |
| new_size = old_size + features.shape[0] |
|
|
| f['features'].resize((new_size, features.shape[-1])) |
| f['features'][old_size:] = features |
|
|
| else: |
| with h5py.File(fname, 'w') as f: |
| f.create_dataset('features', data=features, maxshape=(None, features.shape[-1]), chunks=True) |
|
|