Karthik8nitt commited on
Commit
b43faef
·
verified ·
1 Parent(s): adc95b4

Add training script

Browse files
Files changed (1) hide show
  1. train.py +126 -0
train.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parametric Floorplan Generation Model Training
3
+ Based on: DStruct2Design (arXiv:2407.15723)
4
+ Approach: Fine-tune an instruction-tuned LLM (Qwen2.5-1.5B-Instruct) with LoRA
5
+ to generate JSON floorplan structures from parametric constraint prompts.
6
+ Dataset: ludolara/DStruct2Design (10k train, 1k val, 1k test)
7
+ """
8
+ import os
9
+ import json
10
+ import torch
11
+ from datasets import load_dataset, DatasetDict
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from peft import LoraConfig, TaskType
14
+ from trl import SFTTrainer, SFTConfig
15
+
16
+ def format_constraints(example):
17
+ room_count = example["room_count"]
18
+ total_area = example["total_area"]
19
+ room_types = example["room_types"]
20
+ edges = example.get("edges", [])
21
+ rooms = example.get("rooms", [])
22
+ lines = [
23
+ f"Generate a floor plan with {room_count} rooms and a total area of {total_area} square meters.",
24
+ f"The room types are: {', '.join(room_types)}."
25
+ ]
26
+ if rooms:
27
+ lines.append("Room details:")
28
+ for i, room in enumerate(rooms):
29
+ lines.append(f" - Room {i+1} ({room.get('room_type','unknown')}): area ~{room.get('area','unspecified')} m², width ~{room.get('width','unspecified')} m, height ~{room.get('height','unspecified')} m")
30
+ if edges:
31
+ lines.append(f"Adjacency requirements (room indices): {edges}")
32
+ return "\n".join(lines)
33
+
34
+ def format_floorplan_output(example):
35
+ return json.dumps({
36
+ "rooms": [{"room_type": r["room_type"], "area": r["area"], "width": r["width"],
37
+ "height": r["height"], "floor_polygon": r["floor_polygon"],
38
+ "is_regular": r.get("is_regular", 0)} for r in example["rooms"]],
39
+ "edges": example.get("edges", []),
40
+ "room_count": example["room_count"],
41
+ "total_area": example["total_area"],
42
+ "room_types": example["room_types"],
43
+ }, indent=2)
44
+
45
+ def convert(example):
46
+ return {"prompt": format_constraints(example), "completion": format_floorplan_output(example)}
47
+
48
+ def main():
49
+ model_id = "Qwen/Qwen2.5-1.5B-Instruct"
50
+ hub_model_id = os.environ.get("HF_TRAINER_HUB_MODEL_ID", "Karthik8nitt/parametric-floorplan-generator")
51
+ output_dir = "/app/floorplan-model"
52
+
53
+ print("Loading DStruct2Design dataset...")
54
+ dataset = load_dataset("ludolara/DStruct2Design")
55
+ processed = {split: dataset[split].map(convert, remove_columns=dataset[split].column_names) for split in dataset.keys()}
56
+
57
+ print("Loading tokenizer...")
58
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+
62
+ print("Loading model...")
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ model_id,
65
+ torch_dtype=torch.bfloat16,
66
+ device_map="auto",
67
+ trust_remote_code=True,
68
+ )
69
+
70
+ peft_config = LoraConfig(
71
+ r=16,
72
+ lora_alpha=32,
73
+ lora_dropout=0.05,
74
+ bias="none",
75
+ task_type=TaskType.CAUSAL_LM,
76
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
77
+ )
78
+
79
+ training_args = SFTConfig(
80
+ output_dir=output_dir,
81
+ num_train_epochs=5,
82
+ per_device_train_batch_size=4,
83
+ per_device_eval_batch_size=4,
84
+ gradient_accumulation_steps=4,
85
+ learning_rate=1e-4,
86
+ lr_scheduler_type="cosine",
87
+ warmup_ratio=0.1,
88
+ logging_steps=10,
89
+ eval_strategy="steps",
90
+ eval_steps=100,
91
+ save_strategy="steps",
92
+ save_steps=100,
93
+ save_total_limit=3,
94
+ max_seq_length=2048,
95
+ bf16=True,
96
+ gradient_checkpointing=True,
97
+ report_to="trackio",
98
+ run_name="floorplan-qwen1.5b-lora",
99
+ project="parametric-floorplan",
100
+ hub_model_id=hub_model_id,
101
+ push_to_hub=True,
102
+ completion_only_loss=True,
103
+ disable_tqdm=True,
104
+ logging_first_step=True,
105
+ seed=42,
106
+ )
107
+
108
+ trainer = SFTTrainer(
109
+ model=model,
110
+ args=training_args,
111
+ train_dataset=processed["train"],
112
+ eval_dataset=processed["validation"],
113
+ peft_config=peft_config,
114
+ processing_class=tokenizer,
115
+ )
116
+
117
+ print("Starting training...")
118
+ trainer.train()
119
+
120
+ print("Saving and pushing model...")
121
+ trainer.save_model(os.path.join(output_dir, "final"))
122
+ trainer.push_to_hub()
123
+ print(f"Done! Model at https://huggingface.co/{hub_model_id}")
124
+
125
+ if __name__ == "__main__":
126
+ main()