Karthik8nitt commited on
Commit
d468913
·
verified ·
1 Parent(s): 2d17d7b

Update training script for synthetic dataset

Browse files
Files changed (1) hide show
  1. train.py +96 -107
train.py CHANGED
@@ -1,126 +1,115 @@
1
  """
2
  Parametric Floorplan Generation Model Training
3
- Based on: DStruct2Design (arXiv:2407.15723)
4
- Approach: Fine-tune an instruction-tuned LLM (Qwen2.5-1.5B-Instruct) with LoRA
5
- to generate JSON floorplan structures from parametric constraint prompts.
6
- Dataset: ludolara/DStruct2Design (10k train, 1k val, 1k test)
7
  """
8
  import os
9
  import json
10
  import torch
11
- from datasets import load_dataset, DatasetDict
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
13
  from peft import LoraConfig, TaskType
14
  from trl import SFTTrainer, SFTConfig
15
 
16
- def format_constraints(example):
17
- room_count = example["room_count"]
18
- total_area = example["total_area"]
19
- room_types = example["room_types"]
20
- edges = example.get("edges", [])
21
- rooms = example.get("rooms", [])
22
- lines = [
23
- f"Generate a floor plan with {room_count} rooms and a total area of {total_area} square meters.",
24
- f"The room types are: {', '.join(room_types)}."
25
- ]
26
- if rooms:
27
- lines.append("Room details:")
28
- for i, room in enumerate(rooms):
29
- lines.append(f" - Room {i+1} ({room.get('room_type','unknown')}): area ~{room.get('area','unspecified')} m², width ~{room.get('width','unspecified')} m, height ~{room.get('height','unspecified')} m")
30
- if edges:
31
- lines.append(f"Adjacency requirements (room indices): {edges}")
32
- return "\n".join(lines)
33
 
34
- def format_floorplan_output(example):
35
- return json.dumps({
36
- "rooms": [{"room_type": r["room_type"], "area": r["area"], "width": r["width"],
37
- "height": r["height"], "floor_polygon": r["floor_polygon"],
38
- "is_regular": r.get("is_regular", 0)} for r in example["rooms"]],
39
- "edges": example.get("edges", []),
40
- "room_count": example["room_count"],
41
- "total_area": example["total_area"],
42
- "room_types": example["room_types"],
43
- }, indent=2)
44
 
45
- def convert(example):
46
- return {"prompt": format_constraints(example), "completion": format_floorplan_output(example)}
47
 
48
- def main():
49
- model_id = "Qwen/Qwen2.5-1.5B-Instruct"
50
- hub_model_id = os.environ.get("HF_TRAINER_HUB_MODEL_ID", "Karthik8nitt/parametric-floorplan-generator")
51
- output_dir = "/app/floorplan-model"
 
 
 
52
 
53
- print("Loading DStruct2Design dataset...")
54
- dataset = load_dataset("ludolara/DStruct2Design")
55
- processed = {split: dataset[split].map(convert, remove_columns=dataset[split].column_names) for split in dataset.keys()}
 
 
 
 
56
 
57
- print("Loading tokenizer...")
58
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
59
- if tokenizer.pad_token is None:
60
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
61
 
62
- print("Loading model...")
63
- model = AutoModelForCausalLM.from_pretrained(
64
- model_id,
65
- torch_dtype=torch.bfloat16,
66
- device_map="auto",
67
- trust_remote_code=True,
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- peft_config = LoraConfig(
71
- r=16,
72
- lora_alpha=32,
73
- lora_dropout=0.05,
74
- bias="none",
75
- task_type=TaskType.CAUSAL_LM,
76
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
77
- )
 
 
 
78
 
79
- training_args = SFTConfig(
80
- output_dir=output_dir,
81
- num_train_epochs=5,
82
- per_device_train_batch_size=4,
83
- per_device_eval_batch_size=4,
84
- gradient_accumulation_steps=4,
85
- learning_rate=1e-4,
86
- lr_scheduler_type="cosine",
87
- warmup_ratio=0.1,
88
- logging_steps=10,
89
- eval_strategy="steps",
90
- eval_steps=100,
91
- save_strategy="steps",
92
- save_steps=100,
93
- save_total_limit=3,
94
- max_seq_length=2048,
95
- bf16=True,
96
- gradient_checkpointing=True,
97
- report_to="trackio",
98
- run_name="floorplan-qwen1.5b-lora",
99
- project="parametric-floorplan",
100
- hub_model_id=hub_model_id,
101
- push_to_hub=True,
102
- completion_only_loss=True,
103
- disable_tqdm=True,
104
- logging_first_step=True,
105
- seed=42,
106
- )
107
 
108
- trainer = SFTTrainer(
109
- model=model,
110
- args=training_args,
111
- train_dataset=processed["train"],
112
- eval_dataset=processed["validation"],
113
- peft_config=peft_config,
114
- processing_class=tokenizer,
115
- )
116
-
117
- print("Starting training...")
118
- trainer.train()
119
-
120
- print("Saving and pushing model...")
121
- trainer.save_model(os.path.join(output_dir, "final"))
122
- trainer.push_to_hub()
123
- print(f"Done! Model at https://huggingface.co/{hub_model_id}")
124
-
125
- if __name__ == "__main__":
126
- main()
 
1
  """
2
  Parametric Floorplan Generation Model Training
3
+ Based on: DStruct2Design (arXiv:2407.15723) approach
4
+ Dataset: Custom synthetic dataset generated to match user's ProjectCreate schema
 
 
5
  """
6
  import os
7
  import json
8
  import torch
9
+ from datasets import load_dataset, load_from_disk, DatasetDict
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  from peft import LoraConfig, TaskType
12
  from trl import SFTTrainer, SFTConfig
13
 
14
+ # -----------------------------------------------------------------------------
15
+ # Configuration
16
+ # -----------------------------------------------------------------------------
17
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
18
+ OUTPUT_DIR = "/app/floorplan-model"
19
+ HUB_MODEL_ID = os.environ.get("HF_TRAINER_HUB_MODEL_ID", "Karthik8nitt/parametric-floorplan-generator")
20
+ DATASET_PATH = os.environ.get("DATASET_PATH", "/app/floorplan_synthetic_dataset")
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # -----------------------------------------------------------------------------
23
+ # Load data
24
+ # -----------------------------------------------------------------------------
25
+ print("Loading dataset...")
26
+ if os.path.exists(DATASET_PATH):
27
+ dataset = load_from_disk(DATASET_PATH)
28
+ else:
29
+ # Fallback: load from HF if pre-uploaded
30
+ dataset = load_dataset("Karthik8nitt/floorplan-synthetic-dataset")
 
31
 
32
+ print(f"Train: {len(dataset['train'])}, Val: {len(dataset['validation'])}, Test: {len(dataset['test'])}")
 
33
 
34
+ # -----------------------------------------------------------------------------
35
+ # Load tokenizer & model
36
+ # -----------------------------------------------------------------------------
37
+ print("Loading tokenizer...")
38
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
39
+ if tokenizer.pad_token is None:
40
+ tokenizer.pad_token = tokenizer.eos_token
41
 
42
+ print("Loading model...")
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ MODEL_ID,
45
+ torch_dtype=torch.bfloat16,
46
+ device_map="auto",
47
+ trust_remote_code=True,
48
+ )
49
 
50
+ # -----------------------------------------------------------------------------
51
+ # LoRA config
52
+ # -----------------------------------------------------------------------------
53
+ peft_config = LoraConfig(
54
+ r=16,
55
+ lora_alpha=32,
56
+ lora_dropout=0.05,
57
+ bias="none",
58
+ task_type=TaskType.CAUSAL_LM,
59
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
60
+ )
61
 
62
+ # -----------------------------------------------------------------------------
63
+ # Training arguments
64
+ # -----------------------------------------------------------------------------
65
+ training_args = SFTConfig(
66
+ output_dir=OUTPUT_DIR,
67
+ num_train_epochs=5,
68
+ per_device_train_batch_size=4,
69
+ per_device_eval_batch_size=4,
70
+ gradient_accumulation_steps=4,
71
+ learning_rate=1e-4,
72
+ lr_scheduler_type="cosine",
73
+ warmup_ratio=0.1,
74
+ logging_steps=10,
75
+ eval_strategy="steps",
76
+ eval_steps=100,
77
+ save_strategy="steps",
78
+ save_steps=100,
79
+ save_total_limit=3,
80
+ max_seq_length=4096,
81
+ bf16=True,
82
+ gradient_checkpointing=True,
83
+ report_to="trackio",
84
+ run_name="floorplan-qwen1.5b-lora",
85
+ project="parametric-floorplan",
86
+ hub_model_id=HUB_MODEL_ID,
87
+ push_to_hub=True,
88
+ completion_only_loss=True,
89
+ disable_tqdm=True,
90
+ logging_first_step=True,
91
+ seed=42,
92
+ )
93
 
94
+ # -----------------------------------------------------------------------------
95
+ # Trainer
96
+ # -----------------------------------------------------------------------------
97
+ trainer = SFTTrainer(
98
+ model=model,
99
+ args=training_args,
100
+ train_dataset=dataset["train"],
101
+ eval_dataset=dataset["validation"],
102
+ peft_config=peft_config,
103
+ processing_class=tokenizer,
104
+ )
105
 
106
+ # -----------------------------------------------------------------------------
107
+ # Train
108
+ # -----------------------------------------------------------------------------
109
+ print("Starting training...")
110
+ trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ print("Saving and pushing model...")
113
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
114
+ trainer.push_to_hub()
115
+ print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")