File size: 6,173 Bytes
6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 fb95e51 6b55b75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
Inference script for parametric floorplan generation.
Generates a JSON floorplan from parametric constraints using a fine-tuned model.
"""
import json
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def build_prompt(params: dict) -> str:
"""Build natural-language prompt from ProjectCreate-like parameters."""
lines = [
f"Generate a floor plan for project '{params.get('name', 'Project')}.'",
f"Plot dimensions: {params['plot_length']}m x {params['plot_width']}m, shape: {params.get('plot_shape', 'rectangular')}.",
f"Setbacks: front={params['setback_front']}m, rear={params['setback_rear']}m, left={params['setback_left']}m, right={params['setback_right']}m.",
f"Road side: {params['road_side']}, North direction: {params.get('north_direction', 'N')}.",
f"Requirements: {params['num_bedrooms']} bedrooms, {params['toilets']} toilets.",
]
if params.get("parking"):
lines.append("Parking is required.")
if params.get("has_pooja"):
lines.append("Include a Pooja room.")
if params.get("has_study"):
lines.append("Include a Study room.")
if params.get("has_balcony"):
lines.append("Include a Balcony.")
if params.get("has_stilt"):
lines.append("Stilt parking required.")
if params.get("has_basement"):
lines.append("Include a basement.")
lines.append(f"Number of floors: {params.get('num_floors', 1)} (1=G, 2=G+1, 3=G+2).")
if params.get("vastu_enabled"):
lines.append("Vastu compliance is enabled.")
city = params.get("city", "other")
municipality = params.get("municipality")
lines.append(f"City: {city}, Municipality: {municipality or 'N/A'}.")
return "\n".join(lines)
def generate_floorplan(model_id: str, prompt: str, max_new_tokens: int = 2048,
temperature: float = 0.7, top_p: float = 0.9):
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
system_msg = (
"You are a parametric floorplan generator for Indian residential construction. "
"Given plot dimensions, setbacks, road direction, number of bedrooms/toilets, "
"and optional rooms (pooja, study, balcony, parking, basement, stilt), "
"output a valid JSON floorplan with plot boundary, buildable boundary, rooms as polygons "
"with dimensions and positions, doors, windows, and area summaries."
)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return generated
def main():
parser = argparse.ArgumentParser(description="Generate a floorplan from parametric input")
parser.add_argument("--model_id", type=str, default="Karthik8nitt/parametric-floorplan-generator")
parser.add_argument("--name", type=str, default="MyHouse")
parser.add_argument("--plot_length", type=float, default=15.0)
parser.add_argument("--plot_width", type=float, default=12.0)
parser.add_argument("--setback_front", type=float, default=1.5)
parser.add_argument("--setback_rear", type=float, default=1.0)
parser.add_argument("--setback_left", type=float, default=1.0)
parser.add_argument("--setback_right", type=float, default=1.0)
parser.add_argument("--road_side", type=str, default="N", choices=["N","S","E","W"])
parser.add_argument("--north_direction", type=str, default="N", choices=["N","S","E","W"])
parser.add_argument("--num_bedrooms", type=int, default=3)
parser.add_argument("--toilets", type=int, default=3)
parser.add_argument("--parking", action="store_true")
parser.add_argument("--has_pooja", action="store_true")
parser.add_argument("--has_study", action="store_true")
parser.add_argument("--has_balcony", action="store_true")
parser.add_argument("--has_stilt", action="store_true")
parser.add_argument("--has_basement", action="store_true")
parser.add_argument("--num_floors", type=int, default=1, choices=[1,2,3])
parser.add_argument("--vastu_enabled", action="store_true")
parser.add_argument("--city", type=str, default="Delhi")
parser.add_argument("--municipality", type=str, default=None)
parser.add_argument("--max_new_tokens", type=int, default=2048)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_p", type=float, default=0.9)
args = parser.parse_args()
params = vars(args)
prompt = build_prompt(params)
print("Prompt:\n", prompt)
print("\n--- Generating floorplan ---\n")
result = generate_floorplan(args.model_id, prompt, args.max_new_tokens, args.temperature, args.top_p)
print(result)
try:
data = json.loads(result)
print("\n--- Parsed JSON (summary) ---")
print(f"Project: {data['project_name']}")
print(f"Plot shape: {data['plot']['shape']}")
print(f"Rooms: {len(data['rooms'])}")
print(f"Doors: {len(data['doors'])}")
print(f"Windows: {len(data['windows'])}")
print(f"Total built-up area: {data['dimensions']['total_built_up_area_sqm']} m²")
print(f"Total carpet area: {data['dimensions']['total_carpet_area_sqm']} m²")
except Exception as e:
print(f"\nWarning: could not parse as JSON: {e}")
if __name__ == "__main__":
main()
|