File size: 4,955 Bytes
8b306b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
import os
import os.path as osp
import uuid
from typing import Any, Optional
import torch
from safetensors.torch import save_file as save_safetensors, save_model as save_safetensors_model
from .logging import get_logger
from .distributed import get_global_rank
logger = get_logger(__name__)
# 解决循环导入问题:延迟导入 is_hdfs_path, mkdir, copy
def _get_filesystem_funcs():
from ..io.filesystem import is_hdfs_path, mkdir, copy
return is_hdfs_path, mkdir, copy
_local_dir = None
def get_local_dir():
"""
Get a local directory for temporary storage for this process.
"""
global _local_dir
_, mkdir, _ = _get_filesystem_funcs()
if _local_dir is None:
_local_dir = os.path.join("persistence", "rank_" + str(get_global_rank()) + "_" + str(uuid.uuid4()))
mkdir(_local_dir)
return _local_dir
def set_local_dir(dirname):
"""
Set a local directory for temporary storage for this process.
"""
global _local_dir
_, mkdir, _ = _get_filesystem_funcs()
if dirname is None:
return
_local_dir = os.path.join(dirname, str(uuid.uuid4()))
mkdir(_local_dir)
def get_local_path(path: str) -> str:
"""
Get a local path for storing the file.
If the path is already a local path, directly return.
"""
is_hdfs_path, mkdir, _ = _get_filesystem_funcs()
if is_hdfs_path(path):
path = os.path.join(get_local_dir(), os.path.basename(path))
else:
mkdir(os.path.dirname(path))
return path
def convert_dtype(states: Any, dtype: Optional[torch.dtype] = None):
"""
Recursively convert the state_dict to device and dtype.
"""
if dtype is None:
return states
if torch.is_tensor(states):
return states.to("cpu", dtype)
if isinstance(states, dict):
return {k: convert_dtype(v, dtype) for k, v in states.items()}
if isinstance(states, list):
return [convert_dtype(v, dtype) for v in states]
return states
def save(data: Any, path: str, blocking: bool = True, persistence_dir: Optional[str] = None):
"""
安全地将数据保存到指定路径(本地或HDFS)。
此版本使用 get_local_dir 来处理临时文件。
"""
is_hdfs_path, _, copy = _get_filesystem_funcs()
if not is_hdfs_path(path):
if path.endswith(".safetensors"):
if isinstance(data, torch.nn.Module):
save_safetensors_model(data, path)
else:
save_safetensors(data, path)
else:
torch.save(data, path)
logger.info(f"Early saved to local path: {path}")
return
# --- HDFS 路径处理 ---
# 1. 获取一个唯一的本地临时文件路径
if persistence_dir is None:
persistence_dir = get_local_dir()
try:
# 2. 向临时文件写入数据
local_path = osp.join(persistence_dir, osp.basename(path))
if path.endswith(".safetensors"):
if isinstance(data, torch.nn.Module):
save_safetensors_model(data, local_path)
else:
save_safetensors(data, local_path)
else:
torch.save(data, local_path)
logger.info(f"Saved to local path: {local_path}")
# 3. 将本地临时文件复制到HDFS
copy(local_path, path, blocking=blocking)
logger.info(f"Copy {local_path} to HDFS or Local path: {path} done.")
finally:
# NOTE: 因为是重复写入,不需要清理了
pass
# # 4. 清理临时文件
# # NOTE: 暂时只在blocking为True的时候清理
# if osp.exists(persistence_path) and blocking:
# os.remove(persistence_path)
# logger.info(f"Removed temporary file: {persistence_path}")
def dummy_indexes_searchsorted(packed_text_indexes: torch.LongTensor, ce_loss_indexes: torch.LongTensor) -> torch.LongTensor:
"""
使用 searchsorted 方法:
- 对 packed_text_indexes 排序,得到排序值 sorted_vals 和原始下标 sorted_pos。
- 在 sorted_vals 中查找 ce_loss_indexes 的位置 loc。
- 根据 loc 索引 sorted_pos,得到 dummy_indexes。
"""
sorted_vals, sorted_pos = torch.sort(packed_text_indexes)
loc = torch.searchsorted(sorted_vals, ce_loss_indexes)
return sorted_pos[loc]
|