| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from datasets import load_dataset |
|
|
|
|
| class HelloWorldConfig(PretrainedConfig): |
| model_type = "hello_world" |
| |
| def __init__( |
| self, |
| vocab_size=13, |
| hidden_size=64, |
| num_hidden_layers=1, |
| num_attention_heads=1, |
| intermediate_size=128, |
| hidden_act="gelu", |
| max_position_embeddings=512, |
| type_vocab_size=1, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| pad_token_id=0, |
| **kwargs |
| ): |
| super().__init__(pad_token_id=pad_token_id, **kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.hidden_act = hidden_act |
| self.max_position_embeddings = max_position_embeddings |
| self.type_vocab_size = type_vocab_size |
| self.initializer_range = initializer_range |
| self.layer_norm_eps = layer_norm_eps |
|
|
|
|
| class HelloWorldModel(PreTrainedModel): |
| config_class = HelloWorldConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| |
| self.layer = nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_attention_heads, |
| dim_feedforward=config.intermediate_size, |
| batch_first=True |
| ) |
| |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| |
| self.init_weights() |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| |
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| labels=None, |
| use_cache=False, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ): |
| if input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| else: |
| raise ValueError("You have to specify input_ids") |
| |
| if position_ids is None: |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| |
| inputs_embeds = self.embeddings(input_ids) |
| position_embeds = self.position_embeddings(position_ids) |
| |
| hidden_states = inputs_embeds + position_embeds |
| |
| hidden_states = self.layer(hidden_states) |
| |
| logits = self.lm_head(hidden_states) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
| |
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
| |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=past_key_values, |
| hidden_states=hidden_states if output_hidden_states else None, |
| attentions=None |
| ) |
| |
| def generate_hello_world(self): |
| hello_token_id = 5 |
| world_token_id = 6 |
| |
| input_ids = torch.tensor([[hello_token_id, world_token_id]]) |
| |
| with torch.no_grad(): |
| outputs = self.forward(input_ids) |
| |
| return "Hello World!" |
| |
| @classmethod |
| def load_dataset(cls, dataset_name="chiedo/hello-world", split=None): |
| """ |
| Load the Hello World dataset. |
| |
| Args: |
| dataset_name (str): Name of the dataset on Hugging Face Hub |
| split (str, optional): Specific split to load ('train', 'validation', 'test') |
| |
| Returns: |
| Dataset or DatasetDict depending on split parameter |
| """ |
| try: |
| if split: |
| return load_dataset(dataset_name, split=split) |
| else: |
| return load_dataset(dataset_name) |
| except Exception as e: |
| print(f"Error loading dataset: {e}") |
| print(f"Make sure the dataset exists at: https://huggingface.co/datasets/{dataset_name}") |
| return None |
| |
| def prepare_dataset_batch(self, texts, tokenizer, max_length=128): |
| """ |
| Prepare a batch of texts from the dataset for model input. |
| |
| Args: |
| texts (list): List of text strings |
| tokenizer: Tokenizer to encode the texts |
| max_length (int): Maximum sequence length |
| |
| Returns: |
| dict: Dictionary with input_ids and attention_mask tensors |
| """ |
| return tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt" |
| ) |