| |
| """ |
| Example usage script for Pre-trained-v2 models |
| Demonstrates how to load and use the physics-based 3D object deformation models |
| """ |
|
|
| import torch |
| import numpy as np |
| import trimesh |
| import json |
| import os |
| from typing import Dict, List, Tuple |
|
|
| class PhysicsDeformationModel: |
| """Wrapper class for loading and using the pre-trained deformation models""" |
| |
| def __init__(self, model_dir: str, model_name: str): |
| """ |
| Initialize the model |
| |
| Args: |
| model_dir: Directory containing the model files |
| model_name: Name of the model (e.g., 'base', 'pot') |
| """ |
| self.model_dir = model_dir |
| self.model_name = model_name |
| |
| |
| self.encoder_path = os.path.join(model_dir, f"{model_name}-encoder.pt") |
| self.decoder_path = os.path.join(model_dir, f"{model_name}-decoder.pt") |
| self.mesh_path = os.path.join(model_dir, f"{model_name}.obj") |
| |
| |
| if not all(os.path.exists(path) for path in [self.encoder_path, self.decoder_path, self.mesh_path]): |
| raise FileNotFoundError(f"Model files not found in {model_dir}") |
| |
| |
| self.reference_mesh = trimesh.load(self.mesh_path) |
| |
| |
| self.encoder = self._load_encoder() |
| self.decoder = self._load_decoder() |
| |
| def _load_encoder(self): |
| """Load the encoder model""" |
| |
| encoder = torch.nn.Sequential( |
| torch.nn.Linear(9, 512), |
| torch.nn.ReLU(), |
| torch.nn.Linear(512, 256), |
| torch.nn.ReLU(), |
| torch.nn.Linear(256, 128), |
| torch.nn.ReLU(), |
| torch.nn.Linear(128, 64) |
| ) |
| |
| |
| encoder.load_state_dict(torch.load(self.encoder_path, map_location='cpu')) |
| encoder.eval() |
| return encoder |
| |
| def _load_decoder(self): |
| """Load the decoder model""" |
| |
| decoder = torch.nn.Sequential( |
| torch.nn.Linear(64, 128), |
| torch.nn.ReLU(), |
| torch.nn.Linear(128, 256), |
| torch.nn.ReLU(), |
| torch.nn.Linear(256, 512), |
| torch.nn.ReLU(), |
| torch.nn.Linear(512, 3) |
| ) |
| |
| |
| decoder.load_state_dict(torch.load(self.decoder_path, map_location='cpu')) |
| decoder.eval() |
| return decoder |
| |
| def prepare_input_conditions(self, impact_point: List[float], |
| velocity: List[float], |
| force: float) -> torch.Tensor: |
| """ |
| Prepare input conditions for the model |
| |
| Args: |
| impact_point: [x, y, z] coordinates of impact point |
| velocity: [vx, vy, vz] velocity vector |
| force: Impact force magnitude |
| |
| Returns: |
| Input tensor for the encoder |
| """ |
| |
| input_data = np.array(impact_point + velocity + [force], dtype=np.float32) |
| |
| |
| |
| input_data[:3] = (input_data[:3] - np.array([0.0, 0.5, 0.0])) / 0.5 |
| input_data[3:6] = input_data[3:6] / 10.0 |
| input_data[6] = input_data[6] / 1000.0 |
| |
| return torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) |
| |
| def predict_deformation(self, impact_point: List[float], |
| velocity: List[float], |
| force: float) -> np.ndarray: |
| """ |
| Predict object deformation given impact conditions |
| |
| Args: |
| impact_point: [x, y, z] coordinates of impact point |
| velocity: [vx, vy, vz] velocity vector |
| force: Impact force magnitude |
| |
| Returns: |
| Deformed vertex positions |
| """ |
| |
| input_tensor = self.prepare_input_conditions(impact_point, velocity, force) |
| |
| |
| with torch.no_grad(): |
| latent = self.encoder(input_tensor) |
| deformation = self.decoder(latent) |
| |
| |
| vertices = deformation.squeeze().numpy().reshape(-1, 3) |
| |
| return vertices |
| |
| def apply_deformation_to_mesh(self, impact_point: List[float], |
| velocity: List[float], |
| force: float) -> trimesh.Trimesh: |
| """ |
| Apply deformation to the reference mesh |
| |
| Args: |
| impact_point: [x, y, z] coordinates of impact point |
| velocity: [vx, vy, vz] velocity vector |
| force: Impact force magnitude |
| |
| Returns: |
| Deformed mesh |
| """ |
| |
| deformed_vertices = self.predict_deformation(impact_point, velocity, force) |
| |
| |
| deformed_mesh = self.reference_mesh.copy() |
| deformed_mesh.vertices = deformed_vertices |
| |
| return deformed_mesh |
| |
| def save_deformed_mesh(self, output_path: str, impact_point: List[float], |
| velocity: List[float], force: float): |
| """ |
| Save deformed mesh to file |
| |
| Args: |
| output_path: Path to save the deformed mesh |
| impact_point: [x, y, z] coordinates of impact point |
| velocity: [vx, vy, vz] velocity vector |
| force: Impact force magnitude |
| """ |
| deformed_mesh = self.apply_deformation_to_mesh(impact_point, velocity, force) |
| deformed_mesh.export(output_path) |
|
|
| def main(): |
| """Example usage of the PhysicsDeformationModel""" |
| |
| |
| model_dir = "base" |
| model_name = "base" |
| |
| |
| impact_point = [0.1, 0.8, 0.1] |
| velocity = [0.0, -5.0, 0.0] |
| force = 500.0 |
| |
| try: |
| |
| print(f"Loading {model_name} model...") |
| model = PhysicsDeformationModel(model_dir, model_name) |
| print("Model loaded successfully!") |
| |
| |
| print("Predicting deformation...") |
| deformed_vertices = model.predict_deformation(impact_point, velocity, force) |
| print(f"Deformation predicted. Output shape: {deformed_vertices.shape}") |
| |
| |
| output_path = f"deformed_{model_name}.obj" |
| model.save_deformed_mesh(output_path, impact_point, velocity, force) |
| print(f"Deformed mesh saved to: {output_path}") |
| |
| |
| original_vertices = model.reference_mesh.vertices |
| deformation_magnitude = np.linalg.norm(deformed_vertices - original_vertices, axis=1) |
| print(f"Average deformation magnitude: {np.mean(deformation_magnitude):.4f}") |
| print(f"Maximum deformation magnitude: {np.max(deformation_magnitude):.4f}") |
| |
| except Exception as e: |
| print(f"Error: {e}") |
| print("Make sure you have the correct model files and dependencies installed.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|