chrisvoncsefalvay commited on
Commit
468e6b8
·
verified ·
1 Parent(s): 2d8288a

Upload train_smol_discharge_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_smol_discharge_v2.py +113 -0
train_smol_discharge_v2.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "bitsandbytes",
9
+ # ]
10
+ # ///
11
+
12
+ import os
13
+ from huggingface_hub import login
14
+
15
+ token = os.environ.get("HF_TOKEN")
16
+ if token:
17
+ login(token=token)
18
+ print("Logged in to HuggingFace Hub")
19
+
20
+ from datasets import load_dataset
21
+ from peft import LoraConfig
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM
23
+ from trl import SFTTrainer, SFTConfig
24
+
25
+ print("Loading tokenizer...")
26
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B-Base")
27
+
28
+ CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
29
+
30
+ tokenizer.chat_template = CHAT_TEMPLATE
31
+ special_tokens = {"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]}
32
+ tokenizer.add_special_tokens(special_tokens)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ print("Loading model...")
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ "HuggingFaceTB/SmolLM3-3B-Base",
39
+ torch_dtype="auto",
40
+ device_map="auto",
41
+ )
42
+ model.resize_token_embeddings(len(tokenizer))
43
+
44
+ print("Loading dataset...")
45
+ train_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="train")
46
+ eval_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="validation")
47
+ print(f"Train: {len(train_dataset)} examples")
48
+ print(f"Eval: {len(eval_dataset)} examples")
49
+
50
+ # IMPROVED CONFIG
51
+ config = SFTConfig(
52
+ output_dir="smollm3-discharge-notes-sft",
53
+ push_to_hub=True,
54
+ hub_model_id="chrisvoncsefalvay/smollm3-discharge-notes-sft",
55
+ hub_strategy="every_save",
56
+
57
+ # MORE EPOCHS
58
+ num_train_epochs=10,
59
+
60
+ # BATCH SETTINGS
61
+ per_device_train_batch_size=4,
62
+ per_device_eval_batch_size=2,
63
+ gradient_accumulation_steps=4, # effective batch = 16
64
+
65
+ # HIGHER LR WITH WARMUP
66
+ learning_rate=5e-5,
67
+ warmup_ratio=0.1,
68
+ lr_scheduler_type="cosine",
69
+
70
+ max_length=2048,
71
+ logging_steps=10,
72
+ save_strategy="steps",
73
+ save_steps=100,
74
+ save_total_limit=3,
75
+ eval_strategy="steps",
76
+ eval_steps=100,
77
+ load_best_model_at_end=True,
78
+ metric_for_best_model="eval_loss",
79
+ greater_is_better=False,
80
+ gradient_checkpointing=True,
81
+ bf16=True,
82
+ report_to="trackio",
83
+ project="clinical-action-processing",
84
+ run_name="smollm3-3b-discharge-sft-v2",
85
+ )
86
+
87
+ # STRONGER LORA
88
+ peft_config = LoraConfig(
89
+ r=64, # increased from 16
90
+ lora_alpha=128, # increased from 32
91
+ lora_dropout=0.1, # slightly higher
92
+ bias="none",
93
+ task_type="CAUSAL_LM",
94
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
95
+ )
96
+
97
+ print("Initializing trainer...")
98
+ trainer = SFTTrainer(
99
+ model=model,
100
+ processing_class=tokenizer,
101
+ train_dataset=train_dataset,
102
+ eval_dataset=eval_dataset,
103
+ args=config,
104
+ peft_config=peft_config,
105
+ )
106
+
107
+ print("Starting training...")
108
+ trainer.train()
109
+
110
+ print("Pushing to Hub...")
111
+ trainer.push_to_hub()
112
+
113
+ print("Complete!")