File size: 4,955 Bytes
8b306b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8

import os
import os.path as osp
import uuid
from typing import Any, Optional
import torch
from safetensors.torch import save_file as save_safetensors, save_model as save_safetensors_model
from .logging import get_logger
from .distributed import get_global_rank

logger = get_logger(__name__)

# 解决循环导入问题:延迟导入 is_hdfs_path, mkdir, copy
def _get_filesystem_funcs():
    from ..io.filesystem import is_hdfs_path, mkdir, copy
    return is_hdfs_path, mkdir, copy

_local_dir = None


def get_local_dir():
    """
    Get a local directory for temporary storage for this process.
    """
    global _local_dir
    _, mkdir, _ = _get_filesystem_funcs()
    if _local_dir is None:
        _local_dir = os.path.join("persistence", "rank_" + str(get_global_rank()) + "_" + str(uuid.uuid4()))
        mkdir(_local_dir)
    return _local_dir


def set_local_dir(dirname):
    """
    Set a local directory for temporary storage for this process.
    """
    global _local_dir
    _, mkdir, _ = _get_filesystem_funcs()
    if dirname is None:
        return
    _local_dir = os.path.join(dirname, str(uuid.uuid4()))
    mkdir(_local_dir)


def get_local_path(path: str) -> str:
    """
    Get a local path for storing the file.
    If the path is already a local path, directly return.
    """
    is_hdfs_path, mkdir, _ = _get_filesystem_funcs()
    if is_hdfs_path(path):
        path = os.path.join(get_local_dir(), os.path.basename(path))
    else:
        mkdir(os.path.dirname(path))
    return path


def convert_dtype(states: Any, dtype: Optional[torch.dtype] = None):
    """
    Recursively convert the state_dict to device and dtype.
    """
    if dtype is None:
        return states
    if torch.is_tensor(states):
        return states.to("cpu", dtype)
    if isinstance(states, dict):
        return {k: convert_dtype(v, dtype) for k, v in states.items()}
    if isinstance(states, list):
        return [convert_dtype(v, dtype) for v in states]
    return states


def save(data: Any, path: str, blocking: bool = True, persistence_dir: Optional[str] = None):
    """
    安全地将数据保存到指定路径(本地或HDFS)。
    此版本使用 get_local_dir 来处理临时文件。
    """
    is_hdfs_path, _, copy = _get_filesystem_funcs()
    if not is_hdfs_path(path):
        if path.endswith(".safetensors"):
            if isinstance(data, torch.nn.Module):
                save_safetensors_model(data, path)
            else:
                save_safetensors(data, path)
        else:
            torch.save(data, path)

        logger.info(f"Early saved to local path: {path}")
        return

    # --- HDFS 路径处理 ---
    # 1. 获取一个唯一的本地临时文件路径
    if persistence_dir is None:
        persistence_dir = get_local_dir()

    try:
        # 2. 向临时文件写入数据
        local_path = osp.join(persistence_dir, osp.basename(path))
        if path.endswith(".safetensors"):
            if isinstance(data, torch.nn.Module):
                save_safetensors_model(data, local_path)
            else:
                save_safetensors(data, local_path)
        else:
            torch.save(data, local_path)
        logger.info(f"Saved to local path: {local_path}")

        # 3. 将本地临时文件复制到HDFS
        copy(local_path, path, blocking=blocking)
        logger.info(f"Copy {local_path} to HDFS or Local path: {path} done.")

    finally:
        # NOTE: 因为是重复写入,不需要清理了
        pass

        # # 4. 清理临时文件
        # # NOTE: 暂时只在blocking为True的时候清理
        # if osp.exists(persistence_path) and blocking:
        #     os.remove(persistence_path)
        #     logger.info(f"Removed temporary file: {persistence_path}")

def dummy_indexes_searchsorted(packed_text_indexes: torch.LongTensor, ce_loss_indexes: torch.LongTensor) -> torch.LongTensor:
    """
    使用 searchsorted 方法:
    - 对 packed_text_indexes 排序,得到排序值 sorted_vals 和原始下标 sorted_pos。
    - 在 sorted_vals 中查找 ce_loss_indexes 的位置 loc。
    - 根据 loc 索引 sorted_pos,得到 dummy_indexes。
    """
    sorted_vals, sorted_pos = torch.sort(packed_text_indexes)
    loc = torch.searchsorted(sorted_vals, ce_loss_indexes)
    return sorted_pos[loc]