| import torch |
| import random |
| import numpy as np |
| import os |
| import os.path as osp |
| import glob |
| from tqdm import tqdm |
|
|
| from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph |
|
|
| from scalelsd.ssl.datasets import dataset_util |
| from scalelsd.ssl.models.detector import ScaleLSD |
| from scalelsd.ssl.misc.train_utils import load_scalelsd_model |
|
|
| from torch.utils.data import DataLoader |
| import torch.utils.data.dataloader as torch_loader |
|
|
| from pathlib import Path |
| import argparse, yaml, logging, time, datetime, cv2, copy, sys, json |
| from easydict import EasyDict |
| import accelerate |
| from accelerate import load_checkpoint_and_dispatch |
| import matplotlib |
| import matplotlib.pyplot as plt |
|
|
| def parse_args(): |
| aparser = argparse.ArgumentParser() |
| aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints') |
| aparser.add_argument('-t','--threshold', default=10,type=float) |
| aparser.add_argument('-i', '--img', required=True, type=str) |
| aparser.add_argument('--width', default=512, type=int) |
| aparser.add_argument('--height', default=512,type=int) |
| aparser.add_argument('--whitebg', default=0.0, type=float) |
| aparser.add_argument('--saveto', default=None, type=str,) |
| aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt']) |
| aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps']) |
| aparser.add_argument('--disable-show', default=False, action='store_true') |
| aparser.add_argument('--draw-junctions-only', default=False, action='store_true') |
| aparser.add_argument('--use_lsd', default=False, action='store_true') |
| aparser.add_argument('--use_nms', default=False, action='store_true') |
|
|
| ScaleLSD.cli(aparser) |
|
|
| args = aparser.parse_args() |
| |
| ScaleLSD.configure(args) |
|
|
| return args |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| model = load_scalelsd_model(args.ckpt, device=args.device) |
|
|
| |
| if args.saveto is None: |
| print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD') |
| args.saveto = 'temp_output/ScaleLSD' |
| os.makedirs(args.saveto,exist_ok=True) |
|
|
| show.painters.HAWPainter.confidence_threshold = args.threshold |
| |
| |
| show.Canvas.show = not args.disable_show |
| if args.whitebg > 0.0: |
| show.Canvas.white_overlay = args.whitebg |
| painter = show.painters.HAWPainter() |
| edge_color = 'orange' |
| vertex_color = 'Cyan' |
|
|
| |
| all_images = [] |
| if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')): |
| all_images.append(args.img) |
| elif os.path.isdir(args.img): |
| for file in os.listdir(args.img): |
| if file.endswith(('.jpg', '.png')): |
| fname = os.path.join(args.img, file) |
| all_images.append(fname) |
| all_images = sorted(all_images) |
| else: |
| raise ValueError('Input must be a file or a directory containing images.') |
|
|
| |
| for fname in tqdm(all_images): |
| pname = Path(fname) |
| image = cv2.imread(fname,0) |
| |
| |
| ori_shape = image.shape[:2] |
| image_cp = copy.deepcopy(image) |
| image_ = cv2.resize(image_cp, (args.width, args.height)) |
| image_ = torch.from_numpy(image_).float()/255.0 |
| image_ = image_[None,None].to(args.device) |
| |
| meta = { |
| 'width': ori_shape[1], |
| 'height':ori_shape[0], |
| 'filename': '', |
| 'use_lsd': args.use_lsd, |
| 'use_nms': args.use_nms, |
| } |
|
|
| with torch.no_grad(): |
| outputs, _ = model(image_, meta) |
| outputs = outputs[0] |
|
|
|
|
| if args.saveto is not None: |
|
|
| if args.ext in ['png', 'pdf']: |
| fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name) |
| with show.image_canvas(fname, fig_file=fig_file) as ax: |
| if args.draw_junctions_only: |
| painter.draw_junctions(ax,outputs) |
| else: |
| |
| painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) |
| elif args.ext == 'json': |
| indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) |
| wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) |
| outpath = osp.join(args.saveto, pname.with_suffix('.json').name) |
| with open(outpath,'w') as f: |
| json.dump(wireframe.jsonize(),f) |
| else: |
| raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext)) |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|