devsheroubi commited on
Commit
d9cda46
·
verified ·
1 Parent(s): eb00bc1

Upload meshify.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. meshify.py +76 -0
meshify.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ import cv2 as cv
7
+
8
+ # Add TripoSR to path
9
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "TripoSR"))
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import resize_foreground
13
+
14
+ class TripoMeshifier:
15
+ def __init__(self, device="cuda:0"):
16
+ self.device = device
17
+ if not torch.cuda.is_available():
18
+ self.device = "cpu"
19
+
20
+ print(f"Initializing TripoSR on {self.device}...")
21
+ self.model = TSR.from_pretrained(
22
+ "stabilityai/TripoSR",
23
+ config_name="config.yaml",
24
+ weight_name="model.ckpt",
25
+ )
26
+ self.model.renderer.set_chunk_size(8192)
27
+ self.model.to(self.device)
28
+
29
+ def preprocess_image(self, image_path):
30
+ # Load image
31
+ img = cv.imread(image_path)
32
+ if img is None:
33
+ raise ValueError(f"Could not load image from {image_path}")
34
+
35
+ img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
36
+
37
+ # Create alpha channel based on black background
38
+ # We assume the masked image has black background (0,0,0)
39
+ gray = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
40
+ _, mask = cv.threshold(gray, 1, 255, cv.THRESH_BINARY)
41
+
42
+ # Create RGBA
43
+ rgba = cv.cvtColor(img, cv.COLOR_RGB2RGBA)
44
+ rgba[:, :, 3] = mask
45
+
46
+ pil_image = Image.fromarray(rgba)
47
+
48
+ # Resize foreground
49
+ image = resize_foreground(pil_image, 0.85)
50
+
51
+ # Composite on gray background
52
+ image = np.array(image).astype(np.float32) / 255.0
53
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
54
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
55
+
56
+ return image
57
+
58
+ def meshify(self, image_path, output_path):
59
+ print(f"Processing {image_path}...")
60
+ image = self.preprocess_image(image_path)
61
+
62
+ print("Running model...")
63
+ with torch.no_grad():
64
+ scene_codes = self.model([image], device=self.device)
65
+
66
+ print("Extracting mesh...")
67
+ meshes = self.model.extract_mesh(scene_codes, has_vertex_color=True, resolution=256)
68
+ meshes[0].export(output_path)
69
+ print(f"Mesh saved to {output_path}")
70
+
71
+ if __name__ == "__main__":
72
+ meshifier = TripoMeshifier()
73
+ if os.path.exists("masked_image.png"):
74
+ meshifier.meshify("masked_image.png", "output_mesh.obj")
75
+ else:
76
+ print("masked_image.png not found. Please run segment.py first.")