| import torch |
| from zoedepth.models.builder import build_model |
| from zoedepth.utils.config import get_config |
|
|
| class ZoeDepth: |
| def __init__(self, width=512, height=512): |
| conf = get_config("zoedepth_nk", "infer") |
| conf.img_size = [width, height] |
| self.model_zoe = build_model(conf) |
| self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| self.zoe = self.model_zoe.to(self.DEVICE) |
| self.width = width |
| self.height = height |
| |
| def predict(self, image): |
| self.zoe.core.prep.resizer._Resize__width = self.width |
| self.zoe.core.prep.resizer._Resize__height = self.height |
| depth_tensor = self.zoe.infer_pil(image, output_type="tensor") |
| return depth_tensor |
| |
| def to(self, device): |
| self.DEVICE = device |
| self.zoe = self.model_zoe.to(device) |
| |
| def save_raw_depth(self, depth, filepath): |
| depth.save(filepath, format='PNG', mode='I;16') |
| |
| def delete(self): |
| del self.model_zoe |
| del self.zoe |