File size: 3,987 Bytes
b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 b43faef d468913 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | """
Parametric Floorplan Generation Model Training
Based on: DStruct2Design (arXiv:2407.15723) approach
Dataset: Custom synthetic dataset generated to match user's ProjectCreate schema
"""
import os
import json
import torch
from datasets import load_dataset, load_from_disk, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = "/app/floorplan-model"
HUB_MODEL_ID = os.environ.get("HF_TRAINER_HUB_MODEL_ID", "Karthik8nitt/parametric-floorplan-generator")
DATASET_PATH = os.environ.get("DATASET_PATH", "/app/floorplan_synthetic_dataset")
# -----------------------------------------------------------------------------
# Load data
# -----------------------------------------------------------------------------
print("Loading dataset...")
if os.path.exists(DATASET_PATH):
dataset = load_from_disk(DATASET_PATH)
else:
# Fallback: load from HF if pre-uploaded
dataset = load_dataset("Karthik8nitt/floorplan-synthetic-dataset")
print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}, Test: {len(dataset['test'])}")
# -----------------------------------------------------------------------------
# Load tokenizer & model
# -----------------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# -----------------------------------------------------------------------------
# LoRA config
# -----------------------------------------------------------------------------
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# -----------------------------------------------------------------------------
# Training arguments
# -----------------------------------------------------------------------------
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
max_seq_length=4096,
bf16=True,
gradient_checkpointing=True,
report_to="trackio",
run_name="floorplan-qwen1.5b-lora",
project="parametric-floorplan",
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
completion_only_loss=True,
disable_tqdm=True,
logging_first_step=True,
seed=42,
)
# -----------------------------------------------------------------------------
# Trainer
# -----------------------------------------------------------------------------
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
peft_config=peft_config,
processing_class=tokenizer,
)
# -----------------------------------------------------------------------------
# Train
# -----------------------------------------------------------------------------
print("Starting training...")
trainer.train()
print("Saving and pushing model...")
trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
trainer.push_to_hub()
print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")
|