Spaces:
Running on Zero
Running on Zero
| from ..core.loader import load_model, hash_model_file | |
| from ..core.vram import AutoWrappedModule | |
| from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS | |
| import importlib, json, torch | |
| class ModelPool: | |
| def __init__(self): | |
| self.model = [] | |
| self.model_name = [] | |
| self.model_path = [] | |
| def import_model_class(self, model_class): | |
| split = model_class.rfind(".") | |
| model_resource, model_class = model_class[:split], model_class[split+1:] | |
| model_class = importlib.import_module(model_resource).__getattribute__(model_class) | |
| return model_class | |
| def need_to_enable_vram_management(self, vram_config): | |
| return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None | |
| def fetch_module_map(self, model_class, vram_config): | |
| if self.need_to_enable_vram_management(vram_config): | |
| if model_class in VRAM_MANAGEMENT_MODULE_MAPS: | |
| module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()} | |
| else: | |
| module_map = {self.import_model_class(model_class): AutoWrappedModule} | |
| else: | |
| module_map = None | |
| return module_map | |
| def load_model_file(self, config, path, vram_config, vram_limit=None): | |
| model_class = self.import_model_class(config["model_class"]) | |
| model_config = config.get("extra_kwargs", {}) | |
| if "state_dict_converter" in config: | |
| state_dict_converter = self.import_model_class(config["state_dict_converter"]) | |
| else: | |
| state_dict_converter = None | |
| module_map = self.fetch_module_map(config["model_class"], vram_config) | |
| model = load_model( | |
| model_class, path, model_config, | |
| vram_config["computation_dtype"], vram_config["computation_device"], | |
| state_dict_converter, | |
| use_disk_map=True, | |
| vram_config=vram_config, module_map=module_map, vram_limit=vram_limit, | |
| ) | |
| return model | |
| def default_vram_config(self): | |
| vram_config = { | |
| "offload_dtype": None, | |
| "offload_device": None, | |
| "onload_dtype": torch.bfloat16, | |
| "onload_device": "cpu", | |
| "preparing_dtype": torch.bfloat16, | |
| "preparing_device": "cpu", | |
| "computation_dtype": torch.bfloat16, | |
| "computation_device": "cpu", | |
| } | |
| return vram_config | |
| def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False): | |
| print(f"Loading models from: {json.dumps(path, indent=4)}") | |
| if vram_config is None: | |
| vram_config = self.default_vram_config() | |
| model_hash = hash_model_file(path) | |
| loaded = False | |
| for config in MODEL_CONFIGS: | |
| if config["model_hash"] == model_hash: | |
| model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit) | |
| if clear_parameters: self.clear_parameters(model) | |
| self.model.append(model) | |
| model_name = config["model_name"] | |
| self.model_name.append(model_name) | |
| self.model_path.append(path) | |
| model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")} | |
| print(f"Loaded model: {json.dumps(model_info, indent=4)}") | |
| loaded = True | |
| if not loaded: | |
| raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}") | |
| def fetch_model(self, model_name, index=None): | |
| fetched_models = [] | |
| fetched_model_paths = [] | |
| for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): | |
| if model_name == model_name_: | |
| fetched_models.append(model) | |
| fetched_model_paths.append(model_path) | |
| if len(fetched_models) == 0: | |
| print(f"No {model_name} models available. This is not an error.") | |
| model = None | |
| elif len(fetched_models) == 1: | |
| print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") | |
| model = fetched_models[0] | |
| else: | |
| if index is None: | |
| model = fetched_models[0] | |
| print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.") | |
| elif isinstance(index, int): | |
| model = fetched_models[:index] | |
| print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.") | |
| else: | |
| model = fetched_models | |
| print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.") | |
| return model | |
| def clear_parameters(self, model: torch.nn.Module): | |
| for name, module in model.named_children(): | |
| self.clear_parameters(module) | |
| for name, param in model.named_parameters(recurse=False): | |
| setattr(model, name, None) | |