| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import os.path as osp |
| import uuid |
| from typing import Any, Optional |
| import torch |
| from safetensors.torch import save_file as save_safetensors, save_model as save_safetensors_model |
| from .logging import get_logger |
| from .distributed import get_global_rank |
|
|
| logger = get_logger(__name__) |
|
|
| |
| def _get_filesystem_funcs(): |
| from ..io.filesystem import is_hdfs_path, mkdir, copy |
| return is_hdfs_path, mkdir, copy |
|
|
| _local_dir = None |
|
|
|
|
| def get_local_dir(): |
| """ |
| Get a local directory for temporary storage for this process. |
| """ |
| global _local_dir |
| _, mkdir, _ = _get_filesystem_funcs() |
| if _local_dir is None: |
| _local_dir = os.path.join("persistence", "rank_" + str(get_global_rank()) + "_" + str(uuid.uuid4())) |
| mkdir(_local_dir) |
| return _local_dir |
|
|
|
|
| def set_local_dir(dirname): |
| """ |
| Set a local directory for temporary storage for this process. |
| """ |
| global _local_dir |
| _, mkdir, _ = _get_filesystem_funcs() |
| if dirname is None: |
| return |
| _local_dir = os.path.join(dirname, str(uuid.uuid4())) |
| mkdir(_local_dir) |
|
|
|
|
| def get_local_path(path: str) -> str: |
| """ |
| Get a local path for storing the file. |
| If the path is already a local path, directly return. |
| """ |
| is_hdfs_path, mkdir, _ = _get_filesystem_funcs() |
| if is_hdfs_path(path): |
| path = os.path.join(get_local_dir(), os.path.basename(path)) |
| else: |
| mkdir(os.path.dirname(path)) |
| return path |
|
|
|
|
| def convert_dtype(states: Any, dtype: Optional[torch.dtype] = None): |
| """ |
| Recursively convert the state_dict to device and dtype. |
| """ |
| if dtype is None: |
| return states |
| if torch.is_tensor(states): |
| return states.to("cpu", dtype) |
| if isinstance(states, dict): |
| return {k: convert_dtype(v, dtype) for k, v in states.items()} |
| if isinstance(states, list): |
| return [convert_dtype(v, dtype) for v in states] |
| return states |
|
|
|
|
| def save(data: Any, path: str, blocking: bool = True, persistence_dir: Optional[str] = None): |
| """ |
| 安全地将数据保存到指定路径(本地或HDFS)。 |
| 此版本使用 get_local_dir 来处理临时文件。 |
| """ |
| is_hdfs_path, _, copy = _get_filesystem_funcs() |
| if not is_hdfs_path(path): |
| if path.endswith(".safetensors"): |
| if isinstance(data, torch.nn.Module): |
| save_safetensors_model(data, path) |
| else: |
| save_safetensors(data, path) |
| else: |
| torch.save(data, path) |
|
|
| logger.info(f"Early saved to local path: {path}") |
| return |
|
|
| |
| |
| if persistence_dir is None: |
| persistence_dir = get_local_dir() |
|
|
| try: |
| |
| local_path = osp.join(persistence_dir, osp.basename(path)) |
| if path.endswith(".safetensors"): |
| if isinstance(data, torch.nn.Module): |
| save_safetensors_model(data, local_path) |
| else: |
| save_safetensors(data, local_path) |
| else: |
| torch.save(data, local_path) |
| logger.info(f"Saved to local path: {local_path}") |
|
|
| |
| copy(local_path, path, blocking=blocking) |
| logger.info(f"Copy {local_path} to HDFS or Local path: {path} done.") |
|
|
| finally: |
| |
| pass |
|
|
| |
| |
| |
| |
| |
|
|
| def dummy_indexes_searchsorted(packed_text_indexes: torch.LongTensor, ce_loss_indexes: torch.LongTensor) -> torch.LongTensor: |
| """ |
| 使用 searchsorted 方法: |
| - 对 packed_text_indexes 排序,得到排序值 sorted_vals 和原始下标 sorted_pos。 |
| - 在 sorted_vals 中查找 ce_loss_indexes 的位置 loc。 |
| - 根据 loc 索引 sorted_pos,得到 dummy_indexes。 |
| """ |
| sorted_vals, sorted_pos = torch.sort(packed_text_indexes) |
| loc = torch.searchsorted(sorted_vals, ce_loss_indexes) |
| return sorted_pos[loc] |
|
|