Spaces:
Running on Zero
Running on Zero
| import torch, glob, os | |
| from typing import Optional, Union | |
| from dataclasses import dataclass | |
| # from modelscope import snapshot_download | |
| from huggingface_hub import snapshot_download as hf_snapshot_download | |
| from typing import Optional | |
| class ModelConfig: | |
| path: Union[str, list[str]] = None | |
| model_id: str = None | |
| origin_file_pattern: Union[str, list[str]] = None | |
| download_source: str = None | |
| local_model_path: str = None | |
| skip_download: bool = None | |
| offload_device: Optional[Union[str, torch.device]] = None | |
| offload_dtype: Optional[torch.dtype] = None | |
| onload_device: Optional[Union[str, torch.device]] = None | |
| onload_dtype: Optional[torch.dtype] = None | |
| preparing_device: Optional[Union[str, torch.device]] = None | |
| preparing_dtype: Optional[torch.dtype] = None | |
| computation_device: Optional[Union[str, torch.device]] = None | |
| computation_dtype: Optional[torch.dtype] = None | |
| clear_parameters: bool = False | |
| def check_input(self): | |
| if self.path is None and self.model_id is None: | |
| raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""") | |
| def parse_original_file_pattern(self): | |
| if self.origin_file_pattern is None or self.origin_file_pattern == "": | |
| return "*" | |
| elif self.origin_file_pattern.endswith("/"): | |
| return self.origin_file_pattern + "*" | |
| else: | |
| return self.origin_file_pattern | |
| def parse_download_source(self): | |
| if self.download_source is None: | |
| if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None: | |
| return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') | |
| else: | |
| return "modelscope" | |
| else: | |
| return self.download_source | |
| def parse_skip_download(self): | |
| if self.skip_download is None: | |
| if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None: | |
| if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true": | |
| return True | |
| elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false": | |
| return False | |
| else: | |
| return False | |
| else: | |
| return self.skip_download | |
| def download(self): | |
| origin_file_pattern = self.parse_original_file_pattern() | |
| downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) | |
| download_source = self.parse_download_source() | |
| # if download_source.lower() == "modelscope": | |
| # snapshot_download( | |
| # self.model_id, | |
| # local_dir=os.path.join(self.local_model_path, self.model_id), | |
| # allow_file_pattern=origin_file_pattern, | |
| # ignore_file_pattern=downloaded_files, | |
| # local_files_only=False | |
| # ) | |
| # elif | |
| if download_source.lower() == "huggingface": | |
| hf_snapshot_download( | |
| self.model_id, | |
| local_dir=os.path.join(self.local_model_path, self.model_id), | |
| allow_patterns=origin_file_pattern, | |
| ignore_patterns=downloaded_files, | |
| local_files_only=False | |
| ) | |
| else: | |
| raise ValueError("`download_source` should be `modelscope` or `huggingface`.") | |
| def require_downloading(self): | |
| if self.path is not None: | |
| return False | |
| skip_download = self.parse_skip_download() | |
| return not skip_download | |
| def reset_local_model_path(self): | |
| if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None: | |
| self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') | |
| elif self.local_model_path is None: | |
| self.local_model_path = "./models" | |
| def download_if_necessary(self): | |
| self.check_input() | |
| self.reset_local_model_path() | |
| if self.require_downloading(): | |
| self.download() | |
| if self.path is None: | |
| if self.origin_file_pattern is None or self.origin_file_pattern == "": | |
| self.path = os.path.join(self.local_model_path, self.model_id) | |
| else: | |
| self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern)) | |
| if isinstance(self.path, list) and len(self.path) == 1: | |
| self.path = self.path[0] | |
| def vram_config(self): | |
| return { | |
| "offload_device": self.offload_device, | |
| "offload_dtype": self.offload_dtype, | |
| "onload_device": self.onload_device, | |
| "onload_dtype": self.onload_dtype, | |
| "preparing_device": self.preparing_device, | |
| "preparing_dtype": self.preparing_dtype, | |
| "computation_device": self.computation_device, | |
| "computation_dtype": self.computation_dtype, | |
| } | |