| from typing import IO, Union, Callable |
| from collections import OrderedDict |
| import torch |
| from lightning_fabric.utilities.cloud_io import _load |
| from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH |
| from lightning.pytorch.utilities.migration import pl_legacy_patch |
| from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint |
|
|
|
|
| def pl_load( |
| path_or_url: Union[IO, _PATH], |
| map_location: _MAP_LOCATION_TYPE = None, |
| ) -> OrderedDict[str, torch.Tensor]: |
| r""" |
| Load the `state_dict` only from a PyTorch-Lightning checkpoint. |
| Code is adopted from https://github.com/Lightning-AI/lightning/blob/255b18823e7da265e0e2e3996f55dcd0f78e9f3e/src/lightning/pytorch/core/saving.py |
| """ |
| with pl_legacy_patch(): |
| checkpoint = _load(path_or_url, map_location=map_location) |
| |
| checkpoint = _pl_migrate_checkpoint( |
| checkpoint, checkpoint_path=(path_or_url if isinstance(path_or_url, _PATH) else None) |
| ) |
| return checkpoint["state_dict"] |
|
|
|
|
| def pl_ckpt_to_state_dict( |
| checkpoint_path: str, |
| map_location: _MAP_LOCATION_TYPE = None, |
| key_fn: Callable = lambda x: x, |
| ): |
| r""" |
| Parameters |
| ---------- |
| checkpoint_path: str |
| map_location: _MAP_LOCATION_TYPE |
| A function, torch.device, string or a dict specifying how to remap storage locations. |
| The same as the arg `map_location` in `torch.load()`. |
| key_fn: Callable |
| A function to map the keys in the loaded checkpoint to the desired keys in the returned state_dict. |
| |
| Returns |
| ------- |
| state_dict: OrderedDict |
| """ |
| if map_location is None: |
| map_location = lambda storage, loc: storage |
| pl_ckpt_state_dict = pl_load(checkpoint_path, map_location=map_location) |
| state_dict = {key_fn(key): val for key, val in pl_ckpt_state_dict.items()} |
| return state_dict |
|
|