File size: 6,173 Bytes
6b55b75
 
fb95e51
6b55b75
 
 
 
 
 
fb95e51
 
6b55b75
fb95e51
 
 
 
 
6b55b75
fb95e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b55b75
 
fb95e51
 
6b55b75
 
fb95e51
 
 
 
6b55b75
 
 
 
fb95e51
 
 
 
 
 
 
 
6b55b75
fb95e51
6b55b75
 
fb95e51
6b55b75
fb95e51
6b55b75
 
 
 
 
 
 
 
 
 
fb95e51
 
 
6b55b75
 
fb95e51
6b55b75
fb95e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b55b75
 
 
 
fb95e51
 
6b55b75
 
fb95e51
6b55b75
 
fb95e51
6b55b75
fb95e51
 
 
 
 
 
 
 
 
6b55b75
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Inference script for parametric floorplan generation.
Generates a JSON floorplan from parametric constraints using a fine-tuned model.
"""
import json
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def build_prompt(params: dict) -> str:
    """Build natural-language prompt from ProjectCreate-like parameters."""
    lines = [
        f"Generate a floor plan for project '{params.get('name', 'Project')}.'",
        f"Plot dimensions: {params['plot_length']}m x {params['plot_width']}m, shape: {params.get('plot_shape', 'rectangular')}.",
        f"Setbacks: front={params['setback_front']}m, rear={params['setback_rear']}m, left={params['setback_left']}m, right={params['setback_right']}m.",
        f"Road side: {params['road_side']}, North direction: {params.get('north_direction', 'N')}.",
        f"Requirements: {params['num_bedrooms']} bedrooms, {params['toilets']} toilets.",
    ]
    if params.get("parking"):
        lines.append("Parking is required.")
    if params.get("has_pooja"):
        lines.append("Include a Pooja room.")
    if params.get("has_study"):
        lines.append("Include a Study room.")
    if params.get("has_balcony"):
        lines.append("Include a Balcony.")
    if params.get("has_stilt"):
        lines.append("Stilt parking required.")
    if params.get("has_basement"):
        lines.append("Include a basement.")
    lines.append(f"Number of floors: {params.get('num_floors', 1)} (1=G, 2=G+1, 3=G+2).")
    if params.get("vastu_enabled"):
        lines.append("Vastu compliance is enabled.")
    city = params.get("city", "other")
    municipality = params.get("municipality")
    lines.append(f"City: {city}, Municipality: {municipality or 'N/A'}.")
    return "\n".join(lines)

def generate_floorplan(model_id: str, prompt: str, max_new_tokens: int = 2048,
                       temperature: float = 0.7, top_p: float = 0.9):
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    system_msg = (
        "You are a parametric floorplan generator for Indian residential construction. "
        "Given plot dimensions, setbacks, road direction, number of bedrooms/toilets, "
        "and optional rooms (pooja, study, balcony, parking, basement, stilt), "
        "output a valid JSON floorplan with plot boundary, buildable boundary, rooms as polygons "
        "with dimensions and positions, doors, windows, and area summaries."
    )

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": prompt},
    ]

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )

    generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return generated

def main():
    parser = argparse.ArgumentParser(description="Generate a floorplan from parametric input")
    parser.add_argument("--model_id", type=str, default="Karthik8nitt/parametric-floorplan-generator")
    parser.add_argument("--name", type=str, default="MyHouse")
    parser.add_argument("--plot_length", type=float, default=15.0)
    parser.add_argument("--plot_width", type=float, default=12.0)
    parser.add_argument("--setback_front", type=float, default=1.5)
    parser.add_argument("--setback_rear", type=float, default=1.0)
    parser.add_argument("--setback_left", type=float, default=1.0)
    parser.add_argument("--setback_right", type=float, default=1.0)
    parser.add_argument("--road_side", type=str, default="N", choices=["N","S","E","W"])
    parser.add_argument("--north_direction", type=str, default="N", choices=["N","S","E","W"])
    parser.add_argument("--num_bedrooms", type=int, default=3)
    parser.add_argument("--toilets", type=int, default=3)
    parser.add_argument("--parking", action="store_true")
    parser.add_argument("--has_pooja", action="store_true")
    parser.add_argument("--has_study", action="store_true")
    parser.add_argument("--has_balcony", action="store_true")
    parser.add_argument("--has_stilt", action="store_true")
    parser.add_argument("--has_basement", action="store_true")
    parser.add_argument("--num_floors", type=int, default=1, choices=[1,2,3])
    parser.add_argument("--vastu_enabled", action="store_true")
    parser.add_argument("--city", type=str, default="Delhi")
    parser.add_argument("--municipality", type=str, default=None)
    parser.add_argument("--max_new_tokens", type=int, default=2048)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top_p", type=float, default=0.9)
    args = parser.parse_args()

    params = vars(args)
    prompt = build_prompt(params)
    print("Prompt:\n", prompt)
    print("\n--- Generating floorplan ---\n")

    result = generate_floorplan(args.model_id, prompt, args.max_new_tokens, args.temperature, args.top_p)
    print(result)

    try:
        data = json.loads(result)
        print("\n--- Parsed JSON (summary) ---")
        print(f"Project: {data['project_name']}")
        print(f"Plot shape: {data['plot']['shape']}")
        print(f"Rooms: {len(data['rooms'])}")
        print(f"Doors: {len(data['doors'])}")
        print(f"Windows: {len(data['windows'])}")
        print(f"Total built-up area: {data['dimensions']['total_built_up_area_sqm']} m²")
        print(f"Total carpet area: {data['dimensions']['total_carpet_area_sqm']} m²")
    except Exception as e:
        print(f"\nWarning: could not parse as JSON: {e}")

if __name__ == "__main__":
    main()