esmith5594 commited on
Commit
1cef7c9
Β·
verified Β·
1 Parent(s): df19e66

Upload train_google_api.py

Browse files
Files changed (1) hide show
  1. train_google_api.py +74 -0
train_google_api.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tune Qwen2.5-Coder-3B-Instruct on Google Classroom & Drive API code.
3
+ Uses LoRA via PEFT for memory-efficient training.
4
+ """
5
+ import os
6
+ from datasets import load_dataset
7
+ from trl import SFTTrainer, SFTConfig
8
+ from peft import LoraConfig
9
+
10
+ # ── Config ───────────────────────────────────────────────────────────
11
+ MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
12
+ DATASET_ID = "esmith5594/google-classroom-drive-api-code"
13
+ OUTPUT_DIR = "qwen25-coder-3b-google-api-lora"
14
+ HUB_MODEL_ID = "esmith5594/qwen25-coder-3b-google-api-lora"
15
+
16
+ # ── Load Dataset ─────────────────────────────────────────────────────
17
+ dataset = load_dataset(DATASET_ID, split="train")
18
+ print(f"Loaded {len(dataset)} training examples")
19
+
20
+ # ── LoRA Config (based on Octopus paper + TRL best practices) ─────────
21
+ peft_config = LoraConfig(
22
+ r=16,
23
+ lora_alpha=32,
24
+ target_modules=[
25
+ "q_proj", "k_proj", "v_proj", "o_proj",
26
+ "gate_proj", "up_proj", "down_proj",
27
+ ],
28
+ lora_dropout=0.05,
29
+ bias="none",
30
+ task_type="CAUSAL_LM",
31
+ )
32
+
33
+ # ── Training Config ──────────────────────────────────────────────────
34
+ training_args = SFTConfig(
35
+ output_dir=OUTPUT_DIR,
36
+ hub_model_id=HUB_MODEL_ID,
37
+ push_to_hub=True,
38
+ num_train_epochs=5,
39
+ per_device_train_batch_size=4,
40
+ gradient_accumulation_steps=128,
41
+ learning_rate=2e-5,
42
+ lr_scheduler_type="constant",
43
+ warmup_ratio=0.0,
44
+ bf16=True,
45
+ gradient_checkpointing=True,
46
+ max_seq_length=4096,
47
+ logging_steps=10,
48
+ logging_first_step=True,
49
+ disable_tqdm=True,
50
+ save_strategy="epoch",
51
+ save_total_limit=2,
52
+ report_to="trackio",
53
+ run_name="qwen25-coder-3b-google-api-lora",
54
+ project="google-api-coder",
55
+ assistant_only_loss=True,
56
+ packing=False,
57
+ )
58
+
59
+ # ── Trainer ──────────────────────────────────────────────────────────
60
+ trainer = SFTTrainer(
61
+ model=MODEL_ID,
62
+ train_dataset=dataset,
63
+ peft_config=peft_config,
64
+ args=training_args,
65
+ )
66
+
67
+ # ── Train ────────────────────────────────────────────────────────────
68
+ trainer.train()
69
+
70
+ # ── Save ─────────────────────────────────────────────────────────────
71
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
72
+ trainer.push_to_hub()
73
+
74
+ print(f"\nTraining complete! Model saved to {HUB_MODEL_ID}")