Spaces:
Sleeping
Sleeping
File size: 1,810 Bytes
98fc9b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel
def main():
max_seq_length = 1024
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "llama-3-8b-instruct",
max_seq_length = max_seq_length,
dtype = None,
load_in_4bit = True,
)
# We use a subset of GSM8K style data to warm start the reasoning format
# In practice, this would load a custom generated dataset locally
try:
dataset = load_dataset("gsm8k", "main", split="train[:5%]")
except Exception:
# Fallback dummy dataset
dataset = load_dataset("json", data_files={"train": ["dummy.json"]}, split="train")
def formatting_prompts_func(examples):
texts = []
for q, a in zip(examples['question'], examples['answer']):
# Assuming 'answer' has reasoning and then '#### answer'
parts = a.split("####")
reasoning = parts[0].strip()
final_answer = parts[1].strip() if len(parts) > 1 else ""
text = f"Problem: {q}\nReasoning: {reasoning}\nAnswer: {final_answer}"
texts.append(text)
return { "text" : texts }
dataset = dataset.map(formatting_prompts_func, batched = True)
training_args = SFTConfig(
output_dir="sft_outputs",
dataset_text_field="text",
max_seq_length=max_seq_length,
per_device_train_batch_size=2,
max_steps=100,
learning_rate=2e-5,
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
)
print("Starting SFT Warm-Start...")
trainer.train()
if __name__ == "__main__":
main()
|