AutoMathReasoner / train /sft_warm_start.py
Pratap-K's picture
AutoMathReasoner
98fc9b6
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel
def main():
max_seq_length = 1024
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "llama-3-8b-instruct",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
)
# We use a subset of GSM8K style data to warm start the reasoning format
# In practice, this would load a custom generated dataset locally
try:
dataset = load_dataset("gsm8k", "main", split="train[:5%]")
except Exception:
# Fallback dummy dataset
dataset = load_dataset("json", data_files={"train": ["dummy.json"]}, split="train")
def formatting_prompts_func(examples):
texts = []
for q, a in zip(examples['question'], examples['answer']):
# Assuming 'answer' has reasoning and then '#### answer'
parts = a.split("####")
reasoning = parts[0].strip()
final_answer = parts[1].strip() if len(parts) > 1 else ""
text = f"Problem: {q}\nReasoning: {reasoning}\nAnswer: {final_answer}"
texts.append(text)
return { "text" : texts }
dataset = dataset.map(formatting_prompts_func, batched = True)
training_args = SFTConfig(
output_dir="sft_outputs",
dataset_text_field="text",
max_seq_length=max_seq_length,
per_device_train_batch_size=2,
max_steps=100,
learning_rate=2e-5,
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
)
print("Starting SFT Warm-Start...")
trainer.train()
if __name__ == "__main__":
main()