| import os |
| import cv2 |
| import torch |
| import argparse |
| from tqdm import tqdm |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| from ddcolor_model import DDColor |
| from infer import ImageColorizationPipeline |
|
|
|
|
| class DDColorHF(DDColor, PyTorchModelHubMixin): |
| def __init__(self, config): |
| super().__init__(**config) |
|
|
|
|
| class ImageColorizationPipelineHF(ImageColorizationPipeline): |
| def __init__(self, model, input_size): |
| self.input_size = input_size |
| if torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| else: |
| self.device = torch.device("cpu") |
|
|
| self.model = model.to(self.device) |
| self.model.eval() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_name", type=str, default="ddcolor_modelscope") |
| parser.add_argument( |
| "--input", |
| type=str, |
| default="figure/", |
| help="input test image folder or video path", |
| ) |
| parser.add_argument( |
| "--output", type=str, default="results", help="output folder or video path" |
| ) |
| parser.add_argument( |
| "--input_size", type=int, default=512, help="input size for model" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if not os.path.exists(args.model_name): |
| model_name = f"piddnad/{args.model_name}" |
| else: |
| model_name = args.model_name |
|
|
| ddcolor_model = DDColorHF.from_pretrained(model_name) |
|
|
| print(f"Output path: {args.output}") |
| os.makedirs(args.output, exist_ok=True) |
| img_list = os.listdir(args.input) |
| assert len(img_list) > 0 |
|
|
| colorizer = ImageColorizationPipelineHF( |
| model=ddcolor_model, input_size=args.input_size |
| ) |
|
|
| for name in tqdm(img_list): |
| img = cv2.imread(os.path.join(args.input, name)) |
| image_out = colorizer.process(img) |
| cv2.imwrite(os.path.join(args.output, name), image_out) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|