| from typing import List, Optional, Tuple, Union, Callable, Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from configuration_llama3_SAE import LLama3_SAE_Config |
| except: |
| from .configuration_llama3_SAE import LLama3_SAE_Config |
|
|
| from transformers import ( |
| LlamaPreTrainedModel, |
| LlamaModel, |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.cache_utils import Cache |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
|
|
| import logging |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LLama3_SAE(LlamaPreTrainedModel): |
| config_class = LLama3_SAE_Config |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config: LLama3_SAE_Config): |
| super().__init__(config) |
| self.model = LlamaModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| if config.activation == "topk": |
| if isinstance(config.activation_k, int): |
| activation = TopK(torch.tensor(config.activation_k)) |
| else: |
| activation = TopK(config.activation_k) |
| elif config.activation == "topk-tanh": |
| if isinstance(config.activation_k, int): |
| activation = TopK(torch.tensor(config.activation_k), nn.Tanh()) |
| else: |
| activation = TopK(config.activation_k, nn.Tanh()) |
| elif config.activation == "topk-sigmoid": |
| if isinstance(config.activation_k, int): |
| activation = TopK(torch.tensor(config.activation_k), nn.Sigmoid()) |
| else: |
| activation = TopK(config.activation_k, nn.Sigmoid()) |
| elif config.activation == "jumprelu": |
| activation = JumpReLu() |
| elif config.activation == "relu": |
| activation = "ReLU" |
| elif config.activation == "identity": |
| activation = "Identity" |
| else: |
| raise ( |
| NotImplementedError, |
| f"Activation '{config.activation}' not implemented.", |
| ) |
|
|
| self.SAE = Autoencoder( |
| n_inputs=config.n_inputs, |
| n_latents=config.n_latents, |
| activation=activation, |
| tied=False, |
| normalize=True, |
| ) |
|
|
| self.hook = HookedTransformer_with_SAE_suppresion( |
| block=config.hook_block_num, |
| sae=self.SAE, |
| mod_features=config.mod_features, |
| mod_threshold=config.mod_threshold, |
| mod_replacement=config.mod_replacement, |
| mod_scaling=config.mod_scaling, |
| ).register_with(self.model, config.site) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| if self.config.pretraining_tp > 1: |
| lm_head_slices = self.lm_head.weight.split( |
| self.vocab_size // self.config.pretraining_tp, dim=0 |
| ) |
| logits = [ |
| F.linear(hidden_states, lm_head_slices[i]) |
| for i in range(self.config.pretraining_tp) |
| ] |
| logits = torch.cat(logits, dim=-1) |
| else: |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| loss_fct = nn.CrossEntropyLoss(reduction="none") |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
| loss = loss.view(logits.size(0), -1) |
| mask = loss != 0 |
| loss = loss.sum(dim=-1) / mask.sum(dim=-1) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| past_length = 0 |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| past_length = ( |
| cache_position[0] |
| if cache_position is not None |
| else past_key_values.get_seq_length() |
| ) |
| max_cache_length = ( |
| torch.tensor( |
| past_key_values.get_max_length(), device=input_ids.device |
| ) |
| if past_key_values.get_max_length() is not None |
| else None |
| ) |
| cache_length = ( |
| past_length |
| if max_cache_length is None |
| else torch.min(max_cache_length, past_length) |
| ) |
| |
| else: |
| cache_length = past_length = past_key_values[0][0].shape[2] |
| max_cache_length = None |
|
|
| |
| |
| |
| if ( |
| attention_mask is not None |
| and attention_mask.shape[1] > input_ids.shape[1] |
| ): |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and cache_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| |
| |
| |
| model_inputs = {"input_ids": input_ids.contiguous()} |
|
|
| input_length = ( |
| position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] |
| ) |
| if cache_position is None: |
| cache_position = torch.arange( |
| past_length, past_length + input_length, device=input_ids.device |
| ) |
| elif use_cache: |
| cache_position = cache_position[-input_length:] |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple( |
| past_state.index_select(0, beam_idx.to(past_state.device)) |
| for past_state in layer_past |
| ), |
| ) |
| return reordered_past |
|
|
|
|
| def LN( |
| x: torch.Tensor, eps: float = 1e-5 |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| mu = x.mean(dim=-1, keepdim=True) |
| x = x - mu |
| std = x.std(dim=-1, keepdim=True) |
| x = x / (std + eps) |
| return x, mu, std |
|
|
|
|
| class Autoencoder(nn.Module): |
| """Sparse autoencoder |
| |
| Implements: |
| latents = activation(encoder(x - pre_bias) + latent_bias) |
| recons = decoder(latents) + pre_bias |
| """ |
|
|
| def __init__( |
| self, |
| n_latents: int, |
| n_inputs: int, |
| activation: Callable = nn.ReLU(), |
| tied: bool = False, |
| normalize: bool = False, |
| ) -> None: |
| """ |
| :param n_latents: dimension of the autoencoder latent |
| :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units) |
| :param activation: activation function |
| :param tied: whether to tie the encoder and decoder weights |
| """ |
| super().__init__() |
| self.n_inputs = n_inputs |
| self.n_latents = n_latents |
|
|
| self.pre_bias = nn.Parameter(torch.zeros(n_inputs)) |
| self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False) |
| self.latent_bias = nn.Parameter(torch.zeros(n_latents)) |
| self.activation = activation |
|
|
| if isinstance(activation, JumpReLu): |
| self.threshold = nn.Parameter(torch.empty(n_latents)) |
| torch.nn.init.constant_(self.threshold, 0.001) |
| self.forward = self.forward_jumprelu |
| elif isinstance(activation, TopK): |
| self.forward = self.forward_topk |
| else: |
| logger.warning( |
| f"Using TopK forward function even if activation is not TopK, but is {activation}" |
| ) |
| self.forward = self.forward_topk |
|
|
| if tied: |
| |
| self.decoder = nn.Linear(n_latents, n_inputs, bias=False) |
| self.decoder.weight.data = self.encoder.weight.data.T.clone() |
| else: |
| self.decoder = nn.Linear(n_latents, n_inputs, bias=False) |
| self.normalize = normalize |
|
|
| def encode_pre_act( |
| self, x: torch.Tensor, latent_slice: slice = slice(None) |
| ) -> torch.Tensor: |
| """ |
| :param x: input data (shape: [batch, n_inputs]) |
| :param latent_slice: slice of latents to compute |
| Example: latent_slice = slice(0, 10) to compute only the first 10 latents. |
| :return: autoencoder latents before activation (shape: [batch, n_latents]) |
| """ |
| x = x - self.pre_bias |
| latents_pre_act = F.linear( |
| x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice] |
| ) |
| return latents_pre_act |
|
|
| def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: |
| if not self.normalize: |
| return x, dict() |
| x, mu, std = LN(x) |
| return x, dict(mu=mu, std=std) |
|
|
| def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| :param x: input data (shape: [batch, n_inputs]) |
| :return: autoencoder latents (shape: [batch, n_latents]) |
| """ |
| x, info = self.preprocess(x) |
| return self.activation(self.encode_pre_act(x)), info |
|
|
| def decode( |
| self, latents: torch.Tensor, info: dict[str, Any] | None = None |
| ) -> torch.Tensor: |
| """ |
| :param latents: autoencoder latents (shape: [batch, n_latents]) |
| :return: reconstructed data (shape: [batch, n_inputs]) |
| """ |
| ret = self.decoder(latents) + self.pre_bias |
| if self.normalize: |
| assert info is not None |
| ret = ret * info["std"] + info["mu"] |
| return ret |
|
|
| def forward_topk( |
| self, x: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| :param x: input data (shape: [batch, n_inputs]) |
| :return: autoencoder latents pre activation (shape: [batch, n_latents]) |
| autoencoder latents (shape: [batch, n_latents]) |
| reconstructed data (shape: [batch, n_inputs]) |
| """ |
| x, info = self.preprocess(x) |
| latents_pre_act = self.encode_pre_act(x) |
| latents = self.activation(latents_pre_act) |
| recons = self.decode(latents, info) |
|
|
| return latents_pre_act, latents, recons |
|
|
| def forward_jumprelu( |
| self, x: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| :param x: input data (shape: [batch, n_inputs]) |
| :return: autoencoder latents pre activation (shape: [batch, n_latents]) |
| autoencoder latents (shape: [batch, n_latents]) |
| reconstructed data (shape: [batch, n_inputs]) |
| """ |
| x, info = self.preprocess(x) |
| latents_pre_act = self.encode_pre_act(x) |
| latents = self.activation(F.relu(latents_pre_act), torch.exp(self.threshold)) |
| recons = self.decode(latents, info) |
|
|
| return latents_pre_act, latents, recons |
|
|
|
|
| class TiedTranspose(nn.Module): |
| def __init__(self, linear: nn.Linear): |
| super().__init__() |
| self.linear = linear |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert self.linear.bias is None |
| |
| return F.linear(x, self.linear.weight.t(), None) |
|
|
| @property |
| def weight(self) -> torch.Tensor: |
| return self.linear.weight.t() |
|
|
| @property |
| def bias(self) -> torch.Tensor: |
| return self.linear.bias |
|
|
|
|
| class TopK(nn.Module): |
| def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None: |
| super().__init__() |
| self.k = k |
| self.postact_fn = postact_fn |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| topk = torch.topk(x, k=self.k, dim=-1) |
| values = self.postact_fn(topk.values) |
| |
| result = torch.zeros_like(x) |
| result.scatter_(-1, topk.indices, values) |
| return result |
|
|
|
|
| class JumpReLu(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, input, threshold): |
| return JumpReLUFunction.apply(input, threshold) |
|
|
|
|
| class HeavyStep(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, input, threshold): |
| return HeavyStepFunction.apply(input, threshold) |
|
|
|
|
| def rectangle(x): |
| return (x > -0.5) & (x < 0.5) |
|
|
|
|
| class JumpReLUFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(input, threshold): |
| output = input * (input > threshold) |
| return output |
|
|
| @staticmethod |
| def setup_context(ctx, inputs, output): |
| input, threshold = inputs |
| ctx.save_for_backward(input, threshold) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| bandwidth = 0.001 |
| |
| input, threshold = ctx.saved_tensors |
| grad_input = grad_threshold = None |
|
|
| grad_input = input > threshold |
| grad_threshold = ( |
| -(threshold / bandwidth) |
| * rectangle((input - threshold) / bandwidth) |
| * grad_output |
| ) |
|
|
| return grad_input, grad_threshold |
|
|
|
|
| class HeavyStepFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(input, threshold): |
| output = input * threshold |
| return output |
|
|
| @staticmethod |
| def setup_context(ctx, inputs, output): |
| input, threshold = inputs |
| ctx.save_for_backward(input, threshold) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| bandwidth = 0.001 |
| |
| input, threshold = ctx.saved_tensors |
| grad_input = grad_threshold = None |
|
|
| grad_input = torch.zeros_like(input) |
| grad_threshold = ( |
| -(1.0 / bandwidth) |
| * rectangle((input - threshold) / bandwidth) |
| * grad_output |
| ) |
|
|
| return grad_input, grad_threshold |
|
|
|
|
| ACTIVATIONS_CLASSES = { |
| "ReLU": nn.ReLU, |
| "Identity": nn.Identity, |
| "TopK": TopK, |
| "JumpReLU": JumpReLu, |
| } |
|
|
|
|
| class HookedTransformer_with_SAE: |
| """Auxilliary class used to extract mlp activations from transformer models.""" |
|
|
| def __init__(self, block: int, sae) -> None: |
| self.block = block |
| self.sae = sae |
|
|
| self.remove_handle = ( |
| None |
| ) |
|
|
| self._features = None |
|
|
| def register_with(self, model): |
| |
| self.remove_handle = model.layers[self.block].mlp.register_forward_hook(self) |
|
|
| return self |
|
|
| def pop(self) -> torch.Tensor: |
| """Remove and return extracted feature from this hook. |
| |
| We only allow access to the features this way to not have any lingering references to them. |
| """ |
| assert self._features is not None, "Feature extractor was not called yet!" |
| features = self._features |
| self._features = None |
| return features |
|
|
| def __call__(self, module, inp, outp) -> None: |
| self._features = outp |
| return self.sae(outp)[2] |
|
|
|
|
| class HookedTransformer_with_SAE_suppresion: |
| """Auxilliary class used to extract mlp activations from transformer models.""" |
|
|
| def __init__( |
| self, |
| block: int, |
| sae: Autoencoder, |
| mod_features: list = None, |
| mod_threshold: list = None, |
| mod_replacement: list = None, |
| mod_scaling: list = None, |
| mod_balance: bool = False, |
| multi_feature: bool = False, |
| ) -> None: |
| self.block = block |
| self.sae = sae |
|
|
| self.remove_handle = ( |
| None |
| ) |
|
|
| self._features = None |
| self.mod_features = mod_features |
| self.mod_threshold = mod_threshold |
| self.mod_replacement = mod_replacement |
| self.mod_scaling = mod_scaling |
| self.mod_balance = mod_balance |
| self.mod_vector = None |
| self.mod_vec_factor = 1.0 |
|
|
| if multi_feature: |
| self.modify = self.modify_list |
| else: |
| self.modify = self.modify_single |
|
|
| if isinstance(self.sae.activation, JumpReLu): |
| logger.info("Setting __call__ function for JumpReLU.") |
| setattr(self, "call", self.__call__jumprelu) |
| elif isinstance(self.sae.activation, TopK): |
| logger.info("Setting __call__ function for TopK.") |
| setattr(self, "call", self.__call__topk) |
| else: |
| logger.warning( |
| f"Using TopK forward function even if activation is not TopK, but is {self.sae.activation}" |
| ) |
| setattr(self, "call", self.__call__topk) |
|
|
| def register_with(self, model, site="mlp"): |
| self.site = site |
| |
| if site == "mlp": |
| self.remove_handle = model.layers[self.block].mlp.register_forward_hook( |
| self |
| ) |
| elif ( |
| site == "block" |
| ): |
| self.remove_handle = model.layers[self.block].register_forward_hook(self) |
| elif site == "attention": |
| raise NotImplementedError |
| else: |
| raise NotImplementedError |
|
|
| |
|
|
| return self |
|
|
| def modify_list(self, latents: torch.Tensor) -> torch.Tensor: |
| if self.mod_replacement is not None: |
| for feat, thresh, mod in zip( |
| self.mod_features, self.mod_threshold, self.mod_replacement |
| ): |
| latents[:, :, feat][latents[:, :, feat] > thresh] = mod |
| elif self.mod_scaling is not None: |
| for feat, thresh, mod in zip( |
| self.mod_features, self.mod_threshold, self.mod_scaling |
| ): |
| latents[:, :, feat][latents[:, :, feat] > thresh] *= mod |
| elif self.mod_vector is not None: |
| latents = latents + self.mod_vec_factor * self.mod_vector |
| else: |
| pass |
|
|
| return latents |
|
|
| def modify_single(self, latents: torch.Tensor) -> torch.Tensor: |
| old_cond_feats = latents[:, :, self.mod_features] |
| if self.mod_replacement is not None: |
| |
| |
| |
| latents[:, :, self.mod_features] = self.mod_replacement |
| elif self.mod_scaling is not None: |
| latents_scaled = latents.clone() |
| latents_scaled[:, :, self.mod_features][ |
| latents[:, :, self.mod_features] > 0 |
| ] *= self.mod_scaling |
| latents_scaled[:, :, self.mod_features][ |
| latents[:, :, self.mod_features] < 0 |
| ] *= -1 * self.mod_scaling |
| latents = latents_scaled |
| |
| elif self.mod_vector is not None: |
| latents = latents + self.mod_vec_factor * self.mod_vector |
| else: |
| pass |
|
|
| if self.mod_balance: |
| |
| |
| num_feat = latents.shape[2] - 1 |
| diff = old_cond_feats - latents[:, :, self.mod_features] |
| if self.mod_features != 0: |
| latents[:, :, : self.mod_features] += (diff / num_feat)[:, :, None] |
| latents[:, :, self.mod_features + 1 :] += (diff / num_feat)[:, :, None] |
|
|
| return latents |
|
|
| def pop(self) -> torch.Tensor: |
| """Remove and return extracted feature from this hook. |
| |
| We only allow access to the features this way to not have any lingering references to them. |
| """ |
| assert self._features is not None, "Feature extractor was not called yet!" |
| if isinstance(self._features, tuple): |
| features = self._features[0] |
| else: |
| features = self._features |
| self._features = None |
| return features |
|
|
| def __call__topk(self, module, inp, outp) -> torch.Tensor: |
| self._features = outp |
| if isinstance(self._features, tuple): |
| features = self._features[0] |
| else: |
| features = self._features |
|
|
| if self.mod_features is None: |
| recons = features |
| else: |
| x, info = self.sae.preprocess(features) |
| latents_pre_act = self.sae.encode_pre_act(x) |
| latents = self.sae.activation(latents_pre_act) |
| |
| |
| |
| |
| |
| mod_latents = self.modify(latents) |
| |
| |
| |
|
|
| recons = self.sae.decode(mod_latents, info) |
|
|
| if isinstance(self._features, tuple): |
| outp = list(outp) |
| outp[0] = recons |
| return tuple(outp) |
| else: |
| return recons |
|
|
| def __call__jumprelu(self, module, inp, outp) -> torch.Tensor: |
| self._features = outp |
| if self.mod_features is None: |
| recons = outp |
| else: |
| x, info = self.sae.preprocess(outp) |
| latents_pre_act = self.sae.encode_pre_act(x) |
| latents = self.sae.activation( |
| F.relu(latents_pre_act), torch.exp(self.sae.threshold) |
| ) |
| latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features] |
| mod_latents = self.modify(latents) |
|
|
| recons = self.sae.decode(mod_latents, info) |
|
|
| return recons |
|
|
| def __call__(self, module, inp, outp) -> torch.Tensor: |
| return self.call(module, inp, outp) |
|
|