import torch from datasets import load_dataset from tokenizers import Tokenizer, models, trainers, pre_tokenizers from transformers import LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling dataset = load_dataset("roneneldan/TinyStories", split="train[:200000]") def train_tokenizer(dataset): tokenizer = Tokenizer(models.BPE(unk_token="")) tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() trainer = trainers.BpeTrainer( vocab_size=8192, special_tokens=["", "", "", "", ""] ) def batch_iterator(): for i in range(0, len(dataset), 1000): yield dataset[i : i + 1000]["text"] tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) from transformers import PreTrainedTokenizerFast return PreTrainedTokenizerFast( tokenizer_object=tokenizer, bos_token="", eos_token="", unk_token="", pad_token="" ) tokenizer = train_tokenizer(dataset) def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, max_length=256) tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) config = LlamaConfig( vocab_size=8192, hidden_size=256, intermediate_size=1024, num_hidden_layers=8, num_attention_heads=8, max_position_embeddings=256, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) model = LlamaForCausalLM(config) print(f"Model parameters: {model.num_parameters():,}") training_args = TrainingArguments( output_dir="./StorySupra-10M", per_device_train_batch_size=32, num_train_epochs=3, save_steps=500, logging_steps=100, learning_rate=5e-4, weight_decay=0.01, fp16=True, push_to_hub=False, report_to="none", lr_scheduler_type="cosine" ) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, ) trainer.train() def generate_story(prompt): inputs = tokenizer(prompt, return_tensors="pt").to("cuda") model.to("cuda") outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.55, top_k=25, top_p=0.85, repetition_penalty=1.1) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) generate_story("Once upon a time, a small bird") trainer.save_model("./StorySupra-10M") tokenizer.save_pretrained("./StorySupra-10M")