| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| import argparse |
| from datetime import datetime |
| import random |
| from typing import Any, List |
| from functools import partial |
|
|
| from accelerate.utils import set_seed |
| import omegaconf |
|
|
| sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) |
|
|
| from libs.engine import merge_and_update_config |
| from libs.utils.argparse import accelerate_parser, base_data_parser |
|
|
|
|
| def render_batch_wrap(args: omegaconf.DictConfig, |
| seed_range: List, |
| pipeline: Any, |
| **pipe_args): |
| start_time = datetime.now() |
| for idx, seed in enumerate(seed_range): |
| args.seed = seed |
| print(f"\n-> [{idx}/{len(seed_range)}], " |
| f"current seed: {seed}, " |
| f"current time: {datetime.now() - start_time}\n") |
| pipe = pipeline(args) |
| pipe.painterly_rendering(**pipe_args) |
|
|
|
|
| def main(args, seed_range): |
| args.batch_size = 1 |
|
|
| args.width = float(args.width) |
|
|
| render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range) |
|
|
| if args.task == "diffsketcher": |
| from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline |
|
|
| if not args.render_batch: |
| pipe = DiffSketcherPipeline(args) |
| pipe.painterly_rendering(args.prompt) |
| else: |
| render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt) |
|
|
| elif args.task == "style-diffsketcher": |
| from pipelines.painter.diffsketcher_stylized_pipeline import StylizedDiffSketcherPipeline |
|
|
| if not args.render_batch: |
| pipe = StylizedDiffSketcherPipeline(args) |
| pipe.painterly_rendering(args.prompt, args.style_file) |
| else: |
| render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=args.prompt, style_fpath=args.style_file) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser( |
| description="vary style and content painterly rendering", |
| parents=[accelerate_parser(), base_data_parser()] |
| ) |
| |
| parser.add_argument("-tk", "--task", |
| default="diffsketcher", type=str, |
| choices=['diffsketcher', 'style-diffsketcher'], |
| help="choose a method.") |
| |
| parser.add_argument("-c", "--config", |
| required=True, type=str, |
| default="", |
| help="YAML/YML file for configuration.") |
| parser.add_argument("-style", "--style_file", |
| default="", type=str, |
| help="the path of style img place.") |
| |
| parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str) |
| parser.add_argument("-npt", "--negative_prompt", default="", type=str) |
| |
| parser.add_argument("--print_timing", "-timing", action="store_true", |
| help="set print svg rendering timing.") |
| |
| parser.add_argument("--download", action="store_true", |
| help="download models from huggingface automatically.") |
| parser.add_argument("--force_download", "-download", action="store_true", |
| help="force the models to be downloaded from huggingface.") |
| parser.add_argument("--resume_download", "-dpm_resume", action="store_true", |
| help="download the models again from the breakpoint.") |
| |
| |
| parser.add_argument("--render_batch", "-rdbz", action="store_true") |
| parser.add_argument("-srange", "--seed_range", |
| required=False, nargs='+', |
| help="Sampling quantity.") |
| |
| parser.add_argument("-mv", "--make_video", action="store_true", |
| help="make a video of the rendering process.") |
| parser.add_argument("-frame_freq", "--video_frame_freq", |
| default=1, type=int, |
| help="video frame control.") |
| parser.add_argument("-framerate", "--video_frame_rate", |
| default=36, type=int, |
| help="by adjusting the frame rate, you can control the playback speed of the output video.") |
|
|
| args = parser.parse_args() |
|
|
| |
| seed_range = None |
| if args.render_batch: |
| |
| start_, end_ = 1, 1000000 |
| if args.seed_range is not None: |
| seed_range_ = list(args.seed_range) |
| assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0]) |
| start_, end_ = int(seed_range_[0]), int(seed_range_[1]) |
| seed_range = [i for i in range(start_, end_)] |
| else: |
| |
| numbers = list(range(start_, end_)) |
| seed_range = random.sample(numbers, k=1000) |
|
|
| args = merge_and_update_config(args) |
|
|
| set_seed(args.seed) |
| main(args, seed_range) |
|
|