| import torch |
| import numpy as np |
| import opensr_model |
| from typing import Union |
|
|
| def create_opensr_model( |
| device: Union[str, torch.device] = "cpu" |
| ) -> opensr_model: |
| """ Create the super image model |
| Returns: |
| HanModel: The super image model |
| """ |
| model = opensr_model.SRLatentDiffusion(device=device) |
| model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt") |
| model.eval() |
| return model |
|
|
|
|
| def run_opensr_model( |
| model: opensr_model, |
| lr: np.ndarray, |
| hr: np.ndarray, |
| device: Union[str, torch.device] = "cpu" |
| ) -> dict: |
| |
| lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float() |
| hr_img = hr[0:3] |
|
|
| if lr_img.shape[1] == 121: |
| |
| lr_img = torch.nn.functional.pad( |
| lr_img[None], |
| pad=(3, 4, 3, 4), |
| mode='reflect' |
| ).squeeze() |
| |
| |
| with torch.no_grad(): |
| sr_img = model(lr_img[None]).squeeze() |
|
|
| |
| lr_img = lr_img[:, 3:-4, 3:-4] |
| sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4] |
| else: |
| |
| with torch.no_grad(): |
| sr_img = model(lr_img[None]).squeeze() |
|
|
| |
| lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
| sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
| hr_img = hr_img |
| |
| |
| return { |
| "lr": lr_img, |
| "sr": sr_img, |
| "hr": hr_img |
| } |