|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
|
|
| from .modeling_base import PreTrainedModelWrapper
|
|
|
|
|
| class ValueHead(nn.Module):
|
| r"""
|
| The ValueHead class implements a head for GPT2 that returns a scalar for each output token.
|
| """
|
|
|
| def __init__(self, config, **kwargs):
|
| super().__init__()
|
| if not hasattr(config, "summary_dropout_prob"):
|
| summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
|
| else:
|
| summary_dropout_prob = config.summary_dropout_prob
|
|
|
| self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
|
|
|
|
|
| if hasattr(config, "hidden_size"):
|
| hidden_size = config.hidden_size
|
| if hasattr(config, "word_embed_proj_dim"):
|
| hidden_size = config.word_embed_proj_dim
|
| elif hasattr(config, "is_encoder_decoder"):
|
| if config.is_encoder_decoder and hasattr(config, "decoder"):
|
| if hasattr(config.decoder, "hidden_size"):
|
| hidden_size = config.decoder.hidden_size
|
|
|
| self.summary = nn.Linear(hidden_size, 1)
|
|
|
| self.flatten = nn.Flatten()
|
|
|
| def forward(self, hidden_states):
|
| output = self.dropout(hidden_states)
|
|
|
|
|
|
|
| if output.dtype != self.summary.weight.dtype:
|
| output = output.to(self.summary.weight.dtype)
|
|
|
| output = self.summary(output)
|
| return output
|
|
|
|
|
| class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
| r"""
|
| An autoregressive model with a value head in addition to the language model head.
|
| This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
|
| `transformers.PreTrainedModel` class. The wrapper class supports classic functions
|
| such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped
|
| model, simply manipulate the `pretrained_model` attribute of this class.
|
|
|
| Class attributes:
|
| - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
|
| should be set to `transformers.AutoModelForCausalLM` for this class.
|
| - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
|
| wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
|
| in the future
|
| - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
|
| by the `ValueHead` class. Currently, the supported args are:
|
| - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
|
| `ValueHead` class.
|
| - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
|
| `ValueHead` if a specific initialization strategy is selected.
|
| - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
|
| `ValueHead`. Currently, the supported strategies are:
|
| - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
|
| strategy.
|
| - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
|
|
|
| """
|
|
|
| transformers_parent_class = AutoModelForCausalLM
|
| lm_head_namings = ["lm_head", "embed_out"]
|
| supported_args = (
|
| "summary_dropout_prob",
|
| "v_head_initializer_range",
|
| "v_head_init_strategy",
|
| )
|
|
|
| def __init__(self, pretrained_model, **kwargs):
|
| r"""
|
| Initializes the model.
|
|
|
| Args:
|
| pretrained_model (`transformers.PreTrainedModel`):
|
| The model to wrap. It should be a causal language model such as GPT2.
|
| or any model mapped inside the `AutoModelForCausalLM` class.
|
| kwargs (`dict`, `optional`):
|
| Additional keyword arguments, that are passed to the `ValueHead` class.
|
| """
|
| super().__init__(pretrained_model, **kwargs)
|
| v_head_kwargs, _, _ = self._split_kwargs(kwargs)
|
|
|
| if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
|
| raise ValueError("The model does not have a language model head, please use a model that has one.")
|
|
|
| self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
|
|
|
| self._init_weights(**v_head_kwargs)
|
|
|
| def _init_weights(self, **kwargs):
|
| r"""
|
| Initializes the weights of the value head. The default initialization strategy is random.
|
| Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument
|
| when calling `.from_pretrained`. Supported strategies are:
|
| - `normal`: initializes the weights with a normal distribution.
|
|
|
| Args:
|
| **kwargs (`dict`, `optional`):
|
| Additional keyword arguments, that are passed to the `ValueHead` class. These arguments
|
| can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range`
|
| argument.
|
| """
|
| initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
|
|
|
| init_strategy = kwargs.pop("v_head_init_strategy", None)
|
| if init_strategy is None:
|
|
|
| pass
|
| elif init_strategy == "normal":
|
| self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
|
| self.v_head.summary.bias.data.zero_()
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| past_key_values=None,
|
| attention_mask=None,
|
| **kwargs,
|
| ):
|
| r"""
|
| Applies a forward pass to the wrapped model and returns the logits of the value head.
|
|
|
| Args:
|
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| Indices of input sequence tokens in the vocabulary.
|
| past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
|
| Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
| (see `past_key_values` input) to speed up sequential decoding.
|
| attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
| Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
| - 1 for tokens that are **not masked**,
|
| - 0 for tokens that are **masked**.
|
| kwargs (`dict`, `optional`):
|
| Additional keyword arguments, that are passed to the wrapped model.
|
| """
|
| kwargs["output_hidden_states"] = True
|
| kwargs["past_key_values"] = past_key_values
|
|
|
| if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
|
| kwargs.pop("past_key_values")
|
|
|
| base_model_output = self.pretrained_model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| **kwargs,
|
| )
|
|
|
| last_hidden_state = base_model_output.hidden_states[-1]
|
| lm_logits = base_model_output.logits
|
| loss = base_model_output.loss
|
|
|
| if last_hidden_state.device != self.v_head.summary.weight.device:
|
| last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
|
|
|
| value = self.v_head(last_hidden_state).squeeze(-1)
|
|
|
|
|
| if lm_logits.dtype != torch.float32:
|
| lm_logits = lm_logits.float()
|
|
|
| return (lm_logits, loss, value)
|
|
|
| def generate(self, *args, **kwargs):
|
| r"""
|
| A simple wrapper around the `generate` method of the wrapped model.
|
| Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)
|
| method of the wrapped model for more information about the supported arguments.
|
|
|
| Args:
|
| *args (`list`, *optional*):
|
| Positional arguments passed to the `generate` method of the wrapped model.
|
| **kwargs (`dict`, *optional*):
|
| Keyword arguments passed to the `generate` method of the wrapped model.
|
| """
|
| return self.pretrained_model.generate(*args, **kwargs)
|
|
|
| def state_dict(self, *args, **kwargs):
|
| r"""
|
| Returns the state dictionary of the model. We add the state dictionary of the value head
|
| to the state dictionary of the wrapped model by prepending the key with `v_head.`.
|
| """
|
| if not self.is_peft_model:
|
| pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
|
| else:
|
|
|
| pretrained_model_state_dict = {}
|
|
|
| v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
|
| for k, v in v_head_state_dict.items():
|
| pretrained_model_state_dict[f"v_head.{k}"] = v
|
| return pretrained_model_state_dict
|
|
|
| def push_to_hub(self, *args, **kwargs):
|
| setattr(self.pretrained_model, "v_head", self.v_head)
|
|
|
| return self.pretrained_model.push_to_hub(*args, **kwargs)
|
|
|
| def post_init(self, state_dict):
|
| r"""
|
| We add the state dictionary of the value head to the state dictionary of the wrapped model
|
| by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
|
| keys of the value head state dictionary.
|
| """
|
| for k in list(state_dict.keys()):
|
| if "v_head." in k:
|
| state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
|
| self.v_head.load_state_dict(state_dict, strict=False)
|
| del state_dict
|
|
|
| if hasattr(self.pretrained_model, "hf_device_map"):
|
| if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
|
| raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
|
|
|
| first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
|
|
|
| self.v_head = self.v_head.to(first_device)
|
|
|
| def set_device_hook(module, input, outputs):
|
| new_output = ()
|
| for output in outputs:
|
| if isinstance(output, torch.Tensor):
|
| new_output += (output.to(first_device),)
|
| else:
|
| new_output += (output,)
|
| return new_output
|
|
|
| self.register_forward_hook(set_device_hook)
|
|
|
| self.is_sequential_parallel = True
|
|
|
|
|
| class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
|
| r"""
|
| A seq2seq model with a value head in addition to the language model head.
|
| This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
|
| `transformers.PreTrainedModel` class. The wrapper class supports classic functions
|
| such as `from_pretrained` and `push_to_hub` and also provides some additional
|
| functionalities such as `generate`.
|
|
|
| Args:
|
| pretrained_model (`transformers.PreTrainedModel`):
|
| The model to wrap. It should be a causal language model such as GPT2.
|
| or any model mapped inside the `AutoModelForSeq2SeqLM` class.
|
| kwargs:
|
| Additional keyword arguments passed along to the `ValueHead` class.
|
| """
|
|
|
| transformers_parent_class = AutoModelForSeq2SeqLM
|
| lm_head_namings = ["lm_head", "embed_out", "output_projection"]
|
| supported_args = (
|
| "summary_dropout_prob",
|
| "v_head_initializer_range",
|
| "v_head_init_strategy",
|
| )
|
|
|
| def __init__(self, pretrained_model, **kwargs):
|
| super().__init__(pretrained_model, **kwargs)
|
| v_head_kwargs, _, _ = self._split_kwargs(kwargs)
|
| self.is_encoder_decoder = True
|
|
|
| if not self._has_lm_head():
|
| raise ValueError("The model does not have a language model head, please use a model that has one.")
|
|
|
| self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
|
|
|
| self._init_weights(**v_head_kwargs)
|
|
|
| def _has_lm_head(self):
|
|
|
| for name, module in self.pretrained_model.named_modules():
|
| if any(attribute in name for attribute in self.lm_head_namings):
|
| return True
|
| return False
|
|
|
| def post_init(self, state_dict):
|
| r"""
|
| We add the state dictionary of the value head to the state dictionary of the wrapped model
|
| by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
|
| keys of the value head state dictionary.
|
| """
|
| for k in list(state_dict.keys()):
|
| if "v_head." in k:
|
| state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
|
| self.v_head.load_state_dict(state_dict, strict=False)
|
| del state_dict
|
|
|
| if hasattr(self.pretrained_model, "hf_device_map"):
|
| if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
|
| raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
|
|
|
|
|
| for name, module in self.pretrained_model.named_modules():
|
| if any(attribute in name for attribute in self.lm_head_namings):
|
| lm_head_device = module.weight.device
|
| break
|
|
|
|
|
| self.v_head = self.v_head.to(lm_head_device)
|
|
|
| def set_device_hook(module, input, outputs):
|
| r"""
|
| A hook that sets the device of the output of the model to the device of the first
|
| parameter of the model.
|
|
|
| Args:
|
| module (`nn.Module`):
|
| The module to which the hook is attached.
|
| input (`tuple`):
|
| The input to the module.
|
| outputs (`tuple`):
|
| The output of the module.
|
| """
|
| new_output = ()
|
| for output in outputs:
|
| if isinstance(output, torch.Tensor):
|
| new_output += (output.to(lm_head_device),)
|
| else:
|
| new_output += (output,)
|
| return new_output
|
|
|
| self.register_forward_hook(set_device_hook)
|
| self.is_sequential_parallel = True
|
|
|
| def state_dict(self, *args, **kwargs):
|
| r"""
|
| Returns the state dictionary of the model. We add the state dictionary of the value head
|
| to the state dictionary of the wrapped model by prepending the key with `v_head.`.
|
| """
|
| if not self.is_peft_model:
|
| pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
|
| else:
|
|
|
| pretrained_model_state_dict = {}
|
|
|
| v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
|
| for k, v in v_head_state_dict.items():
|
| pretrained_model_state_dict[f"v_head.{k}"] = v
|
| return pretrained_model_state_dict
|
|
|
| def push_to_hub(self, *args, **kwargs):
|
| setattr(self.pretrained_model, "v_head", self.v_head)
|
|
|
| return self.pretrained_model.push_to_hub(*args, **kwargs)
|
|
|
| def _init_weights(self, **kwargs):
|
| r"""
|
| We initialize the weights of the value head.
|
| """
|
| initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
|
|
|
| init_strategy = kwargs.pop("v_head_init_strategy", None)
|
| if init_strategy is None:
|
|
|
| pass
|
| elif init_strategy == "normal":
|
| self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
|
| self.v_head.summary.bias.data.zero_()
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| past_key_values=None,
|
| attention_mask=None,
|
| **kwargs,
|
| ):
|
| kwargs["past_key_values"] = past_key_values
|
| if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
|
| kwargs.pop("past_key_values")
|
|
|
| base_model_output = self.pretrained_model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| output_hidden_states=True,
|
| **kwargs,
|
| )
|
|
|
| last_hidden_state = base_model_output.decoder_hidden_states[-1]
|
| lm_logits = base_model_output.logits
|
| loss = base_model_output.loss
|
|
|
| value = self.v_head(last_hidden_state).squeeze(-1)
|
|
|
|
|
| if lm_logits.dtype != torch.float32:
|
| lm_logits = lm_logits.float()
|
|
|
| return (lm_logits, loss, value)
|
|
|
| def generate(self, *args, **kwargs):
|
| r"""
|
| We call `generate` on the wrapped model.
|
| """
|
| return self.pretrained_model.generate(*args, **kwargs)
|
|
|