Pin / train.py
LH-Tech-AI's picture
Create train.py
859492a verified
print("Loading...")
import torch
torch.cuda.empty_cache()
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
MODEL_NAME = "Pin-25M"
DATASET_ID = "starhopp3r/TinyChat"
MAX_LENGTH = 256
BATCH_SIZE = 32
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
"gpt2",
n_layer=12,
n_head=12,
n_embd=288,
n_inner=1152,
vocab_size=len(tokenizer),
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
model = AutoModelForCausalLM.from_config(config)
print(f"Model parameters: {model.num_parameters() / 1e6:.2f}M")
print("Loading dataset...")
dataset = load_dataset(DATASET_ID, split="train")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH)
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
num_proc=4
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
print("Setting up training arguments...")
training_args = TrainingArguments(
output_dir="./" + MODEL_NAME + "_checkpoints",
num_train_epochs=1,
max_steps=1500,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=2,
learning_rate=5e-4,
weight_decay=0.01,
logging_steps=100,
save_steps=2500,
fp16=True,
push_to_hub=False,
report_to="none",
warmup_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
data_collator=data_collator,
)
print("Starting training...")
trainer.train()
trainer.save_model("./" + MODEL_NAME + "-Final")
tokenizer.save_pretrained("./" + MODEL_NAME + "-Final")
def chat(prompt):
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
model.to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n--- Test Chat ---")
print(chat("Hello, how are you today?"))