File size: 4,576 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# astrbot/core/utils/t2i/template_manager.py

import os
import shutil

from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path


class TemplateManager:
    """负责管理 t2i HTML 模板的 CRUD 和重置操作。
    采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。
    所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。
    """

    CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]

    def __init__(self) -> None:
        self.builtin_template_dir = os.path.join(
            get_astrbot_path(),
            "astrbot",
            "core",
            "utils",
            "t2i",
            "template",
        )
        self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")

        os.makedirs(self.user_template_dir, exist_ok=True)
        self._initialize_user_templates()

    def _copy_core_templates(self, overwrite: bool = False) -> None:
        """从内置目录复制核心模板到用户目录。"""
        for filename in self.CORE_TEMPLATES:
            src = os.path.join(self.builtin_template_dir, filename)
            dst = os.path.join(self.user_template_dir, filename)
            if os.path.exists(src) and (overwrite or not os.path.exists(dst)):
                shutil.copyfile(src, dst)

    def _initialize_user_templates(self) -> None:
        """如果用户目录下缺少核心模板,则进行复制。"""
        self._copy_core_templates(overwrite=False)

    def _get_user_template_path(self, name: str) -> str:
        """获取用户模板的完整路径,防止路径遍历漏洞。"""
        if ".." in name or "/" in name or "\\" in name:
            raise ValueError("模板名称包含非法字符。")
        return os.path.join(self.user_template_dir, f"{name}.html")

    def _read_file(self, path: str) -> str:
        """读取文件内容。"""
        with open(path, encoding="utf-8") as f:
            return f.read()

    def list_templates(self) -> list[dict]:
        """列出所有可用模板。
        该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。
        """
        dirs_to_scan = [self.builtin_template_dir, self.user_template_dir]
        all_names = {
            os.path.splitext(f)[0]
            for d in dirs_to_scan
            for f in os.listdir(d)
            if f.endswith(".html")
        }
        return [
            {"name": name, "is_default": name == "base"} for name in sorted(all_names)
        ]

    def get_template(self, name: str) -> str:
        """获取指定模板的内容。
        优先从用户目录加载,如果不存在则回退到内置目录。
        """
        user_path = self._get_user_template_path(name)
        if os.path.exists(user_path):
            return self._read_file(user_path)

        builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html")
        if os.path.exists(builtin_path):
            return self._read_file(builtin_path)

        raise FileNotFoundError("模板不存在。")

    def create_template(self, name: str, content: str) -> None:
        """在用户目录中创建一个新的模板文件。"""
        path = self._get_user_template_path(name)
        if os.path.exists(path):
            raise FileExistsError("同名模板已存在。")
        with open(path, "w", encoding="utf-8") as f:
            f.write(content)

    def update_template(self, name: str, content: str) -> None:
        """更新一个模板。此操作始终写入用户目录。
        如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本,
        从而实现对内置模板的“覆盖”。
        """
        path = self._get_user_template_path(name)
        with open(path, "w", encoding="utf-8") as f:
            f.write(content)

    def delete_template(self, name: str) -> None:
        """仅删除用户目录中的模板文件。
        如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。
        """
        path = self._get_user_template_path(name)
        if not os.path.exists(path):
            raise FileNotFoundError("用户模板不存在,无法删除。")
        os.remove(path)

    def reset_default_template(self) -> None:
        """将核心模板从内置目录强制重置到用户目录。"""
        self._copy_core_templates(overwrite=True)