| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch INFLM model.""" |
|
|
| import torch |
| from torch import nn |
| from transformers.models.llama.modeling_llama import ( |
| LlamaDecoderLayer, |
| LlamaModel, |
| LlamaForCausalLM |
| ) |
| from .configuration_inflm import INFLMConfig |
|
|
| _CONFIG_FOR_DOC = "INFLMConfig" |
|
|
|
|
| class INFLMDecoderLayer(LlamaDecoderLayer): |
| def __init__(self, config: INFLMConfig, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
| class INFLMModel(LlamaModel): |
| config_class = INFLMConfig |
| _no_split_modules = ["INFLMDecoderLayer"] |
| |
| def __init__(self, config: INFLMConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList([INFLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
|
|
| class INFLMForCausalLM(LlamaForCausalLM): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config: INFLMConfig): |
| super().__init__(config) |
| self.model = INFLMModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|