Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # pyre-unsafe | |
| import inspect | |
| import logging | |
| import os | |
| from collections import defaultdict | |
| from dataclasses import field | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch.optim | |
| from accelerate import Accelerator | |
| from pytorch3d.implicitron.models.base_model import ImplicitronModelBase | |
| from pytorch3d.implicitron.tools import model_io | |
| from pytorch3d.implicitron.tools.config import ( | |
| registry, | |
| ReplaceableBase, | |
| run_auto_creation, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class OptimizerFactoryBase(ReplaceableBase): | |
| def __call__( | |
| self, model: ImplicitronModelBase, **kwargs | |
| ) -> Tuple[torch.optim.Optimizer, Any]: | |
| """ | |
| Initialize the optimizer and lr scheduler. | |
| Args: | |
| model: The model with optionally loaded weights. | |
| Returns: | |
| An optimizer module (optionally loaded from a checkpoint) and | |
| a learning rate scheduler module (should be a subclass of torch.optim's | |
| lr_scheduler._LRScheduler). | |
| """ | |
| raise NotImplementedError() | |
| class ImplicitronOptimizerFactory(OptimizerFactoryBase): | |
| """ | |
| A factory that initializes the optimizer and lr scheduler. | |
| Members: | |
| betas: Beta parameters for the Adam optimizer. | |
| breed: The type of optimizer to use. We currently support SGD, Adagrad | |
| and Adam. | |
| exponential_lr_step_size: With Exponential policy only, | |
| lr = lr * gamma ** (epoch/step_size) | |
| gamma: Multiplicative factor of learning rate decay. | |
| lr: The value for the initial learning rate. | |
| lr_policy: The policy to use for learning rate. We currently support | |
| MultiStepLR and Exponential policies. | |
| momentum: A momentum value (for SGD only). | |
| multistep_lr_milestones: With MultiStepLR policy only: list of | |
| increasing epoch indices at which the learning rate is modified. | |
| momentum: Momentum factor for SGD optimizer. | |
| weight_decay: The optimizer weight_decay (L2 penalty on model weights). | |
| foreach: Whether to use new "foreach" implementation of optimizer where | |
| available (e.g. requires PyTorch 1.12.0 for Adam) | |
| group_learning_rates: Parameters or modules can be assigned to parameter | |
| groups. This dictionary has names of those parameter groups as keys | |
| and learning rates as values. All parameter group names have to be | |
| defined in this dictionary. Parameters which do not have predefined | |
| parameter group are put into "default" parameter group which has | |
| `lr` as its learning rate. | |
| """ | |
| betas: Tuple[float, ...] = (0.9, 0.999) | |
| breed: str = "Adam" | |
| exponential_lr_step_size: int = 250 | |
| gamma: float = 0.1 | |
| lr: float = 0.0005 | |
| lr_policy: str = "MultiStepLR" | |
| momentum: float = 0.9 | |
| multistep_lr_milestones: tuple = () | |
| weight_decay: float = 0.0 | |
| linear_exponential_lr_milestone: int = 200 | |
| linear_exponential_start_gamma: float = 0.1 | |
| foreach: Optional[bool] = True | |
| group_learning_rates: Dict[str, float] = field(default_factory=lambda: {}) | |
| def __post_init__(self): | |
| run_auto_creation(self) | |
| def __call__( | |
| self, | |
| last_epoch: int, | |
| model: ImplicitronModelBase, | |
| accelerator: Optional[Accelerator] = None, | |
| exp_dir: Optional[str] = None, | |
| resume: bool = True, | |
| resume_epoch: int = -1, | |
| **kwargs, | |
| ) -> Tuple[torch.optim.Optimizer, Any]: | |
| """ | |
| Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer. | |
| Args: | |
| last_epoch: If the model was loaded from checkpoint this will be the | |
| number of the last epoch that was saved. | |
| model: The model with optionally loaded weights. | |
| accelerator: An optional Accelerator instance. | |
| exp_dir: Root experiment directory. | |
| resume: If True, attempt to load optimizer checkpoint from exp_dir. | |
| Failure to do so will return a newly initialized optimizer. | |
| resume_epoch: If `resume` is True: Resume optimizer at this epoch. If | |
| `resume_epoch` <= 0, then resume from the latest checkpoint. | |
| Returns: | |
| An optimizer module (optionally loaded from a checkpoint) and | |
| a learning rate scheduler module (should be a subclass of torch.optim's | |
| lr_scheduler._LRScheduler). | |
| """ | |
| # Get the parameters to optimize | |
| if hasattr(model, "_get_param_groups"): # use the model function | |
| p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) | |
| else: | |
| p_groups = [ | |
| {"params": params, "lr": self._get_group_learning_rate(group)} | |
| for group, params in self._get_param_groups(model).items() | |
| ] | |
| # Intialize the optimizer | |
| optimizer_kwargs: Dict[str, Any] = { | |
| "lr": self.lr, | |
| "weight_decay": self.weight_decay, | |
| } | |
| if self.breed == "SGD": | |
| optimizer_class = torch.optim.SGD | |
| optimizer_kwargs["momentum"] = self.momentum | |
| elif self.breed == "Adagrad": | |
| optimizer_class = torch.optim.Adagrad | |
| elif self.breed == "Adam": | |
| optimizer_class = torch.optim.Adam | |
| optimizer_kwargs["betas"] = self.betas | |
| else: | |
| raise ValueError(f"No such solver type {self.breed}") | |
| if "foreach" in inspect.signature(optimizer_class.__init__).parameters: | |
| optimizer_kwargs["foreach"] = self.foreach | |
| optimizer = optimizer_class(p_groups, **optimizer_kwargs) | |
| logger.info(f"Solver type = {self.breed}") | |
| # Load state from checkpoint | |
| optimizer_state = self._get_optimizer_state( | |
| exp_dir, | |
| accelerator, | |
| resume_epoch=resume_epoch, | |
| resume=resume, | |
| ) | |
| if optimizer_state is not None: | |
| logger.info("Setting loaded optimizer state.") | |
| optimizer.load_state_dict(optimizer_state) | |
| # Initialize the learning rate scheduler | |
| if self.lr_policy.casefold() == "MultiStepLR".casefold(): | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| milestones=self.multistep_lr_milestones, | |
| gamma=self.gamma, | |
| ) | |
| elif self.lr_policy.casefold() == "Exponential".casefold(): | |
| scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size), | |
| verbose=False, | |
| ) | |
| elif self.lr_policy.casefold() == "LinearExponential".casefold(): | |
| # linear learning rate progression between epochs 0 to | |
| # self.linear_exponential_lr_milestone, followed by exponential | |
| # lr decay for the rest of the epochs | |
| def _get_lr(epoch: int): | |
| m = self.linear_exponential_lr_milestone | |
| if epoch < m: | |
| w = (m - epoch) / m | |
| gamma = w * self.linear_exponential_start_gamma + (1 - w) | |
| else: | |
| epoch_rest = epoch - m | |
| gamma = self.gamma ** (epoch_rest / self.exponential_lr_step_size) | |
| return gamma | |
| scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, _get_lr, verbose=False | |
| ) | |
| else: | |
| raise ValueError("no such lr policy %s" % self.lr_policy) | |
| # When loading from checkpoint, this will make sure that the | |
| # lr is correctly set even after returning. | |
| for _ in range(last_epoch): | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| return optimizer, scheduler | |
| def _get_optimizer_state( | |
| self, | |
| exp_dir: Optional[str], | |
| accelerator: Optional[Accelerator] = None, | |
| resume: bool = True, | |
| resume_epoch: int = -1, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Load an optimizer state from a checkpoint. | |
| resume: If True, attempt to load the last checkpoint from `exp_dir` | |
| passed to __call__. Failure to do so will return a newly initialized | |
| optimizer. | |
| resume_epoch: If `resume` is True: Resume optimizer at this epoch. If | |
| `resume_epoch` <= 0, then resume from the latest checkpoint. | |
| """ | |
| if exp_dir is None or not resume: | |
| return None | |
| if resume_epoch > 0: | |
| save_path = model_io.get_checkpoint(exp_dir, resume_epoch) | |
| if not os.path.isfile(save_path): | |
| raise FileNotFoundError( | |
| f"Cannot find optimizer from epoch {resume_epoch}." | |
| ) | |
| else: | |
| save_path = model_io.find_last_checkpoint(exp_dir) | |
| optimizer_state = None | |
| if save_path is not None: | |
| logger.info(f"Found previous optimizer state {save_path} -> resuming.") | |
| opt_path = model_io.get_optimizer_path(save_path) | |
| if os.path.isfile(opt_path): | |
| map_location = None | |
| if accelerator is not None and not accelerator.is_local_main_process: | |
| map_location = { | |
| "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index | |
| } | |
| optimizer_state = torch.load(opt_path, map_location) | |
| else: | |
| raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.") | |
| return optimizer_state | |
| def _get_param_groups( | |
| self, module: torch.nn.Module | |
| ) -> Dict[str, List[torch.nn.Parameter]]: | |
| """ | |
| Recursively visits all the modules inside the `module` and sorts all the | |
| parameters in parameter groups. | |
| Uses `param_groups` dictionary member, where keys are names of individual | |
| parameters or module members and values are the names of the parameter groups | |
| for those parameters or members. "self" key is used to denote the parameter groups | |
| at the module level. Possible keys, including the "self" key do not have to | |
| be defined. By default all parameters have the learning rate defined in the | |
| optimizer. This can be overridden by setting the parameter group in `param_groups` | |
| member of a specific module. Values are a parameter group name. The keys | |
| specify what parameters will be affected as follows: | |
| - “self”: All the parameters of the module and its child modules | |
| - name of a parameter: A parameter with that name. | |
| - name of a module member: All the parameters of the module and its | |
| child modules. | |
| This is useful if members do not have `param_groups`, for | |
| example torch.nn.Linear. | |
| - <name of module member>.<something>: recursive. Same as if <something> | |
| was used in param_groups of that submodule/member. | |
| Args: | |
| module: module from which to extract the parameters and their parameter | |
| groups | |
| Returns: | |
| dictionary with parameter groups as keys and lists of parameters as values | |
| """ | |
| param_groups = defaultdict(list) | |
| def traverse(module, default_group: str, mapping: Dict[str, str]) -> None: | |
| """ | |
| Visitor for module to assign its parameters to the relevant member of | |
| param_groups. | |
| Args: | |
| module: the module being visited in a depth-first search | |
| default_group: the param group to assign parameters to unless | |
| otherwise overriden. | |
| mapping: known mappings of parameters to groups for this module, | |
| destructively modified by this function. | |
| """ | |
| # If key self is defined in param_groups then chenge the default param | |
| # group for all parameters and children in the module. | |
| if hasattr(module, "param_groups") and "self" in module.param_groups: | |
| default_group = module.param_groups["self"] | |
| # Collect all the parameters that are directly inside the `module`, | |
| # they will be in the default param group if they don't have | |
| # defined group. | |
| if hasattr(module, "param_groups"): | |
| mapping.update(module.param_groups) | |
| for name, param in module.named_parameters(recurse=False): | |
| if param.requires_grad: | |
| group_name = mapping.get(name, default_group) | |
| logger.debug(f"Assigning {name} to param_group {group_name}") | |
| param_groups[group_name].append(param) | |
| # If children have defined default param group then use it else pass | |
| # own default. | |
| for child_name, child in module.named_children(): | |
| mapping_to_add = { | |
| name[len(child_name) + 1 :]: group | |
| for name, group in mapping.items() | |
| if name.startswith(child_name + ".") | |
| } | |
| traverse(child, mapping.get(child_name, default_group), mapping_to_add) | |
| traverse(module, "default", {}) | |
| return param_groups | |
| def _get_group_learning_rate(self, group_name: str) -> float: | |
| """ | |
| Wraps the `group_learning_rates` dictionary providing errors and returns | |
| `self.lr` for "default" group_name. | |
| Args: | |
| group_name: a string representing the name of the group | |
| Returns: | |
| learning rate for a specific group | |
| """ | |
| if group_name == "default": | |
| return self.lr | |
| lr = self.group_learning_rates.get(group_name, None) | |
| if lr is None: | |
| raise ValueError(f"no learning rate given for group {group_name}") | |
| return lr | |