Karthik8nitt commited on
Commit
6b55b75
·
verified ·
1 Parent(s): b43faef

Add inference script

Browse files
Files changed (1) hide show
  1. generate.py +74 -0
generate.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for parametric floorplan generation.
3
+ Given parametric constraints, generates a JSON floorplan using a fine-tuned model.
4
+ Usage:
5
+ python generate.py --room_count 4 --total_area 100 --room_types Bedroom Bathroom Kitchen LivingRoom
6
+ """
7
+ import json
8
+ import argparse
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ def build_prompt(room_count, total_area, room_types, room_details=None, edges=None):
13
+ lines = [
14
+ f"Generate a floor plan with {room_count} rooms and a total area of {total_area} square meters.",
15
+ f"The room types are: {', '.join(room_types)}."
16
+ ]
17
+ if room_details:
18
+ lines.append("Room details:")
19
+ for i, rd in enumerate(room_details):
20
+ lines.append(f" - Room {i+1} ({rd.get('room_type','unknown')}): area ~{rd.get('area','unspecified')} m², width ~{rd.get('width','unspecified')} m, height ~{rd.get('height','unspecified')} m")
21
+ if edges:
22
+ lines.append(f"Adjacency requirements (room indices): {edges}")
23
+ return "\n".join(lines)
24
+
25
+ def generate_floorplan(model_id, prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
29
+ )
30
+ if tokenizer.pad_token is None:
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+
33
+ messages = [
34
+ {"role": "system", "content": "You are a parametric floorplan generator. Given constraints about room count, area, room types, and adjacencies, output a valid JSON floorplan with room polygons, areas, and adjacency edges."},
35
+ {"role": "user", "content": prompt},
36
+ ]
37
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
38
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
39
+
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ **inputs,
43
+ max_new_tokens=max_new_tokens,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ do_sample=True,
47
+ pad_token_id=tokenizer.pad_token_id,
48
+ )
49
+ return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
50
+
51
+ def main():
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--model_id", type=str, default="Karthik8nitt/parametric-floorplan-generator")
54
+ parser.add_argument("--room_count", type=int, default=4)
55
+ parser.add_argument("--total_area", type=float, default=100.0)
56
+ parser.add_argument("--room_types", nargs="+", default=["Bedroom", "Bathroom", "Kitchen", "LivingRoom"])
57
+ parser.add_argument("--max_new_tokens", type=int, default=1024)
58
+ parser.add_argument("--temperature", type=float, default=0.7)
59
+ parser.add_argument("--top_p", type=float, default=0.9)
60
+ args = parser.parse_args()
61
+
62
+ prompt = build_prompt(args.room_count, args.total_area, args.room_types)
63
+ print("Prompt:\n", prompt)
64
+ print("\n--- Generating floorplan ---\n")
65
+ result = generate_floorplan(args.model_id, prompt, args.max_new_tokens, args.temperature, args.top_p)
66
+ print(result)
67
+ try:
68
+ print("\n--- Parsed JSON ---")
69
+ print(json.dumps(json.loads(result), indent=2))
70
+ except Exception as e:
71
+ print(f"\nWarning: could not parse as JSON: {e}")
72
+
73
+ if __name__ == "__main__":
74
+ main()