engine / finetuning /lora.py
VeuReu's picture
Upload 5 files
e5dde7c verified
raw
history blame
7.64 kB
import os
import argparse
from pathlib import Path
from typing import List, Dict
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
)
from peft import LoraConfig, get_peft_model
BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
def find_training_pairs(data_dir: Path) -> List[Dict[str, str]]:
"""Recorre las subcarpetas de data_dir y busca pares target_une_ad.srt / free_ad.txt.
Cada ejemplo se formatea como una instrucción estilo instruct, usando el SRT como entrada
y la narración libre como salida.
"""
examples: List[Dict[str, str]] = []
if not data_dir.exists():
raise FileNotFoundError(f"Data dir not found: {data_dir}")
for item in sorted(data_dir.iterdir()):
if not item.is_dir():
continue
srt_path = item / "target_une_ad.srt"
free_path = item / "free_ad.txt"
if not srt_path.exists() or not free_path.exists():
continue
srt_text = srt_path.read_text(encoding="utf-8")
free_text = free_path.read_text(encoding="utf-8")
# Formato tipo instruction-tuning, en catalán, coherente con la tarea
prompt = (
"Converteix el següent fitxer SRT d'audiodescripció UNE (amb restriccions temporals) "
"en una narració lliure detallada en català, sense límits de temps. "
"Mantén tota la informació visual rellevant però amb un to fluid i natural.\n\n"
"### SRT UNE\n" + srt_text.strip() + "\n\n### Narració lliure:"
)
examples.append({"prompt": prompt, "output": free_text.strip()})
if not examples:
raise RuntimeError(f"No training pairs found in {data_dir} (expected target_une_ad.srt + free_ad.txt)")
return examples
def build_dataset(pairs: List[Dict[str, str]], tokenizer: AutoTokenizer, max_length: int = 2048) -> Dataset:
"""Construye un Dataset de Hugging Face a partir de los pares prompt/output.
Se concatena en una sola secuencia para entrenamiento causal:
[PROMPT] + [OUTPUT] + eos
y se enmascaran los tokens del prompt para que la loss sólo se compute sobre la salida.
"""
def _gen():
for ex in pairs:
yield {"prompt": ex["prompt"], "output": ex["output"]}
raw_ds = Dataset.from_generator(_gen)
def tokenize_fn(batch):
prompts = batch["prompt"]
outputs = batch["output"]
input_ids_list = []
labels_list = []
for p, o in zip(prompts, outputs):
full_text = p + "\n" + o + tokenizer.eos_token
enc = tokenizer(
full_text,
truncation=True,
max_length=max_length,
padding="max_length",
)
# Máscara: ignorar loss en tokens del prompt
prompt_ids = tokenizer(p + "\n", truncation=True, max_length=max_length)["input_ids"]
prompt_len = min(len(prompt_ids), max_length)
labels = enc["input_ids"].copy()
for i in range(prompt_len):
labels[i] = -100
input_ids_list.append(enc["input_ids"])
labels_list.append(labels)
return {"input_ids": input_ids_list, "attention_mask": [([1] * max_length)] * len(input_ids_list), "labels": labels_list}
tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=["prompt", "output"])
return tokenized
def create_lora_model(base_model_name: str, r: int = 16, alpha: int = 32, dropout: float = 0.05):
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype="auto",
device_map="auto",
)
lora_config = LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
return model
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Fine-tuning LoRA per a salamandra-instruct-7b amb dades UNE/free AD")
parser.add_argument(
"--base_model",
type=str,
default="projecte-aina/salamandra-instruct-7b",
help="Nom o ruta del model base (HF hub o path local)",
)
parser.add_argument(
"--data_dir",
type=str,
default=str(DATA_DIR),
help="Directori base amb subcarpetes que contenen target_une_ad.srt i free_ad.txt",
)
parser.add_argument(
"--output_dir",
type=str,
default=str(BASE_DIR / "lora_output"),
help="Directori on desar l'adapter LoRA",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation", type=int, default=8)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--warmup_ratio", type=float, default=0.03)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=200)
parser.add_argument("--eval_steps", type=int, default=200)
parser.add_argument("--r", type=int, default=16, help="Rank de LoRA")
parser.add_argument("--alpha", type=int, default=32, help="Alpha de LoRA")
parser.add_argument("--dropout", type=float, default=0.05, help="Dropout de LoRA")
return parser.parse_args()
def main():
args = parse_args()
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"[lora] Buscant dades a: {data_dir}")
pairs = find_training_pairs(data_dir)
print(f"[lora] Nombre d'exemples trobats: {len(pairs)}")
print(f"[lora] Carregant tokenizer de {args.base_model}")
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("[lora] Construint dataset tokenitzat...")
dataset = build_dataset(pairs, tokenizer, max_length=args.max_length)
print(f"[lora] Carregant model base {args.base_model} i aplicant LoRA...")
model = create_lora_model(args.base_model, r=args.r, alpha=args.alpha, dropout=args.dropout)
training_args = TrainingArguments(
output_dir=str(output_dir),
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
num_train_epochs=args.epochs,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
evaluation_strategy="steps",
eval_steps=args.eval_steps,
save_total_limit=2,
bf16=True,
gradient_checkpointing=True,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=None,
tokenizer=tokenizer,
)
print("[lora] Iniciant entrenament...")
trainer.train()
print("[lora] Guardant adapter LoRA...")
model.save_pretrained(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
print(f"[lora] Entrenament completat. Adapter guardat a {output_dir}")
if __name__ == "__main__":
main()