Karthik8nitt commited on
Commit
8acc209
·
verified ·
1 Parent(s): 2eb3eb4

Add Modal training script

Browse files
Files changed (1) hide show
  1. modal_train.py +429 -0
modal_train.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train the parametric floorplan model on Modal.com GPU instances.
3
+ Usage: modal deploy modal_train.py
4
+ modal run modal_train.py
5
+ """
6
+ import modal
7
+ import os
8
+
9
+ # ---------------------------------------------------------------------------
10
+ # Modal Image & App Setup
11
+ # ---------------------------------------------------------------------------
12
+ image = (
13
+ modal.Image.debian_slim(python_version="3.10")
14
+ .pip_install(
15
+ "transformers>=4.45.0",
16
+ "trl>=0.15.0",
17
+ "torch>=2.3.0",
18
+ "datasets>=2.20.0",
19
+ "peft>=0.12.0",
20
+ "accelerate>=0.33.0",
21
+ "trackio>=0.3.0",
22
+ "huggingface_hub",
23
+ )
24
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
25
+ )
26
+
27
+ app = modal.App("floorplan-trainer", image=image)
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Persistent volume for dataset & model checkpoints
31
+ # ---------------------------------------------------------------------------
32
+ vol = modal.Volume.from_name("floorplan-data", create_if_missing=True)
33
+ MODEL_VOL = modal.Volume.from_name("floorplan-model", create_if_missing=True)
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Secrets
37
+ # ---------------------------------------------------------------------------
38
+ hf_secret = modal.Secret.from_name("huggingface-token")
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Dataset Generation (CPU)
42
+ # ---------------------------------------------------------------------------
43
+ @app.function(volumes={"/data": vol}, secrets=[hf_secret], timeout=1800)
44
+ def generate_dataset_on_modal():
45
+ """Generate the synthetic dataset inside Modal volume."""
46
+ import json, random
47
+ from typing import List, Dict, Any
48
+ from datasets import Dataset, DatasetDict
49
+
50
+ DIRECTIONS = ["N", "S", "E", "W"]
51
+ CITIES = ["Delhi", "Mumbai", "Bangalore", "Chennai", "Hyderabad", "Pune", "Kolkata", "Ahmedabad", "Jaipur", "other"]
52
+ MUNICIPALITIES = ["MC", "MDA", "PMA", "BDA", "GHMC", "BBMP", "MCD", "KMC", "JDA", None]
53
+
54
+ def rect_polygon(x, y, w, d):
55
+ return [[round(v,2) for v in p] for p in [[x,y],[x+w,y],[x+w,y+d],[x,y+d]]]
56
+ def polygon_area(poly):
57
+ n = len(poly); area = 0.0
58
+ for i in range(n):
59
+ x1,y1 = poly[i]; x2,y2 = poly[(i+1)%n]
60
+ area += x1*y2 - x2*y1
61
+ return abs(area)/2.0
62
+ def polygon_bbox(poly):
63
+ xs = [p[0] for p in poly]; ys = [p[1] for p in poly]
64
+ return min(xs), min(ys), max(xs), max(ys)
65
+
66
+ def make_plot_boundary(params):
67
+ shape = params.get("plot_shape","rectangular")
68
+ L,W = params["plot_length"], params["plot_width"]
69
+ if shape == "rectangular": return rect_polygon(0,0,L,W)
70
+ if shape == "l_shaped":
71
+ cw = params.get("cutout_width",L*0.3); ch = params.get("cutout_height",W*0.3)
72
+ corner = params.get("cutout_corner","NE")
73
+ if corner == "NE": return [[0,0],[L,0],[L,W-ch],[L-cw,W-ch],[L-cw,W],[0,W]]
74
+ if corner == "NW": return [[0,0],[L,0],[L,W],[cw,W],[cw,W-ch],[0,W-ch]]
75
+ if corner == "SE": return [[0,0],[L-cw,0],[L-cw,ch],[L,ch],[L,W],[0,W]]
76
+ if corner == "SW": return [[cw,0],[L,0],[L,W],[0,W],[0,ch],[cw,ch]]
77
+ return rect_polygon(0,0,L,W)
78
+ if shape == "trapezoid":
79
+ fw = params.get("plot_front_width",L); rw = params.get("plot_rear_width",L*0.8)
80
+ off = params.get("plot_side_offset",0.0)
81
+ return [[0,0],[fw,0],[fw-off,W],[-off,W]]
82
+ return rect_polygon(0,0,L,W)
83
+
84
+ def make_buildable_boundary(plot_poly, params):
85
+ sf = params.get("setback_front",1.5); sr = params.get("setback_rear",1.0)
86
+ sl = params.get("setback_left",1.0); srt = params.get("setback_right",1.0)
87
+ minx,miny,maxx,maxy = polygon_bbox(plot_poly)
88
+ return rect_polygon(minx+sl,miny+sf,maxx-minx-sl-srt,maxy-miny-sf-sr)
89
+
90
+ def distribute_width(total,n):
91
+ base = total/n; bays = []; remaining = total
92
+ for i in range(n):
93
+ if i == n-1: bays.append(round(remaining,2))
94
+ else:
95
+ bay = max(2.4, round(base,1))
96
+ bay = min(bay, remaining - 2.4*(n-i-1))
97
+ bays.append(bay); remaining -= bay
98
+ return bays
99
+
100
+ def generate_room_specs(params):
101
+ specs = []
102
+ num_bed = params["num_bedrooms"]; num_toi = params["toilets"]
103
+ has_pooja = params.get("has_pooja",False); has_study = params.get("has_study",False)
104
+ has_balc = params.get("has_balcony",False); parking = params.get("parking",False)
105
+ num_floors = params.get("num_floors",1); has_stilt = params.get("has_stilt",False)
106
+ has_basement = params.get("has_basement",False); custom = params.get("custom_room_config",None) or []
107
+ def add(rtype,rid,name,zone,floor="gf",target_area=None):
108
+ specs.append({"id":rid,"type":rtype,"name":name,"zone":zone,"floor":floor,"target_area":target_area})
109
+ add("living","living_1","Living Room","front","gf",18)
110
+ add("kitchen","kitchen_1","Kitchen","mid","gf",9)
111
+ add("dining","dining_1","Dining Area","front","gf",12)
112
+ add("toilet","toilet_common","Common Toilet","back","gf",3.5)
113
+ for i in range(num_bed):
114
+ floor = "gf" if i < num_bed-(num_floors-1) else ("ff" if num_floors>1 else "gf")
115
+ if i==0: add("master_bedroom",f"bedroom_{i+1}","Master Bedroom","back",floor,16)
116
+ else: add("bedroom",f"bedroom_{i+1}",f"Bedroom {i+1}","back",floor,12)
117
+ for i in range(num_toi):
118
+ floor = "gf" if i < num_toi-(num_floors-1) else ("ff" if num_floors>1 else "gf")
119
+ add("toilet",f"toilet_{i+1}",f"Toilet {i+1}","back",floor,3.5)
120
+ if has_pooja: add("pooja","pooja_1","Pooja Room","back","gf",2.5)
121
+ if has_study: floor = "ff" if num_floors>1 else "gf"; add("study","study_1","Study Room","back",floor,8)
122
+ if has_balc: add("balcony","balcony_1","Balcony","side","gf",5)
123
+ add("staircase","stairs_1","Staircase","mid","gf",10)
124
+ if has_stilt or parking: add("parking","parking_1","Parking","side","stilt",15); add("staircase","stairs_stilt","Staircase (Stilt)","mid","stilt",10)
125
+ if has_basement: add("store","store_base","Storage","back","basement",8); add("staircase","stairs_base","Staircase (Basement)","mid","basement",10)
126
+ if num_floors >= 2:
127
+ ff_beds = max(0,num_bed-1)
128
+ for i in range(ff_beds): add("bedroom",f"bedroom_ff_{i+1}",f"Bedroom {num_bed-ff_beds+i+1}","back","ff",12)
129
+ add("living","living_ff","Family Lounge","front","ff",14)
130
+ add("toilet","toilet_ff","Common Toilet (FF)","back","ff",3.5)
131
+ if has_study and num_floors>=2: add("study","study_ff","Study Room","back","ff",8)
132
+ if has_balc: add("balcony","balcony_ff","Balcony (FF)","side","ff",5)
133
+ if num_floors >= 3:
134
+ sf_beds = max(0,num_bed-2)
135
+ for i in range(sf_beds): add("bedroom",f"bedroom_sf_{i+1}",f"Bedroom {num_bed-sf_beds-ff_beds+i+1}","back","sf",12)
136
+ add("living","living_sf","Terrace Lounge","front","sf",12)
137
+ add("toilet","toilet_sf","Common Toilet (SF)","back","sf",3.5)
138
+ if has_balc: add("balcony","balcony_sf","Balcony (SF)","side","sf",5)
139
+ for i,cr in enumerate(custom):
140
+ rtype = cr.get("type","room").lower().replace(" ","_")
141
+ floor_map = {"basement":"basement","stilt":"stilt","gf":"gf","ff":"ff","sf":"sf","either":"gf"}
142
+ floor = floor_map.get(cr.get("floor_preference","either"),"gf")
143
+ add(rtype,f"custom_{i+1}",cr.get("name",f"Custom Room {i+1}"),"mid",floor,cr.get("min_area_sqm",10))
144
+ return specs
145
+
146
+ def place_rooms(buildable_poly, rooms_spec, vastu, road_side, north_dir):
147
+ minx,miny,maxx,maxy = polygon_bbox(buildable_poly)
148
+ bw,bd = maxx-minx,maxy-miny
149
+ num_bays = max(2,min(4,int(bw/3.0)))
150
+ bay_widths = distribute_width(bw, num_bays)
151
+ def place_row(room_list,row_depth,y_start):
152
+ x_cursor = miny
153
+ placed_in_row = []
154
+ for i,room in enumerate(room_list):
155
+ if i >= len(bay_widths): break
156
+ w = bay_widths[i]; d = row_depth
157
+ target = room.get("target_area",w*d)
158
+ if target>0 and w>0:
159
+ adj_d = min(max(target/w,2.4),row_depth)
160
+ d = round(adj_d,2)
161
+ poly = rect_polygon(round(x_cursor,2),round(y_start,2),round(w,2),round(d,2))
162
+ area = polygon_area(poly)
163
+ placed_in_row.append({
164
+ "id":room["id"],"type":room["type"],"name":room["name"],"floor":room.get("floor","gf"),
165
+ "polygon":poly,"area_sqm":round(area,2),
166
+ "dimensions":{"width":round(w,2),"depth":round(d,2)},
167
+ "position":{"x":round(x_cursor+w/2,2),"y":round(y_start+d/2,2)},
168
+ })
169
+ x_cursor += w
170
+ return placed_in_row
171
+ all_rows = []
172
+ front_types = [r for r in rooms_spec if r["type"] in ("living","dining")]
173
+ if front_types: all_rows.append((front_types, bd*0.35))
174
+ mid_types = [r for r in rooms_spec if r["type"] in ("kitchen","utility","staircase","corridor","store")]
175
+ if mid_types: all_rows.append((mid_types, bd*0.3))
176
+ back_types = [r for r in rooms_spec if r["type"] in ("bedroom","master_bedroom","toilet","pooja","study")]
177
+ if back_types: all_rows.append((back_types, bd*0.35))
178
+ y_cursor = miny
179
+ for room_list,row_depth in all_rows:
180
+ yield from place_row(room_list, row_depth, y_cursor)
181
+ y_cursor += row_depth
182
+ side_types = [r for r in rooms_spec if r["type"] in ("balcony","parking")]
183
+ for room in side_types:
184
+ if room["type"] == "balcony" and list(all_rows):
185
+ poly = rect_polygon(minx, maxy, min(bw/num_bays, 2.0), min(bd*0.15, 1.5))
186
+ yield {
187
+ "id": room["id"], "type": "balcony", "name": room.get("name","Balcony"), "floor": room.get("floor","gf"),
188
+ "polygon": [[round(v,2) for v in p] for p in poly], "area_sqm": round(polygon_area(poly),2),
189
+ "dimensions": {"width": round(min(bw/num_bays,2.0),2), "depth": round(min(bd*0.15,1.5),2)},
190
+ "position": {"x": round(minx+min(bw/num_bays,2.0)/2,2), "y": round(maxy+min(bd*0.15,1.5)/2,2)},
191
+ }
192
+ elif room["type"] == "parking":
193
+ poly = rect_polygon(minx, miny, min(bw,7.5), min(bd*0.25,6.0))
194
+ yield {
195
+ "id": room["id"], "type": "parking", "name": room.get("name","Parking"), "floor": room.get("floor","stilt"),
196
+ "polygon": [[round(v,2) for v in p] for p in poly], "area_sqm": round(polygon_area(poly),2),
197
+ "dimensions": {"width": round(min(bw,7.5),2), "depth": round(min(bd*0.25,6.0),2)},
198
+ "position": {"x": round(minx+min(bw,7.5)/2,2), "y": round(miny+min(bd*0.25,6.0)/2,2)},
199
+ }
200
+
201
+ def generate_openings(rooms, road_side):
202
+ doors,windows = [],[]
203
+ entrance = [r for r in rooms if r["type"]=="living" and r["floor"]=="gf"]
204
+ if entrance:
205
+ lr = entrance[0]; mx,my,Mx,My = polygon_bbox(lr["polygon"])
206
+ if road_side in ("N","S"):
207
+ x = round((mx+Mx)/2 - 0.45, 2); y = My if road_side=="N" else my
208
+ doors.append({"id":"door_main","type":"main_entrance","width":0.9,"from":"outside","to":lr["id"],"position":[x,y],"orientation":"horizontal"})
209
+ else:
210
+ x = Mx if road_side=="E" else mx; y = round((my+My)/2 - 0.45, 2)
211
+ doors.append({"id":"door_main","type":"main_entrance","width":0.9,"from":"outside","to":lr["id"],"position":[x,y],"orientation":"vertical"})
212
+ for i,r1 in enumerate(rooms):
213
+ for r2 in rooms[i+1:]:
214
+ if r1["floor"] != r2["floor"]: continue
215
+ m1x,m1y,M1x,M1y = polygon_bbox(r1["polygon"])
216
+ m2x,m2y,M2x,M2y = polygon_bbox(r2["polygon"])
217
+ share_x = not (M1x < m2x or M2x < m1x); share_y = not (M1y < m2y or M2y < m1y)
218
+ if share_x and abs(M1y-m2y) < 0.3:
219
+ x = round(max(m1x,m2x)+0.3,2); y = round(M1y,2)
220
+ doors.append({"id":f"door_{r1['id']}_{r2['id']}","type":"internal","width":0.75,"from":r1["id"],"to":r2["id"],"position":[x,y],"orientation":"horizontal"})
221
+ elif share_y and abs(M1x-m2x) < 0.3:
222
+ x = round(M1x,2); y = round(max(m1y,m2y)+0.3,2)
223
+ doors.append({"id":f"door_{r1['id']}_{r2['id']}","type":"internal","width":0.75,"from":r1["id"],"to":r2["id"],"position":[x,y],"orientation":"vertical"})
224
+ for r in rooms:
225
+ if r["type"] in ("living","bedroom","master_bedroom","dining","kitchen","study"):
226
+ mx,my,Mx,My = polygon_bbox(r["polygon"])
227
+ if (Mx-mx) >= (My-my):
228
+ y = round((my+My)/2,2); cx = (mx+Mx)/2; x = mx if abs(mx-cx) > abs(Mx-cx) else Mx
229
+ windows.append({"id":f"win_{r['id']}","room":r["id"],"width":1.2,"height":1.5,"position":[round(x,2),y],"orientation":"vertical"})
230
+ else:
231
+ x = round((mx+Mx)/2,2); cy = (my+My)/2; y = my if abs(my-cy) > abs(My-cy) else My
232
+ windows.append({"id":f"win_{r['id']}","room":r["id"],"width":1.5,"height":1.2,"position":[x,round(y,2)],"orientation":"horizontal"})
233
+ return doors,windows
234
+
235
+ def generate_example(seed=None):
236
+ if seed is not None: random.seed(seed)
237
+ plot_length = round(random.uniform(8.0,25.0),1)
238
+ plot_width = round(random.uniform(7.0,20.0),1)
239
+ setback_front = round(random.uniform(1.0,3.0),1); setback_rear = round(random.uniform(0.5,2.0),1)
240
+ setback_left = round(random.uniform(0.5,2.0),1); setback_right = round(random.uniform(0.5,2.0),1)
241
+ road_side = random.choice(DIRECTIONS); north_direction = random.choice(DIRECTIONS)
242
+ num_bedrooms = random.randint(1,4); toilets = random.randint(1,num_bedrooms+1)
243
+ parking = random.choice([True,False]); city = random.choice(CITIES)
244
+ vastu_enabled = random.choice([True,False]); road_width_m = round(random.uniform(6.0,18.0),1)
245
+ has_pooja = random.choice([True,False]); has_study = random.choice([True,False])
246
+ has_balcony = random.choice([True,False])
247
+ plot_shape = random.choice(["rectangular"]*8 + ["l_shaped"]*1 + ["trapezoid"]*1)
248
+ plot_front_width = plot_length if plot_shape!="trapezoid" else round(plot_length*random.uniform(0.8,1.0),1)
249
+ plot_rear_width = plot_length if plot_shape!="trapezoid" else round(plot_length*random.uniform(0.7,1.0),1)
250
+ plot_side_offset = 0.0 if plot_shape!="trapezoid" else round(random.uniform(-1.0,1.0),1)
251
+ cutout_corner = random.choice(["NE","NW","SE","SW"])
252
+ cutout_width = round(plot_length*random.uniform(0.15,0.35),1) if plot_shape=="l_shaped" else 0.0
253
+ cutout_height = round(plot_width*random.uniform(0.15,0.35),1) if plot_shape=="l_shaped" else 0.0
254
+ num_floors = random.choices([1,2,3], weights=[5,3,1])[0]
255
+ has_stilt = random.choice([True,False]) if num_floors>1 else False
256
+ has_basement = random.choice([True,False]); municipality = random.choice(MUNICIPALITIES)
257
+ min_buildable = 5.0
258
+ if plot_length - setback_front - setback_rear < min_buildable:
259
+ setback_front = min(setback_front,(plot_length-min_buildable)/2)
260
+ setback_rear = min(setback_rear,(plot_length-min_buildable)/2)
261
+ if plot_width - setback_left - setback_right < min_buildable:
262
+ setback_left = min(setback_left,(plot_width-min_buildable)/2)
263
+ setback_right = min(setback_right,(plot_width-min_buildable)/2)
264
+ params = {
265
+ "name": f"Project_{random.randint(1000,9999)}",
266
+ "plot_length": plot_length, "plot_width": plot_width,
267
+ "setback_front": round(setback_front,1), "setback_rear": round(setback_rear,1),
268
+ "setback_left": round(setback_left,1), "setback_right": round(setback_right,1),
269
+ "road_side": road_side, "north_direction": north_direction,
270
+ "num_bedrooms": num_bedrooms, "toilets": toilets, "parking": parking,
271
+ "city": city, "vastu_enabled": vastu_enabled, "road_width_m": road_width_m,
272
+ "has_pooja": has_pooja, "has_study": has_study, "has_balcony": has_balcony,
273
+ "plot_shape": plot_shape,
274
+ "plot_front_width": plot_front_width if plot_shape=="trapezoid" else None,
275
+ "plot_rear_width": plot_rear_width if plot_shape=="trapezoid" else None,
276
+ "plot_side_offset": plot_side_offset if plot_shape=="trapezoid" else None,
277
+ "plot_corners": None,
278
+ "cutout_corner": cutout_corner, "cutout_width": cutout_width, "cutout_height": cutout_height,
279
+ "num_floors": num_floors, "has_stilt": has_stilt, "has_basement": has_basement,
280
+ "municipality": municipality, "custom_room_config": None, "team_id": None,
281
+ }
282
+ plot_boundary = make_plot_boundary(params)
283
+ buildable_boundary = make_buildable_boundary(plot_boundary, params)
284
+ room_specs = generate_room_specs(params)
285
+ rooms = list(place_rooms(buildable_boundary, room_specs, vastu_enabled, road_side, north_direction))
286
+ doors, windows = generate_openings(rooms, road_side)
287
+ total_area = sum(r["area_sqm"] for r in rooms if r["floor"] not in ("stilt","basement"))
288
+ built_up_area = sum(r["area_sqm"] for r in rooms)
289
+ floorplan = {
290
+ "project_name": params["name"],
291
+ "plot": {
292
+ "shape": plot_shape, "outer_boundary": plot_boundary,
293
+ "setbacks": {"front":params["setback_front"],"rear":params["setback_rear"],"left":params["setback_left"],"right":params["setback_right"]},
294
+ "buildable_boundary": buildable_boundary,
295
+ "road_side": road_side, "north_direction": north_direction,
296
+ "plot_length": plot_length, "plot_width": plot_width,
297
+ },
298
+ "rooms": rooms, "doors": doors, "windows": windows,
299
+ "dimensions": {
300
+ "total_built_up_area_sqm": round(built_up_area,2),
301
+ "total_carpet_area_sqm": round(total_area,2),
302
+ "ground_floor_area_sqm": round(sum(r["area_sqm"] for r in rooms if r["floor"]=="gf"),2),
303
+ "first_floor_area_sqm": round(sum(r["area_sqm"] for r in rooms if r["floor"]=="ff"),2),
304
+ "second_floor_area_sqm": round(sum(r["area_sqm"] for r in rooms if r["floor"]=="sf"),2),
305
+ "stilt_area_sqm": round(sum(r["area_sqm"] for r in rooms if r["floor"]=="stilt"),2),
306
+ "basement_area_sqm": round(sum(r["area_sqm"] for r in rooms if r["floor"]=="basement"),2),
307
+ },
308
+ "meta": {"num_floors": num_floors, "has_stilt": has_stilt, "has_basement": has_basement,
309
+ "vastu_enabled": vastu_enabled, "city": city, "municipality": municipality},
310
+ }
311
+ lines = [
312
+ f"Generate a floor plan for project '{params['name']}'.",
313
+ f"Plot dimensions: {plot_length}m x {plot_width}m, shape: {plot_shape}.",
314
+ f"Setbacks: front={params['setback_front']}m, rear={params['setback_rear']}m, left={params['setback_left']}m, right={params['setback_right']}m.",
315
+ f"Road side: {road_side}, North direction: {north_direction}.",
316
+ f"Requirements: {num_bedrooms} bedrooms, {toilets} toilets.",
317
+ ]
318
+ if parking: lines.append("Parking is required.")
319
+ if has_pooja: lines.append("Include a Pooja room.")
320
+ if has_study: lines.append("Include a Study room.")
321
+ if has_balcony: lines.append("Include a Balcony.")
322
+ if has_stilt: lines.append("Stilt parking required.")
323
+ if has_basement: lines.append("Include a basement.")
324
+ lines.append(f"Number of floors: {num_floors} (1=G, 2=G+1, 3=G+2).")
325
+ if vastu_enabled: lines.append("Vastu compliance is enabled.")
326
+ lines.append(f"City: {city}, Municipality: {municipality or 'N/A'}.")
327
+ prompt = "\n".join(lines)
328
+ return {"prompt": prompt, "completion": json.dumps(floorplan, indent=2), "params": params}
329
+
330
+ ds = DatasetDict({
331
+ "train": Dataset.from_list([generate_example(seed=i) for i in range(5000)]),
332
+ "validation": Dataset.from_list([generate_example(seed=100000+i) for i in range(500)]),
333
+ "test": Dataset.from_list([generate_example(seed=200000+i) for i in range(500)]),
334
+ })
335
+ ds.save_to_disk("/data/floorplan_synthetic_dataset")
336
+ print("Dataset saved to /data/floorplan_synthetic_dataset")
337
+ print(f"Train: {len(ds['train'])}, Val: {len(ds['validation'])}, Test: {len(ds['test'])}")
338
+ return ds
339
+
340
+ # ---------------------------------------------------------------------------
341
+ # Training (GPU)
342
+ # ---------------------------------------------------------------------------
343
+ @app.function(
344
+ gpu="A10G", # or "T4", "A100", "H100"
345
+ volumes={"/data": vol, "/model": MODEL_VOL},
346
+ secrets=[hf_secret],
347
+ timeout=3600*4, # 4 hours
348
+ )
349
+ def train_model():
350
+ import os, json, torch
351
+ from datasets import load_from_disk
352
+ from transformers import AutoModelForCausalLM, AutoTokenizer
353
+ from peft import LoraConfig, TaskType
354
+ from trl import SFTTrainer, SFTConfig
355
+
356
+ MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
357
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "Karthik8nitt/parametric-floorplan-generator")
358
+ OUTPUT_DIR = "/model/floorplan-model"
359
+
360
+ print("Loading dataset from /data...")
361
+ dataset = load_from_disk("/data/floorplan_synthetic_dataset")
362
+ print(f"Loaded: {len(dataset['train'])} train, {len(dataset['validation'])} val")
363
+
364
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
365
+ if tokenizer.pad_token is None:
366
+ tokenizer.pad_token = tokenizer.eos_token
367
+
368
+ print("Loading model...")
369
+ model = AutoModelForCausalLM.from_pretrained(
370
+ MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
371
+ )
372
+
373
+ peft_config = LoraConfig(
374
+ r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
375
+ task_type=TaskType.CAUSAL_LM,
376
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
377
+ )
378
+
379
+ training_args = SFTConfig(
380
+ output_dir=OUTPUT_DIR,
381
+ num_train_epochs=5,
382
+ per_device_train_batch_size=4,
383
+ per_device_eval_batch_size=4,
384
+ gradient_accumulation_steps=4,
385
+ learning_rate=1e-4,
386
+ lr_scheduler_type="cosine",
387
+ warmup_ratio=0.1,
388
+ logging_steps=10,
389
+ eval_strategy="steps",
390
+ eval_steps=100,
391
+ save_strategy="steps",
392
+ save_steps=100,
393
+ save_total_limit=3,
394
+ max_seq_length=4096,
395
+ bf16=True,
396
+ gradient_checkpointing=True,
397
+ report_to="none",
398
+ hub_model_id=HUB_MODEL_ID,
399
+ push_to_hub=True,
400
+ completion_only_loss=True,
401
+ disable_tqdm=True,
402
+ logging_first_step=True,
403
+ seed=42,
404
+ )
405
+
406
+ trainer = SFTTrainer(
407
+ model=model, args=training_args,
408
+ train_dataset=dataset["train"],
409
+ eval_dataset=dataset["validation"],
410
+ peft_config=peft_config,
411
+ processing_class=tokenizer,
412
+ )
413
+
414
+ print("Starting training...")
415
+ trainer.train()
416
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
417
+ trainer.push_to_hub()
418
+ print(f"Done! Model pushed to https://huggingface.co/{HUB_MODEL_ID}")
419
+
420
+ # ---------------------------------------------------------------------------
421
+ # Local entrypoint
422
+ # ---------------------------------------------------------------------------
423
+ @app.local_entrypoint()
424
+ def main():
425
+ print("Step 1: Generate dataset...")
426
+ generate_dataset_on_modal.remote()
427
+ print("Step 2: Train model on GPU...")
428
+ train_model.remote()
429
+ print("All done!")