print("Loading...") import torch torch.cuda.empty_cache() from datasets import load_dataset from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) MODEL_NAME = "Pin-25M" DATASET_ID = "starhopp3r/TinyChat" MAX_LENGTH = 256 BATCH_SIZE = 32 tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained( "gpt2", n_layer=12, n_head=12, n_embd=288, n_inner=1152, vocab_size=len(tokenizer), bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, ) model = AutoModelForCausalLM.from_config(config) print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M") print("Loading dataset...") dataset = load_dataset(DATASET_ID, split="train") def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH) tokenized_datasets = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names, num_proc=4 ) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) print("Setting up training arguments...") training_args = TrainingArguments( output_dir="./" + MODEL_NAME + "_checkpoints", num_train_epochs=1, max_steps=1500, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=2, learning_rate=5e-4, weight_decay=0.01, logging_steps=100, save_steps=2500, fp16=True, push_to_hub=False, report_to="none", warmup_steps=500, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets, data_collator=data_collator, ) print("Starting training...") trainer.train() trainer.save_model("./" + MODEL_NAME + "-Final") tokenizer.save_pretrained("./" + MODEL_NAME + "-Final") def chat(prompt): formatted_prompt = f"[INST] {prompt} [/INST]" inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda") model.to("cuda") outputs = model.generate( **inputs, max_new_tokens=50, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) print("\n--- Test Chat ---") print(chat("Hello, how are you today?"))