Spaces:
Running
Running
| from typing import * | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import trimesh | |
| from .base import Pipeline | |
| from . import samplers, rembg | |
| from ..modules.sparse import SparseTensor | |
| from ..modules import image_feature_extractor | |
| import o_voxel | |
| import cumesh | |
| import nvdiffrast.torch as dr | |
| import cv2 | |
| import flex_gemm | |
| class Trellis2TexturingPipeline(Pipeline): | |
| """ | |
| Pipeline for inferring Trellis2 image-to-3D models. | |
| Args: | |
| models (dict[str, nn.Module]): The models to use in the pipeline. | |
| tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. | |
| tex_slat_sampler_params (dict): The parameters for the texture latent sampler. | |
| shape_slat_normalization (dict): The normalization parameters for the structured latent. | |
| tex_slat_normalization (dict): The normalization parameters for the texture latent. | |
| image_cond_model (Callable): The image conditioning model. | |
| rembg_model (Callable): The model for removing background. | |
| low_vram (bool): Whether to use low-VRAM mode. | |
| """ | |
| model_names_to_load = [ | |
| 'shape_slat_encoder', | |
| 'tex_slat_decoder', | |
| 'tex_slat_flow_model_512', | |
| 'tex_slat_flow_model_1024' | |
| ] | |
| def __init__( | |
| self, | |
| models: dict[str, nn.Module] = None, | |
| tex_slat_sampler: samplers.Sampler = None, | |
| tex_slat_sampler_params: dict = None, | |
| shape_slat_normalization: dict = None, | |
| tex_slat_normalization: dict = None, | |
| image_cond_model: Callable = None, | |
| rembg_model: Callable = None, | |
| low_vram: bool = True, | |
| ): | |
| if models is None: | |
| return | |
| super().__init__(models) | |
| self.tex_slat_sampler = tex_slat_sampler | |
| self.tex_slat_sampler_params = tex_slat_sampler_params | |
| self.shape_slat_normalization = shape_slat_normalization | |
| self.tex_slat_normalization = tex_slat_normalization | |
| self.image_cond_model = image_cond_model | |
| self.rembg_model = rembg_model | |
| self.low_vram = low_vram | |
| self.pbr_attr_layout = { | |
| 'base_color': slice(0, 3), | |
| 'metallic': slice(3, 4), | |
| 'roughness': slice(4, 5), | |
| 'alpha': slice(5, 6), | |
| } | |
| self._device = 'cpu' | |
| def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline": | |
| """ | |
| Load a pretrained model. | |
| Args: | |
| path (str): The path to the model. Can be either local path or a Hugging Face repository. | |
| """ | |
| pipeline = super().from_pretrained(path, config_file) | |
| args = pipeline._pretrained_args | |
| pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) | |
| pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] | |
| pipeline.shape_slat_normalization = args['shape_slat_normalization'] | |
| pipeline.tex_slat_normalization = args['tex_slat_normalization'] | |
| pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) | |
| pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) | |
| pipeline.low_vram = args.get('low_vram', True) | |
| pipeline.pbr_attr_layout = { | |
| 'base_color': slice(0, 3), | |
| 'metallic': slice(3, 4), | |
| 'roughness': slice(4, 5), | |
| 'alpha': slice(5, 6), | |
| } | |
| pipeline._device = 'cpu' | |
| return pipeline | |
| def to(self, device: torch.device) -> None: | |
| self._device = device | |
| if not self.low_vram: | |
| super().to(device) | |
| self.image_cond_model.to(device) | |
| if self.rembg_model is not None: | |
| self.rembg_model.to(device) | |
| def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: | |
| """ | |
| Preprocess the input mesh. | |
| """ | |
| vertices = mesh.vertices | |
| vertices_min = vertices.min(axis=0) | |
| vertices_max = vertices.max(axis=0) | |
| center = (vertices_min + vertices_max) / 2 | |
| scale = 0.99999 / (vertices_max - vertices_min).max() | |
| vertices = (vertices - center) * scale | |
| tmp = vertices[:, 1].copy() | |
| vertices[:, 1] = -vertices[:, 2] | |
| vertices[:, 2] = tmp | |
| assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range' | |
| return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False) | |
| def preprocess_image(self, input: Image.Image) -> Image.Image: | |
| """ | |
| Preprocess the input image. | |
| """ | |
| # if has alpha channel, use it directly; otherwise, remove background | |
| has_alpha = False | |
| if input.mode == 'RGBA': | |
| alpha = np.array(input)[:, :, 3] | |
| if not np.all(alpha == 255): | |
| has_alpha = True | |
| max_size = max(input.size) | |
| scale = min(1, 1024 / max_size) | |
| if scale < 1: | |
| input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) | |
| if has_alpha: | |
| output = input | |
| else: | |
| input = input.convert('RGB') | |
| if self.low_vram: | |
| self.rembg_model.to(self.device) | |
| output = self.rembg_model(input) | |
| if self.low_vram: | |
| self.rembg_model.cpu() | |
| output_np = np.array(output) | |
| alpha = output_np[:, :, 3] | |
| bbox = np.argwhere(alpha > 0.8 * 255) | |
| bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) | |
| center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 | |
| size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) | |
| size = int(size * 1) | |
| bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 | |
| output = output.crop(bbox) # type: ignore | |
| output = np.array(output).astype(np.float32) / 255 | |
| output = output[:, :, :3] * output[:, :, 3:4] | |
| output = Image.fromarray((output * 255).astype(np.uint8)) | |
| return output | |
| def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: | |
| """ | |
| Get the conditioning information for the model. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image prompts. | |
| Returns: | |
| dict: The conditioning information | |
| """ | |
| self.image_cond_model.image_size = resolution | |
| if self.low_vram: | |
| self.image_cond_model.to(self.device) | |
| cond = self.image_cond_model(image) | |
| if self.low_vram: | |
| self.image_cond_model.cpu() | |
| if not include_neg_cond: | |
| return {'cond': cond} | |
| neg_cond = torch.zeros_like(cond) | |
| return { | |
| 'cond': cond, | |
| 'neg_cond': neg_cond, | |
| } | |
| def encode_shape_slat( | |
| self, | |
| mesh: trimesh.Trimesh, | |
| resolution: int = 1024, | |
| ) -> SparseTensor: | |
| """ | |
| Encode the meshes to structured latent. | |
| Args: | |
| mesh (trimesh.Trimesh): The mesh to encode. | |
| resolution (int): The resolution of mesh | |
| Returns: | |
| SparseTensor: The encoded structured latent. | |
| """ | |
| vertices = torch.from_numpy(mesh.vertices).float() | |
| faces = torch.from_numpy(mesh.faces).long() | |
| voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( | |
| vertices.cpu(), faces.cpu(), | |
| grid_size=resolution, | |
| aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], | |
| face_weight=1.0, | |
| boundary_weight=0.2, | |
| regularization_weight=1e-2, | |
| timing=True, | |
| ) | |
| vertices = SparseTensor( | |
| feats=dual_vertices * resolution - voxel_indices, | |
| coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) | |
| ).to(self.device) | |
| intersected = vertices.replace(intersected).to(self.device) | |
| if self.low_vram: | |
| self.models['shape_slat_encoder'].to(self.device) | |
| shape_slat = self.models['shape_slat_encoder'](vertices, intersected) | |
| if self.low_vram: | |
| self.models['shape_slat_encoder'].cpu() | |
| return shape_slat | |
| def sample_tex_slat( | |
| self, | |
| cond: dict, | |
| flow_model, | |
| shape_slat: SparseTensor, | |
| sampler_params: dict = {}, | |
| ) -> SparseTensor: | |
| """ | |
| Sample structured latent with the given conditioning. | |
| Args: | |
| cond (dict): The conditioning information. | |
| shape_slat (SparseTensor): The structured latent for shape | |
| sampler_params (dict): Additional parameters for the sampler. | |
| """ | |
| # Sample structured latent | |
| std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) | |
| mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) | |
| shape_slat = (shape_slat - mean) / std | |
| in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels | |
| noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) | |
| sampler_params = {**self.tex_slat_sampler_params, **sampler_params} | |
| if self.low_vram: | |
| flow_model.to(self.device) | |
| slat = self.tex_slat_sampler.sample( | |
| flow_model, | |
| noise, | |
| concat_cond=shape_slat, | |
| **cond, | |
| **sampler_params, | |
| verbose=True, | |
| tqdm_desc="Sampling texture SLat", | |
| ).samples | |
| if self.low_vram: | |
| flow_model.cpu() | |
| std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) | |
| mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) | |
| slat = slat * std + mean | |
| return slat | |
| def decode_tex_slat( | |
| self, | |
| slat: SparseTensor, | |
| ) -> SparseTensor: | |
| """ | |
| Decode the structured latent. | |
| Args: | |
| slat (SparseTensor): The structured latent. | |
| Returns: | |
| SparseTensor: The decoded texture voxels | |
| """ | |
| if self.low_vram: | |
| self.models['tex_slat_decoder'].to(self.device) | |
| ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5 | |
| if self.low_vram: | |
| self.models['tex_slat_decoder'].cpu() | |
| return ret | |
| def postprocess_mesh( | |
| self, | |
| mesh: trimesh.Trimesh, | |
| pbr_voxel: SparseTensor, | |
| resolution: int = 1024, | |
| texture_size: int = 1024, | |
| ) -> trimesh.Trimesh: | |
| vertices = mesh.vertices | |
| faces = mesh.faces | |
| normals = mesh.vertex_normals | |
| vertices_torch = torch.from_numpy(vertices).float().cuda() | |
| faces_torch = torch.from_numpy(faces).int().cuda() | |
| if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: | |
| uvs = mesh.visual.uv.copy() | |
| uvs[:, 1] = 1 - uvs[:, 1] | |
| uvs_torch = torch.from_numpy(uvs).float().cuda() | |
| else: | |
| _cumesh = cumesh.CuMesh() | |
| _cumesh.init(vertices_torch, faces_torch) | |
| vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True) | |
| vertices_torch = vertices_torch.cuda() | |
| faces_torch = faces_torch.cuda() | |
| uvs_torch = uvs_torch.cuda() | |
| vertices = vertices_torch.cpu().numpy() | |
| faces = faces_torch.cpu().numpy() | |
| uvs = uvs_torch.cpu().numpy() | |
| normals = normals[vmap.cpu().numpy()] | |
| # rasterize | |
| ctx = dr.RasterizeCudaContext() | |
| uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0) | |
| rast, _ = dr.rasterize( | |
| ctx, uvs_torch, faces_torch, | |
| resolution=[texture_size, texture_size], | |
| ) | |
| mask = rast[0, ..., 3] > 0 | |
| pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0] | |
| attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device) | |
| attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d( | |
| pbr_voxel.feats, | |
| pbr_voxel.coords, | |
| shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]), | |
| grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3), | |
| mode='trilinear', | |
| ) | |
| # construct mesh | |
| mask = mask.cpu().numpy() | |
| base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
| metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
| roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
| alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
| # extend | |
| mask = (~mask).astype(np.uint8) | |
| base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA) | |
| metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None] | |
| roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None] | |
| alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None] | |
| material = trimesh.visual.material.PBRMaterial( | |
| baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), | |
| baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), | |
| metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)), | |
| metallicFactor=1.0, | |
| roughnessFactor=1.0, | |
| alphaMode='OPAQUE', | |
| doubleSided=True, | |
| ) | |
| # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) | |
| vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1] | |
| normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1] | |
| uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate | |
| textured_mesh = trimesh.Trimesh( | |
| vertices=vertices, | |
| faces=faces, | |
| vertex_normals=normals, | |
| process=False, | |
| visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) | |
| ) | |
| return textured_mesh | |
| def run( | |
| self, | |
| mesh: trimesh.Trimesh, | |
| image: Image.Image, | |
| seed: int = 42, | |
| tex_slat_sampler_params: dict = {}, | |
| preprocess_image: bool = True, | |
| resolution: int = 1024, | |
| texture_size: int = 2048, | |
| ) -> trimesh.Trimesh: | |
| """ | |
| Run the pipeline. | |
| Args: | |
| mesh (trimesh.Trimesh): The mesh to texture. | |
| image (Image.Image): The image prompt. | |
| seed (int): The random seed. | |
| tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler. | |
| preprocess_image (bool): Whether to preprocess the image. | |
| """ | |
| if preprocess_image: | |
| image = self.preprocess_image(image) | |
| mesh = self.preprocess_mesh(mesh) | |
| torch.manual_seed(seed) | |
| cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024) | |
| shape_slat = self.encode_shape_slat(mesh, resolution) | |
| tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024'] | |
| tex_slat = self.sample_tex_slat( | |
| cond, tex_model, | |
| shape_slat, tex_slat_sampler_params | |
| ) | |
| pbr_voxel = self.decode_tex_slat(tex_slat) | |
| out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size) | |
| return out_mesh | |