training-scripts / train_smol_discharge.py
chrisvoncsefalvay's picture
Upload train_smol_discharge.py with huggingface_hub
2d8288a verified
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "bitsandbytes",
# ]
# ///
import os
from huggingface_hub import login
# Login with token from environment (set via secrets)
token = os.environ.get("HF_TOKEN")
if token:
login(token=token)
print("Logged in to HuggingFace Hub")
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B-Base")
CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
tokenizer.chat_template = CHAT_TEMPLATE
special_tokens = {"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]}
tokenizer.add_special_tokens(special_tokens)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM3-3B-Base",
torch_dtype="auto",
device_map="auto",
)
model.resize_token_embeddings(len(tokenizer))
print("Loading dataset...")
train_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="train")
eval_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="validation")
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
config = SFTConfig(
output_dir="smollm3-discharge-notes-sft",
push_to_hub=True,
hub_model_id="chrisvoncsefalvay/smollm3-discharge-notes-sft",
hub_strategy="every_save",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=4,
gradient_accumulation_steps=2,
learning_rate=2e-5,
max_length=2048,
logging_steps=10,
save_strategy="steps",
save_steps=50,
save_total_limit=2,
eval_strategy="steps",
eval_steps=50,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
gradient_checkpointing=True,
bf16=True,
report_to="trackio",
project="clinical-action-processing",
run_name="smollm3-3b-discharge-sft-a100",
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
)
print("Starting training...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
print("Complete!")