yetrun commited on
Commit
a5fd608
·
0 Parent(s):

ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务

Browse files

模型架构:
- 手写 Mini GPT(Transformer):实现 PositionalEmbedding、TransformerDecoder 组件
- 手写 RNN:LSTM 堆叠结构
训练系统:
- Pipeline 框架:数据加载、Tokenizer、训练、生成全流程封装
- Checkpoint 机制:支持断点续训、分代模型保存和加载
- TensorBoard 训练监控
任务实现:
- Wiki GPT:基于中文维基语料的中文文本生成
- 诗歌生成器(GPT):基于 Transformer 的诗歌生成
- 诗歌生成器(RNN):基于 LSTM 的诗歌生成
工程支持:
- Gradio 交互界面
- Hugging Face Space 部署配置
- pytest 测试套件
其他特点:
- 多种采样方法:top-k、随机采样(temperature)、贪婪搜索
- Pipeline 重要支持组件:ModelBuidler、DataBundle

This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.keras filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ /.idea
3
+ /local
AGENTS.md ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Agent 编码规范
2
+
3
+ ## 防御性编程精简
4
+
5
+ 避免过度防御性编程,遵循以下原则:
6
+
7
+ ### 1. None 检查
8
+
9
+ - **不要**进行显式的 None 检查
10
+ - 信任输入数据,让程序在真正的错误点上失败
11
+ - 避免 `if x is not None:` 这样的防御性代码
12
+
13
+ ```python
14
+ # ❌ 避免
15
+ def process(data):
16
+ if data is not None:
17
+ return data.value
18
+ return None
19
+
20
+ # ✅ 推荐
21
+ def process(data):
22
+ return data.value
23
+ ```
24
+
25
+ ### 2. 类型检查
26
+
27
+ - **不要**使用 `isinstance`、`type()`、`typeof` 等进行运行时类型检查
28
+ - 依靠类型提示和静态类型检查工具(如 mypy)
29
+ - 让 Duck Typing 发挥作用
30
+
31
+ ```python
32
+ # ❌ 避免
33
+ def calculate(obj):
34
+ if isinstance(obj, int):
35
+ return obj * 2
36
+ elif isinstance(obj, str):
37
+ return obj * 2
38
+ else:
39
+ raise TypeError("不支持的类型")
40
+
41
+ # ✅ 推荐
42
+ def calculate(obj: int | str) -> int | str:
43
+ return obj * 2
44
+ ```
45
+
46
+ ### 3. 异常处理
47
+
48
+ - **不要**滥用 try-catch 来压制异常
49
+ - **不要**用 try-catch 让程序"容错"运行
50
+ - 只在真正需要处理异常的地方捕获
51
+ - 让未处理的异常自然抛出,暴露真正的问题
52
+
53
+ ```python
54
+ # ❌ 避免 - 压制异常
55
+ import logging
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ def parse_config(path):
60
+ try:
61
+ with open(path) as f:
62
+ return json.load(f)
63
+ except Exception as e:
64
+ logger.error(f"加载配置失败: {e}")
65
+ return {} # 返回空配置让程序继续运行
66
+
67
+ # ✅ 推荐 - 让异常传播
68
+ def parse_config(path):
69
+ with open(path) as f:
70
+ return json.load(f)
71
+
72
+ # ✅ 或仅在必要时转换异常类型
73
+ def parse_config(path):
74
+ try:
75
+ with open(path) as f:
76
+ return json.load(f)
77
+ except json.JSONDecodeError as e:
78
+ raise ConfigError(f"配置文件格式错误: {e}") from e
79
+ ```
80
+
81
+ ### 4. 原则总结
82
+
83
+ 1. **早失败(Fail Fast)** - 让错误尽早暴露,不要试图掩盖
84
+ 2. **信任调用方** - 假设调用方会提供正确的输入
85
+ 3. **清晰错误信息** - 让异常信息直接指出问题所在
86
+ 4. **代码简洁** - 减少不必要的检查代码,专注于业务逻辑
87
+
88
+ ---
89
+
90
+ **核心信条**:清晰的代码比健壮的代码更重要。让错误暴露,让问题可见。
91
+
92
+ ## 编辑文件时的精准修改原则
93
+
94
+ 在进行代码编辑时,**只修改必要的部分**,不要进行任何无关改动:
95
+
96
+ ### 禁止的无关改动
97
+
98
+ - **不要**调整代码缩进或格式
99
+ - **不要**重排 import 语句的顺序
100
+ - **不要**添加或删除空行
101
+ - **不要**修改注释(除非任务明确要求)
102
+ - **不要**修改变量名、函数名等标识符(除非任务明确要求)
103
+ - **不要**进行任何代码重构(除非任务明确要求)
104
+
105
+ ### ✅ 正确示例
106
+
107
+ 如果任务是将 `import config` 改为 `from mini_gpt import config`:
108
+
109
+ ```python
110
+ # 修改前
111
+ import config
112
+ from typing import Callable
113
+
114
+ # 修改后 - 只修改 import 语句,其他保持不变
115
+ from mini_gpt import config
116
+ from typing import Callable
117
+ ```
118
+
119
+ ### ❌ 错误示例
120
+
121
+ #### 示例1:无关地调整 import 顺序
122
+
123
+ ```python
124
+ # 修改前
125
+ import config
126
+ from typing import Callable
127
+
128
+ # 错误 - 无关地调整了 import 顺序
129
+ from typing import Callable
130
+ from mini_gpt import config
131
+ ```
132
+
133
+ #### 示例2:无关地修改函数参数格式
134
+
135
+ ```python
136
+ # 修改前
137
+ def my_function(
138
+ param1,
139
+ param2,
140
+ param3,
141
+ ):
142
+ pass
143
+
144
+ # 错误 - 任务只要求修改函数体,却无关地修改了参数格式
145
+ def my_function(
146
+ param1, # 调整了缩进宽度
147
+ param2,
148
+ param3 # 去掉了尾部逗号
149
+ ):
150
+ pass
151
+ ```
152
+
153
+ **原则**:最小化改动范围,只改必须改的地方。
154
+
155
+ ## 运行单元测试
156
+
157
+ 本项目使用 pytest 运行单元测试,必须在 `mini-gpt` conda 环境中执行。
158
+
159
+ ### 运行命令
160
+
161
+ ```bash
162
+ /Users/run/anaconda3/envs/mini-gpt/bin/python -m pytest test/ -v
163
+ ```
164
+
165
+ ### 重要提示
166
+
167
+ 1. **必须使用 mini-gpt 环境** - 基础环境缺少 tensorflow 依赖,会导致测试收集失败
168
+ 2. **不要添加 `pytest.importorskip("tensorflow")`** - 这些测试依赖 tensorflow,跳过会掩盖真正的问题
169
+
170
+ ## Python 代码风格
171
+
172
+ ### 禁止尾逗号
173
+
174
+ **任何情况下都不应出现尾逗号**(trailing comma)。
175
+
176
+ ```python
177
+ # ❌ 避免 - 尾逗号
178
+ my_list = [
179
+ 1,
180
+ 2,
181
+ 3,
182
+ ]
183
+
184
+ # ✅ 推荐
185
+ my_list = [
186
+ 1,
187
+ 2,
188
+ 3
189
+ ]
190
+
191
+ # ❌ 避免 - 函数参数尾逗号
192
+ def my_func(
193
+ arg1,
194
+ arg2,
195
+ ):
196
+ pass
197
+
198
+ # ✅ 推荐
199
+ def my_func(
200
+ arg1,
201
+ arg2
202
+ ):
203
+ pass
204
+
205
+ # ❌ 避免 - 字典尾逗号
206
+ my_dict = {
207
+ "key1": "value1",
208
+ "key2": "value2",
209
+ }
210
+
211
+ # ✅ 推荐
212
+ my_dict = {
213
+ "key1": "value1",
214
+ "key2": "value2"
215
+ }
216
+ ```
217
+
218
+ ## 禁止命令行参数
219
+
220
+ 永远不要在代码中使用命令行参数(如 `argparse`、`sys.argv` 等)。配置应通过代码中硬编码实现。
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: General Deep Learning
3
+ emoji: 🏃
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: General Deep Learning is a practical deep learning experimen
12
+ ---
13
+
14
+ # 通用深度学习(General Deep Learning)
15
+
16
+ ## 项目简介
17
+
18
+ **通用深度学习(General Deep Learning)** 是一个面向实践的深度学习实验平台,致力于打造"训练-部署-体验"一体化的完整工作流。
19
+
20
+ ### ✨ 为什么适合你?
21
+
22
+ **🎯 我的愿景**
23
+ - 构建一个**从零开始、透明可学、工程模块化**的深度学习平台。
24
+
25
+ **🎓学习友好**
26
+ - ✅ **纯手工从零构建** - Transformer、RNN 都是一行行代码手撸
27
+ - ✅ **代码即教程** - 没有黑盒封装,每个组件清晰可见
28
+ - ✅ **完整的训练闭环** - 从数据处理到部署,全流程透明
29
+ -
30
+ **🔧 技术特性**
31
+ - ✅ **覆盖主流模型** - Transformer、RNN,未来将扩展至 CNN、Diffusion 等
32
+ - ✅ **模块化架构** - 可插拔设计,新模型/新数据集快速接入
33
+ - ✅ **生产级部署** - 一键部署到 Hugging Face,支持断点续训、TensorBoard 监控
34
+
35
+ ### 📅 关于这个项目
36
+
37
+ > *历时俩月,忙里偷闲。*
38
+
39
+ 这不是一个追求最新模型的项目,而是一个**"代码即教程"**的个人实验场。
40
+
41
+ **已完成功能**:
42
+ - Wiki GPT - 基于中文维基的手写 Transformer
43
+ - 诗歌生成器 - GPT 和 RNN 双版本对比
44
+
45
+ **未来规划**:
46
+ 4 月有事不再投入,5 月开始计划每月新增一个模型,探索更多架构(CNN、Diffusion...)
47
+
48
+ - 🔮 逐步扩展至 CV、多模态等领域
49
+ - 🔮 保持"从零手撸"的风格,让每个新模型都成为学习素材
50
+
51
+ **欢迎一起折腾** —— 反馈问题、贡献代码,或单纯聊聊技术!
52
+
53
+ ### 🤗 在线体验
54
+
55
+ [![Hugging Face Space](https://img.shields.io/badge/🤗-Hugging%20Face%20Space-blue)](https://huggingface.co/spaces/yetrun/general-deep-learning)
56
+
57
+ 🚀 **在线体验**:[点击访问 Hugging Face Space](https://huggingface.co/spaces/yetrun/general-deep-learning)
58
+
59
+ 本项目已部署到 Hugging Face Space,你可以在线体验以下功能:
60
+
61
+ - **Wiki GPT 文本生成**:基于 Transformer 架构的中文文本生成,训练数据来自中文维基语料库
62
+ - **诗歌生成器(GPT)**:基于 Transformer 的中文诗歌生成,支持五言、七言诗等
63
+ - **诗歌生成器(RNN)**:基于 RNN 架构的中文诗歌生成,支持五言、七言诗等
64
+
65
+ ## 部署说明
66
+
67
+ 本项目已配置为 Hugging Face Space 兼容格式,如需更新部署:
68
+
69
+ ```bash
70
+ # 1. 在 Hugging Face 创建新的 Space(选择 Gradio SDK)
71
+ # 2. 绑定 Space 远程仓库
72
+ git remote add huggingface https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
73
+ # 3. 确保依赖同步(生成 requirements.txt)
74
+ python3 generate_requirements.py
75
+ # 4. 提交并推送
76
+ git push huggingface master
77
+ ```
78
+
79
+ ## 本地开发
80
+
81
+ ### Conda 环境使用
82
+
83
+ 使用方法:
84
+ ```bash
85
+ # 创建环境
86
+ conda env create -f <environment.yml>
87
+ # 激活环境
88
+ conda activate general-dl
89
+ # 更新 environment.yml
90
+ conda env update -f <environment.yml> --prune
91
+ ```
92
+
93
+ 上述 `<environment.yml>` 是环境配置文件的路径,需要替换成实际的文件名:
94
+
95
+ - 如果你是本地开发,使用 `environment.yml`(Mac Intel 64 环境,`ENV=test`)
96
+ - 如果你是在远程服务器上运行,使用 `environments-linux.yml`(Linux 服务器环境,`ENV=production`)
97
+
98
+ > **插曲:**
99
+ >
100
+ > 环境配置出现了问题,强制重新安装 tensorflow-text 才修复。
101
+ >
102
+ > ```bash
103
+ > pip uninstall tensorflow-text -y
104
+ > pip install --no-cache-dir --force-reinstall tensorflow-text==2.20.0
105
+ > ```
106
+
107
+ ### 开发工具配置
108
+
109
+ #### TensorBoard 说明
110
+
111
+ 训练时,调用 `tensorboard --logdir=<logdir>` 来启动 TensorBoard,默认访问地址是 http://localhost:6006/.
112
+
113
+ `<logdir>` 通常是 `local/tasks/<project_name>/tensorboard`.
114
+
115
+ > 冷知识:tensorboard 中的代数与我们常规认为的代数不一致,第一代的计数是 0.
116
+
117
+ #### JetBrains 远程开发配置
118
+
119
+ 配置本地代码映射:
120
+
121
+ 1. 菜单栏:Tools → Deployment → Configuration
122
+ 2. 配置目录映射:切换到Mappings标签页,Deployment path 设置远程目录路径
123
+ 3. 配置排除目录,一般可排除的本地目录包括:`data/dev`, `local`, `test`.
124
+
125
+ 手工同步:
126
+
127
+ - 右键文件/目录 → Deployment → Upload to...
128
+
129
+ ## 数据集说明
130
+
131
+ ### WIKI 数据集
132
+
133
+ *(本项目中 `wiki_gpt` 任务使用了中文维基语料库进行训练)*
134
+
135
+ 下载维基百科的数据。
136
+
137
+ ```bash
138
+ wget https://dumps.wikimedia.org/other/mediawiki_content_current/zhwiki/2026-01-01/xml/bzip2/zhwiki-2026-01-01-p1p5254490.xml.bz2
139
+ wget https://dumps.wikimedia.org/other/mediawiki_content_current/zhwiki/2026-01-01/xml/bzip2/zhwiki-2026-01-01-p5254491p9382552.xml.bz2
140
+ ```
141
+
142
+ 维基百科的数据分成两个文件,可使用 cat 命令合并成一个文件:
143
+
144
+ ```bash
145
+ cat zhwiki-2026-01-01-p1p5254490.xml.bz2 zhwiki-2026-01-01-p5254491p9382552.xml.bz2 > zhwiki-2026-01-01.xml.bz2
146
+ ```
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI 文本生成工具集 - 多页面 Gradio 应用
3
+
4
+ 入口点,提供导航到各个子应用:
5
+ - /:首页导航
6
+ - /wiki_gpt:Wiki GPT 文本生成器
7
+ - /poetry_gpt:诗歌生成器(GPT)
8
+ - /poetry_rnn:诗歌生成器(RNN)
9
+
10
+ 特点:
11
+ - 每个子页面可以独立运行测试
12
+ """
13
+ import gradio as gr
14
+
15
+ from tasks.wiki_gpt.gradio import demo as wiki_gpt_demo
16
+ from tasks.poetry_gpt.gradio import demo as poetry_gpt_demo
17
+ from tasks.poetry_rnn.gradio import demo as poetry_rnn_demo
18
+
19
+
20
+ with gr.Blocks(title="AI 文本生成工具集") as demo:
21
+ gr.Markdown("# AI 文本生成工具集")
22
+ gr.Markdown("请选择要使用的应用:")
23
+
24
+ with gr.Row():
25
+ with gr.Column():
26
+ gr.Markdown("## 诗歌生成器(GPT)")
27
+ gr.Markdown("基于 Transformer 的中文诗歌生成,支持五言、七言诗等。")
28
+ gr.Button("进入诗歌生成器", link="/poetry_gpt")
29
+
30
+ with gr.Column():
31
+ gr.Markdown("## 诗歌生成器(RNN)")
32
+ gr.Markdown("基于 RNN 的中文诗歌生成,支持五言、七言诗等。")
33
+ gr.Button("进入诗歌生成器", link="/poetry_rnn")
34
+
35
+ with gr.Column():
36
+ gr.Markdown("## Wiki GPT 文本生成")
37
+ gr.Markdown("基于 Transformer 的中文文本生成,训练来自于中文维基语料库。")
38
+ gr.Button("进入 Wiki GPT", link="/wiki_gpt")
39
+
40
+ gr.Markdown("---")
41
+ gr.Markdown("### 说明")
42
+ gr.Markdown("每个应用都是独立加载的,进入页面后需要等待模型加载完成。")
43
+
44
+
45
+ with demo.route("诗歌生成器(GPT)", "/poetry_gpt"):
46
+ poetry_gpt_demo.render()
47
+
48
+
49
+ with demo.route("诗歌生成器(RNN)", "/poetry_rnn"):
50
+ poetry_rnn_demo.render()
51
+
52
+
53
+ with demo.route("Wiki GPT", "/wiki_gpt"):
54
+ wiki_gpt_demo.render()
55
+
56
+
57
+ if __name__ == "__main__":
58
+ demo.launch()
data/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """数据集模块
2
+
3
+ 提供统一的数据集接口,包括 Wiki 和诗歌数据集。
4
+
5
+ Usage:
6
+ from data import WikiDataset, PoetryDataset
7
+
8
+ # Wiki 数据集
9
+ wiki = WikiDataset(data_dir="~/data/wiki/mini_c4")
10
+ doc_ds = wiki.doc_ds()
11
+ tokens_ds = wiki.tokens_ds(seq_length=256, batch_size=32)
12
+ wiki.stat(seq_length=256)
13
+
14
+ # 诗歌数据集
15
+ poetry = PoetryDataset(data_dir="~/data/Poetry")
16
+ doc_ds = poetry.doc_ds()
17
+ tokens_ds = poetry.tokens_ds(seq_length=100, batch_size=128)
18
+ poetry.stat(seq_length=100)
19
+ """
20
+
21
+ from data.base import DataBundle, TokenizerBundle
22
+ from data.wiki import WikiDataset
23
+ from data.poetry import PoetryDataset
24
+
25
+ __all__ = ["DataBundle", "TokenizerBundle", "WikiDataset", "PoetryDataset"]
data/base.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """数据集抽象基类模块
2
+
3
+ 定义 DataBundle 抽象基类,统一数据集的接口规范。
4
+ 每个具体的数据集(如 Wiki、诗歌)都应该继承此类并实现相应方法。
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass, field
9
+ from typing import Callable, Optional
10
+
11
+ import tensorflow as tf
12
+
13
+
14
+ @dataclass
15
+ class TokenizerBundle:
16
+ """分词器信息包装类
17
+
18
+ 将分词器相关的属性打包在一起,简化 DataBundle 接口。
19
+ """
20
+
21
+ tokenizer: Callable
22
+ decode: Callable
23
+ end_of_text: int
24
+ vocab_size: int
25
+ vocab_path: str = ""
26
+
27
+
28
+ @dataclass
29
+ class DataBundle(ABC):
30
+ """数据集抽象基类
31
+
32
+ 将数据加载、分词、统计等功能绑定在一起,提供统一的数据集接口。
33
+
34
+ Usage:
35
+ dataset = WikiDataset(data_dir="~/data/wiki")
36
+ doc_ds = dataset.doc_ds()
37
+ tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
38
+ dataset.stat()
39
+ """
40
+
41
+ data_dir: str
42
+ sequence_length: int = 256
43
+
44
+ @abstractmethod
45
+ def doc_ds(self) -> tf.data.Dataset:
46
+ """返回原始文档数据集
47
+
48
+ Returns:
49
+ TensorFlow Dataset,每个元素是一个文档字符串
50
+ """
51
+ pass
52
+
53
+ @abstractmethod
54
+ def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
55
+ """返回 tokenized 数据集
56
+
57
+ 将原始文档转换为 token ID 序列,并分割为训练样本。
58
+
59
+ Args:
60
+ seq_length: 序列长度
61
+ batch_size: 批次大小
62
+
63
+ Returns:
64
+ TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def tokenizer_bundle(self) -> TokenizerBundle:
70
+ """返回分词器信息"""
71
+ pass
72
+
73
+ def stat(self, seq_length: int | None = None) -> None:
74
+ """打印数据集统计信息
75
+
76
+ Args:
77
+ seq_length: 序列长度,用于估算训练样本数
78
+ """
79
+ from data.common import collect_stats
80
+
81
+ info = self.tokenizer_bundle()
82
+ stats = collect_stats(
83
+ name=self.__class__.__name__, loader=self.doc_ds, tokenizer=info.tokenizer
84
+ )
85
+ stats.print_report(seq_length=seq_length)
data/common.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """数据集共享工具模块
2
+
3
+ 提供数据集统计、报告生成等共享功能。
4
+ """
5
+
6
+ import pathlib
7
+ from dataclasses import dataclass
8
+ from typing import Callable
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ from keras import layers
13
+
14
+
15
+ @dataclass
16
+ class DatasetStats:
17
+ """数据集统计结果"""
18
+
19
+ name: str
20
+ doc_count: int
21
+ total_chars: int
22
+ total_tokens: int
23
+ max_length: int
24
+ median_length: int
25
+
26
+ def print_report(self, seq_length: int | None = 256):
27
+ """打印统一格式的统计报表
28
+
29
+ Args:
30
+ seq_length: 序列长度,用于估算训练样本数。
31
+ 为 None 时表示不切割,一个文档一个样本。
32
+ """
33
+ avg_chars = self.total_chars / self.doc_count if self.doc_count > 0 else 0
34
+ avg_tokens = self.total_tokens / self.doc_count if self.doc_count > 0 else 0
35
+
36
+ print()
37
+ print("=" * 60)
38
+ print(f"{self.name} 数据集统计")
39
+ print("=" * 60)
40
+ print(f"{'文档数:':<20} {self.doc_count:>15,}")
41
+ print(f"{'总字符数:':<20} {self.total_chars:>15,}")
42
+ print(f"{'总 Token 数:':<20} {self.total_tokens:>15,}")
43
+ print("-" * 60)
44
+ print(f"{'平均每文档字符数:':<20} {avg_chars:>15.1f}")
45
+ print(f"{'平均每文档 Token 数:':<20} {avg_tokens:>15.1f}")
46
+ print(f"{'最长文档字符数:':<20} {self.max_length:>15,}")
47
+ print(f"{'文档长度中位数:':<20} {self.median_length:>15,}")
48
+ print("=" * 60)
49
+
50
+ if self.total_tokens > 0:
51
+ print()
52
+ if seq_length is None:
53
+ print(f"训练样本数: {self.doc_count:,} 个 (一个文档一个样本)")
54
+ else:
55
+ print(f"训练样本预估 (seq={seq_length}):")
56
+ print(f" 可生成约 {self.total_tokens // seq_length:,} 个训练样本")
57
+
58
+
59
+ def collect_stats(
60
+ name: str, loader: Callable[[], tf.data.Dataset], tokenizer: Callable
61
+ ) -> DatasetStats:
62
+ """从 DatasetLoader 收集统计数据
63
+
64
+ Args:
65
+ name: 数据集名称(用于报表显示)
66
+ loader: 返回 tf.data.Dataset 的加载器函数
67
+ tokenizer: 分词器函数,接收文本返回 token ID 列表
68
+
69
+ Returns:
70
+ DatasetStats 统计结果对象
71
+ """
72
+ ds = loader()
73
+
74
+ doc_count = 0
75
+ total_chars = 0
76
+ total_tokens = 0
77
+ lengths = []
78
+
79
+ for item in ds:
80
+ text = item.numpy().decode("utf-8")
81
+ if not text.strip():
82
+ continue
83
+
84
+ doc_count += 1
85
+ total_chars += len(text)
86
+ lengths.append(len(text))
87
+
88
+ # Token 统计,过滤掉末尾的 padding (值为 0 的 token)
89
+ try:
90
+ import keras
91
+
92
+ token_ids = keras.ops.convert_to_numpy(tokenizer(text))
93
+ except ImportError:
94
+ # Fallback: assume tokenizer returns numpy array directly
95
+ token_ids = np.array(tokenizer(text))
96
+
97
+ # 只去掉末尾的 0,保留中间内容(包括中间的 OOV/padding)
98
+ valid_tokens = np.trim_zeros(token_ids, "b")
99
+ total_tokens += len(valid_tokens)
100
+
101
+ return DatasetStats(
102
+ name=name,
103
+ doc_count=doc_count,
104
+ total_chars=total_chars,
105
+ total_tokens=total_tokens,
106
+ max_length=max(lengths) if lengths else 0,
107
+ median_length=int(np.median(lengths)) if lengths else 0,
108
+ )
109
+
110
+
111
+ def save_vocabulary(vocab: list[str], vocab_path: pathlib.Path) -> None:
112
+ """保存词汇表到文件
113
+
114
+ Args:
115
+ vocab: 词汇表列表
116
+ vocab_path: 保存路径
117
+ """
118
+ vocab_path.parent.mkdir(parents=True, exist_ok=True)
119
+ with open(vocab_path, "w", encoding="utf-8") as f:
120
+ for char in vocab:
121
+ written = char if char != "\n" else r"\n"
122
+ f.write(written + "\n")
123
+
124
+
125
+ def build_vocab_from_dataset(
126
+ doc_ds: tf.data.Dataset, vocab_path: pathlib.Path
127
+ ) -> list[str]:
128
+ """从文档数据集构建词汇表
129
+
130
+ Args:
131
+ doc_ds: 文档数据集
132
+ vocab_path: 词汇表保存路径
133
+
134
+ Returns:
135
+ 词汇表列表
136
+ """
137
+ vectorizer = layers.TextVectorization(
138
+ output_mode="int", split="character", standardize=None
139
+ )
140
+ vectorizer.adapt(doc_ds, batch_size=128)
141
+
142
+ vocab = vectorizer.get_vocabulary()
143
+ if "$" not in vocab:
144
+ vocab = [*vocab, "$"]
145
+
146
+ save_vocabulary(vocab, vocab_path)
147
+ return vocab
data/dev/mini_c4/file1.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ first document of first file
2
+ second document of first file
3
+ third document of first file
data/dev/mini_c4/file2.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ first document of second file
2
+ second document of second file
3
+ third document of second file
4
+ fourth document of second file
data/dev/mini_c4/file3.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ first document of third file
2
+ second document of third file
3
+ third document of third file
data/dev/poetry/元.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 标题,朝代,作者,体裁,内容
2
+ 西洱河,元,述律杰,五言排律,洱水何雄壮,源流自邓川。两关龙首尾,九曲势蜿蜒。大理城池固,金汤铁石坚。四洲从古号,三岛至今传。罗阁凭巘崄,蒙人恃极边。要当兵十万,不数客三千。世祖亲征日,初还一统天。雨师清瘴疠,风伯扫氛烟。民物因蕃富,封疆近百年。点苍山色好,铭刻尚依然。
3
+ 陟玩春山纪兴,元,忽必烈,七言律诗,时膺韶景陟兰峰,不惮跻攀谒粹容。花色映霞祥彩混,垆烟拂雾瑞光重。雨沾琼干岩边竹,风袭琴声岭际松。净刹玉毫瞻礼罢,回程仙驾驭苍龙。
4
+ 结联,元,奥鲁赤,句,久立危栏须北望,无边秋色杳冥冥。
5
+ 八月初四日雪坡太守周门拓入云居山中复度岭饮于水月尼寺赋诗书似太守及苏州刺史周义卿,元,杨维桢,七言律诗,文章太守早休牙,五马传呼处士家。好客新分朱露酒,题诗近在白云窝。山中子落千年桂,海上人归八月槎。水月楼头横玉笛,误猜萼绿是韶华。
6
+ 用顾松江韵复理贰守并柬雪坡刺史,元,杨维桢,七言律诗,仙客归来隘九州,身骑黄鹤记南游。乌衣故国江山在,铜柱荒台草木秋。起舞刘琨空有志,登高王粲不胜愁。问君蔗境今何在,祇忆当年顾虎头。
7
+ 寄小蓬莱主者闻梅涧并简沈元方宇文仲美贤主宾,元,杨维桢,七言律诗,罗浮主者是仙才,东老诸孙亦俊哉。风雨春城花落尽,江山故国燕归来。酒盟自有乌巾在,笑口应随皓齿开。十八仙人重会处,劫灰不到小蓬莱。
8
+ 次韵奉答倪元镇,元,杨维桢,七言律诗,坐断深林事不闻,西窗风日爱余曛。旧经高赤寻三传,新咏山王削五君。翠筱侵床落苍雪,石池洗砚动玄云。东邻书屋最相忆,莫遣草堂移浪文。
9
+ 送谢太守,元,杨维桢,七言律诗,朝廷遣使航东海,万里南来送玺书。著屐登山良不恶,分符典郡复何如。白苏事业千年后,吴楚封疆百战馀。今日养民方急务,肯将徵算及舟车。
10
+ 送谢太守,元,杨维桢,七言律诗,
11
+ 回上张太尉(一云"谢赐玳瑁笔见征楚国公碑文"),元,杨维桢,七言律诗,昨夜文星照南极,今朝客省过东维。锦囊颖脱千年兔,斑管光摇九尾龟。墨卷风云随王气,恩分雨露出天池。老夫来草平蛮策,先写新封楚国碑。
data/dev/poetry/先秦.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 标题,朝代,作者,体裁,内容
2
+ 禹玉牒辞,先秦,无名氏,古风,祝融司方发其英,沐日浴月百宝生。
3
+ 衣铭,先秦,无名氏,古风,桑蚕苦,女工难,得新捐故后必寒。
4
+ 书车,先秦,无名氏,古风,出畏之,入惧之。
5
+ 击壤歌,先秦,无名氏,古风,日出而作。日入而息。凿井而饮。耕田而食。帝力于我何有哉。
data/dev/poetry/南北朝.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 标题,朝代,作者,体裁,内容
2
+ 悬瓠方丈竹堂飨侍臣联句诗,南北朝,元宏,古风,白日光天兮无不曜。江左一隅独未照。愿从圣明兮登衡会。万国驰诚混内外。云雷大振兮天门辟。率土来宾一正历。舜舞干戚兮天下归。文德远被莫不思。皇风一鼓兮九地匝。戴日依天清六合。遵彼汝坟兮昔化贞。未若今日道风明。文王政教兮晖江沼。宁如大化光四表。
3
+ 歌,南北朝,元宏,句,两菖蒲,新野乐。
4
+ 应制赋铜鞮山松诗,南北朝,元协,古风,问松林。松林经几冬。山川何如昔。风云与古同。
5
+ 绝命诗二首 其一 ,南北朝,元熙,古风,义实动君子,主辱死忠臣。何以明是节,将解七尺身。
6
+ 绝命诗二首 其二 ,南北朝,元熙,古风,平生方寸心,殷勤属知己。从今一销化,悲伤无极已。
data/poetry/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集模块
2
+
3
+ 从以下 github 地址下载数据集到目录 ./data/Poetry:
4
+
5
+ > https://github.com/xiu-ze/Poetry.git
6
+
7
+ 数据集的格式是多文件 CSV 格式,统计结果:
8
+
9
+ > 找到 22 个 CSV 文件
10
+ >
11
+ > 诗歌总数: 1014507
12
+ > 最长字符数: 4872
13
+ > 平均字符数: 66.04
14
+ > 中位数: 48
15
+
16
+ 因此可设置序列长度为 100.
17
+ """
18
+
19
+ from data.poetry.dataset import PoetryDataset
20
+
21
+ __all__ = ["PoetryDataset"]
data/poetry/dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集主模块
2
+
3
+ 实现 PoetryDataset 类,继承自 DataBundle。
4
+ """
5
+
6
+ import pathlib
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import tensorflow as tf
11
+
12
+ from data.base import DataBundle, TokenizerBundle
13
+ from data.poetry.loader import doc_load_with_eot
14
+ from data.poetry.transformer import transform
15
+ from data.poetry.tokenizer import load_vectorizer
16
+
17
+
18
+ @dataclass
19
+ class PoetryDataset(DataBundle):
20
+ """诗歌数据集
21
+
22
+ 将文档加载、分词、统计等功能绑定在一起的数据集类。
23
+
24
+ Usage:
25
+ dataset = PoetryDataset(
26
+ data_dir="~/data/Poetry/诗歌数据集",
27
+ vocab_path="~/data/Poetry/vocabulary.txt",
28
+ sequence_length=100
29
+ )
30
+
31
+ # 获取文档数据集
32
+ doc_ds = dataset.doc_ds()
33
+
34
+ # 获取 token 数据集
35
+ tokens_ds = dataset.tokens_ds(seq_length=100, batch_size=128)
36
+
37
+ # 打印统计信息
38
+ dataset.stat(seq_length=100)
39
+ """
40
+
41
+ vocab_path: str = ""
42
+
43
+ _data_path: pathlib.Path = field(init=False, repr=False)
44
+ _vocab_path: pathlib.Path = field(init=False, repr=False)
45
+ _tokenizer_info: Optional[TokenizerBundle] = field(
46
+ init=False, repr=False, default=None
47
+ )
48
+
49
+ def __post_init__(self):
50
+ self._data_path = pathlib.Path(self.data_dir).expanduser()
51
+ self._vocab_path = pathlib.Path(self.vocab_path).expanduser()
52
+
53
+ def _load_tokenizer(self):
54
+ """懒加载分词器"""
55
+ if self._tokenizer_info is None:
56
+ tokenizer = load_vectorizer(self._vocab_path, self.sequence_length + 1)
57
+ vocab = tokenizer.get_vocabulary()
58
+ end_of_text = vocab.index("$")
59
+ vocab_size = len(vocab)
60
+
61
+ def decode(token_ids: list[int]) -> str:
62
+ chars = [
63
+ vocab[token_id] for token_id in token_ids if token_id < len(vocab)
64
+ ]
65
+ return "".join(chars)
66
+
67
+ self._tokenizer_info = TokenizerBundle(
68
+ tokenizer=tokenizer,
69
+ decode=decode,
70
+ end_of_text=end_of_text,
71
+ vocab_size=vocab_size,
72
+ vocab_path=str(self._vocab_path)
73
+ )
74
+
75
+ def doc_ds(self) -> tf.data.Dataset:
76
+ """返回原始文档数据集
77
+
78
+ Returns:
79
+ TensorFlow Dataset,每个元素是带结束标记的诗歌内容
80
+ """
81
+ return doc_load_with_eot(self._data_path)
82
+
83
+ def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
84
+ """返回 tokenized 数据集
85
+
86
+ Args:
87
+ seq_length: 序列长度(诗歌中此参数主要用于兼容性)
88
+ batch_size: 批次大小
89
+
90
+ Returns:
91
+ TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
92
+ """
93
+ self._load_tokenizer()
94
+ ds = self.doc_ds()
95
+ return transform(
96
+ ds=ds,
97
+ tokenizer=self._tokenizer_info.tokenizer,
98
+ batch_size=batch_size,
99
+ )
100
+
101
+ def tokenizer_bundle(self) -> TokenizerBundle:
102
+ """返回分词器信息"""
103
+ self._load_tokenizer()
104
+ return self._tokenizer_info
data/poetry/loader.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集文档加载模块
2
+
3
+ 从 CSV 文件加载诗歌文本数据。
4
+ """
5
+
6
+ import glob
7
+ import os
8
+ import pathlib
9
+
10
+ import tensorflow as tf
11
+
12
+
13
+ def _parse_csv_line(line: tf.Tensor) -> tf.Tensor:
14
+ """解析 CSV 行,返回内容列"""
15
+ fields = tf.io.decode_csv(
16
+ line,
17
+ use_quote_delim=False, # 行内的引号是普通字符
18
+ record_defaults=["", "", "", "", ""],
19
+ )
20
+ return fields[4] # 返回 '内容' 列的值
21
+
22
+
23
+ def doc_load(data_dir: pathlib.Path) -> tf.data.Dataset:
24
+ """加载诗歌数据集
25
+
26
+ 从指定目录下的 CSV 文件中加载诗歌文本数据。
27
+ 每个 CSV 文件应该包含以下列:标题、作者、朝代、类型、内容。
28
+
29
+ Args:
30
+ data_dir: 数据目录路径
31
+
32
+ Returns:
33
+ TensorFlow Dataset,每个元素是诗歌内容字符串
34
+ """
35
+ csv_files = glob.glob(os.path.join(data_dir, "*.csv"))
36
+ if not csv_files:
37
+ raise ValueError(f"在目录 {data_dir} 中未找到任何 CSV 文件!")
38
+
39
+ files_ds = tf.data.Dataset.from_tensor_slices(csv_files)
40
+ csv_line_ds = files_ds.interleave(
41
+ lambda csv_file: tf.data.TextLineDataset(csv_file).skip(1),
42
+ cycle_length=1,
43
+ )
44
+ return csv_line_ds.map(_parse_csv_line, num_parallel_calls=tf.data.AUTOTUNE).filter(
45
+ lambda x: tf.strings.length(x) > 0
46
+ )
47
+
48
+
49
+ def doc_load_with_eot(data_dir: pathlib.Path) -> tf.data.Dataset:
50
+ """加载诗歌数据集,每行末尾添加结束标记
51
+
52
+ Args:
53
+ data_dir: 数据目录路径
54
+
55
+ Returns:
56
+ TensorFlow Dataset,每个元素是带结束标记的诗歌内容
57
+ """
58
+ ds = doc_load(data_dir)
59
+ return ds.map(lambda x: tf.strings.join([x, "$"]))
data/poetry/runner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集 Runner
2
+
3
+ Usage:
4
+ python data/poetry/runner.py build_vocab
5
+ python data/poetry/runner.py test_dataset
6
+ ENV=production python data/poetry/runner.py build_vocab
7
+ """
8
+
9
+ import pathlib
10
+ import sys
11
+
12
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
13
+
14
+ from data.poetry.dataset import PoetryDataset
15
+ from data.runner import DatasetRunner
16
+ from env.resolve import resolve_path, resolve_env, resolve_saved
17
+
18
+
19
+ dataset = PoetryDataset(
20
+ data_dir=str(
21
+ resolve_env(resolve_path("data/dev/poetry"), resolve_path("~/data/Poetry/诗歌数据集"))
22
+ ),
23
+ vocab_path=str(
24
+ resolve_env(
25
+ resolve_saved("vocab/poetry/vocab.txt"),
26
+ resolve_path("~/data/Poetry/vocabulary.txt"),
27
+ )
28
+ ),
29
+ sequence_length=100,
30
+ )
31
+
32
+ runner = DatasetRunner(
33
+ dataset=dataset,
34
+ name="poetry",
35
+ )
36
+
37
+ if __name__ == "__main__":
38
+ runner()
data/poetry/tokenizer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集分词器模块
2
+
3
+ 提供诗歌数据集专用的分词器实现。
4
+ """
5
+
6
+ import pathlib
7
+
8
+ from keras import layers
9
+
10
+
11
+ def load_vocabulary(vocab_path: pathlib.Path):
12
+ """从文本文件加载词汇表,每行一个字符。
13
+
14
+ Args:
15
+ vocab_path: 词汇表文件路径
16
+
17
+ Returns:
18
+ 词汇表列表
19
+ """
20
+
21
+ def extract_word(line: str) -> str:
22
+ word = line[:-1] # 去掉行末的换行符
23
+ return word if word != r"\n" else "\n"
24
+
25
+ with open(vocab_path, "r", encoding="utf-8") as f:
26
+ vocab = [extract_word(line) for line in f]
27
+ return vocab
28
+
29
+
30
+ def load_vectorizer(
31
+ vocab_path: pathlib.Path, sequence_length: int = 101
32
+ ) -> layers.TextVectorization:
33
+ """从词汇表文件加载分词器
34
+
35
+ Args:
36
+ vocab_path: 词汇表文件路径
37
+ sequence_length: 输出序列长度,默认为 101
38
+ (多一位是为了在训练时构建输入和目标偏移一位)
39
+
40
+ Returns:
41
+ TextVectorization 层
42
+ """
43
+ vectorizer = layers.TextVectorization(
44
+ output_mode="int",
45
+ split="character",
46
+ output_sequence_length=sequence_length,
47
+ standardize=None,
48
+ )
49
+
50
+ vocab = load_vocabulary(vocab_path)
51
+ vectorizer.set_vocabulary(vocab)
52
+
53
+ return vectorizer
54
+
55
+
56
+ def create_vectorizer(sequence_length: int = 101) -> layers.TextVectorization:
57
+ """创建新的分词器(用于训练词汇表)
58
+
59
+ Args:
60
+ sequence_length: 输出序列长度,默认为 101
61
+
62
+ Returns:
63
+ TextVectorization 层
64
+ """
65
+ return layers.TextVectorization(
66
+ output_mode="int", split="character", standardize=None
67
+ )
data/poetry/transformer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """诗歌数据集 token 转换模块
2
+
3
+ 将诗歌文档数据集转换为训练用的 token 序列。
4
+ """
5
+
6
+ from typing import Callable
7
+
8
+ import tensorflow as tf
9
+
10
+
11
+ def transform(
12
+ ds: tf.data.Dataset,
13
+ tokenizer: Callable,
14
+ batch_size: int,
15
+ ) -> tf.data.Dataset:
16
+ """转换诗歌数据集为训练数据集
17
+
18
+ 诗歌数据集已经生成了固定数量的 token 序列,不足的部分会 padding。
19
+
20
+ Args:
21
+ ds: 文档数据集
22
+ tokenizer: 分词器函数
23
+ batch_size: 批次大小
24
+
25
+ Returns:
26
+ 训练数据集,每个元素是 (input_ids, target_ids) 对
27
+ """
28
+ # 文本向量化;对于诗歌数据集来说,已经生成了固定数量的 token 序列了,不足的部分会 padding
29
+ ds = ds.map(tokenizer, num_parallel_calls=8)
30
+
31
+ # 构建输入和目标(偏移一位)
32
+ # 无需在这里添加结束标记,因为在 doc_load 中已经添加了结束标记
33
+ ds = ds.map(lambda x: (x[:-1], x[1:]))
34
+
35
+ # 重新设置批次大小并预取数据以提高性能
36
+ ds = ds.batch(batch_size).prefetch(8)
37
+
38
+ return ds
data/runner.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """数据集 Runner 公共模块
2
+
3
+ 提供通用的数据集测试和词汇表生成功能。
4
+
5
+ Usage:
6
+ # 在各自 runner.py 中实例化
7
+ from data.runner import DatasetRunner
8
+ from data.poetry.dataset import PoetryDataset
9
+ from env.resolve import resolve, resolve_saved, resolve_env
10
+
11
+ dataset = PoetryDataset(
12
+ data_dir=str(resolve_env(resolve("data/dev/poetry"), resolve("~/data/Poetry/诗歌数据集"))),
13
+ vocab_path=str(resolve_env(resolve_saved("poetry/vocab.txt"), resolve("~/data/Poetry/vocabulary.txt"))),
14
+ sequence_length=100,
15
+ )
16
+ runner = DatasetRunner(dataset=dataset, name="poetry")
17
+ runner()
18
+ """
19
+
20
+ from data.base import DataBundle
21
+ from data.common import build_vocab_from_dataset
22
+ from env.resolve import resolve_saved
23
+ from env.runner import ActionRunner
24
+
25
+
26
+ class DatasetRunner(ActionRunner):
27
+ """数据集 Runner
28
+
29
+ 提供通用的数据集测试和词汇表生成功能。
30
+
31
+ Args:
32
+ dataset: 数据集实例(PoetryDataset 或 WikiDataset)
33
+ name: 数据集英文名称(如 "poetry", "wiki")
34
+ max_docs: 测试时显示的文档数量,默认 5
35
+ max_samples: 测试时显示的 token 样本数量,默认 3
36
+ max_doc_chars: 文档显示的最大字符数,默认 200
37
+ max_text_display: token 文本显示的最大字符数,默认 80
38
+
39
+ Usage:
40
+ runner = DatasetRunner(dataset=poetry_dataset, name="poetry")
41
+ runner.test_dataset() # 或 runner.build_vocab()
42
+ """
43
+
44
+ # 中英文名称映射
45
+ NAME_MAP = {
46
+ "poetry": "诗歌",
47
+ "wiki": "Wiki",
48
+ }
49
+
50
+ def __init__(
51
+ self,
52
+ dataset: DataBundle,
53
+ name: str,
54
+ max_docs: int = 5,
55
+ max_samples: int = 3,
56
+ max_doc_chars: int = 200,
57
+ max_text_display: int = 80,
58
+ ):
59
+ self.dataset = dataset
60
+ self.name = name
61
+ self.display_name = self.NAME_MAP.get(name, name)
62
+ self.vocab_path = resolve_saved(f"vocab/{name}/vocab.txt")
63
+ self.max_docs = max_docs
64
+ self.max_samples = max_samples
65
+ self.max_doc_chars = max_doc_chars
66
+ self.max_text_display = max_text_display
67
+
68
+ def build_vocab(self) -> None:
69
+ """生成字符词汇表"""
70
+ print(f"正在加载数据集...")
71
+ ds = self.dataset.doc_ds()
72
+
73
+ print(f"正在保存词汇表到: {self.vocab_path}")
74
+ vocab = build_vocab_from_dataset(ds, self.vocab_path)
75
+
76
+ print(f"词汇表大小: {len(vocab)}")
77
+ print("完成!")
78
+
79
+ def test_dataset(self) -> None:
80
+ """测试数据集"""
81
+ print("\n" + "=" * 60)
82
+ print(f"{self.display_name} 数据集测试")
83
+ print("=" * 60)
84
+
85
+ self._view_documents(self.dataset.doc_ds())
86
+ self._view_tokens(self.dataset)
87
+ self._show_vocab_info(self.dataset.tokenizer_bundle())
88
+
89
+ print("\n" + "=" * 60)
90
+ print("测试完成")
91
+ print("=" * 60)
92
+
93
+ def _view_documents(self, doc_ds) -> None:
94
+ """查看原始文档"""
95
+ print("\n【原始文档查看】")
96
+ print("-" * 60)
97
+ count = 0
98
+ for doc in doc_ds.take(self.max_docs):
99
+ count += 1
100
+ text = doc.numpy().decode("utf-8")
101
+ if len(text) > self.max_doc_chars:
102
+ text = text[: self.max_doc_chars] + "..."
103
+ print(f"\n第 {count} 个文档:")
104
+ print(f" {text}")
105
+ print(f"\n共显示 {count} 个文档")
106
+
107
+ def _view_tokens(self, dataset) -> None:
108
+ """查看 tokenized 数据"""
109
+ print("\n【Tokenized 数据查看】")
110
+ print("-" * 60)
111
+
112
+ tokenizer_info = dataset.tokenizer_bundle()
113
+ tokens_ds = dataset.tokens_ds(seq_length=dataset.sequence_length, batch_size=1)
114
+
115
+ count = 0
116
+ for batch_input, batch_target in tokens_ds.take(self.max_samples):
117
+ count += 1
118
+ input_ids = batch_input[0].numpy()
119
+ target_ids = batch_target[0].numpy()
120
+
121
+ input_text = tokenizer_info.decode(input_ids.tolist())
122
+ target_text = tokenizer_info.decode(target_ids.tolist())
123
+
124
+ if len(input_text) > self.max_text_display:
125
+ input_text = input_text[: self.max_text_display] + "..."
126
+ if len(target_text) > self.max_text_display:
127
+ target_text = target_text[: self.max_text_display] + "..."
128
+
129
+ print(f"\n第 {count} 个样本:")
130
+ print(f" 输入 tokens: {input_ids[:20]}... (长度: {len(input_ids)})")
131
+ print(f" 目标 tokens: {target_ids[:20]}... (长度: {len(target_ids)})")
132
+ print(f" 输入文本: {input_text}")
133
+ print(f" 目标文本: {target_text}")
134
+ print(f"\n共显示 {count} 个样本")
135
+
136
+ @staticmethod
137
+ def _show_vocab_info(tokenizer_info) -> None:
138
+ """显示词汇表信息"""
139
+ print("\n【词汇表信息】")
140
+ print("-" * 60)
141
+ print(f" 词汇表大小: {tokenizer_info.vocab_size}")
142
+ print(f" 结束标记 ID: {tokenizer_info.end_of_text}")
data/tokenizers.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT模型的共享组件模块:
3
+
4
+ - 分词器
5
+ """
6
+
7
+ import keras
8
+ import keras_hub
9
+ from keras import layers
10
+
11
+
12
+ def sentence_piece():
13
+ # 用预训练好的分词器,也就是说我们不去自己训练分词器了
14
+ vocabulary_file = keras.utils.get_file(
15
+ origin="https://hf-mirror.com/mattdangerw/spiece/resolve/main/vocabulary.proto"
16
+ )
17
+ # [Note] 依然需要 tensorflow_text 包
18
+ tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
19
+
20
+ end_of_text = tokenizer.token_to_id("<|endoftext|>")
21
+
22
+ def decode(tokens: list[int]) -> str:
23
+ return tokenizer.detokenize(tokens)
24
+
25
+ return tokenizer, end_of_text, decode
26
+
27
+
28
+ def character_vectorization():
29
+ """简单的字符级分词器,适用于测试"""
30
+ vectorizer = layers.TextVectorization(output_mode="int", split="character")
31
+ vectorizer.set_vocabulary(
32
+ list("abcdefghijklmnopqrstuvwxyz0123456789 .,!?;:()[]{}<>-_\n")
33
+ + ["<|endoftext|>"] # 兼容 sentence_piece 分词器的特殊标记
34
+ )
35
+
36
+ vocab = vectorizer.get_vocabulary()
37
+ for idx, word in enumerate(vocab):
38
+ if word == "<|endoftext|>":
39
+ end_of_text = idx
40
+ break
41
+ else:
42
+ raise ValueError("Vocabulary does not contain <|endoftext|> token.")
43
+
44
+ def decode(tokens: list[int]) -> str:
45
+ words = [vocab[token] for token in tokens]
46
+ return "".join(words)
47
+
48
+ return vectorizer, end_of_text, decode
49
+
50
+
51
+ def poetry_character_vectorization(
52
+ vocab_path: str = "local/saved/vocab/poetry/vocab.txt",
53
+ ):
54
+ """从文本文件加载诗歌字符级分词器。
55
+
56
+ 词汇表文件格式:每行一个字符,第一行必须是 <|endoftext|>。
57
+
58
+ Args:
59
+ vocab_path: 词汇表文件路径,默认为 "local/saved/poetry/vocab.txt"
60
+
61
+ Returns:
62
+ (vectorizer, end_of_text, decode): 分词器、结束标记ID、解码函数
63
+ """
64
+ from env.resolve import resolve_path
65
+
66
+ # 读取词汇表
67
+ vocab_file = resolve_path(vocab_path)
68
+ with open(vocab_file, "r", encoding="utf-8") as f:
69
+ vocab = [line.rstrip("\n") for line in f]
70
+
71
+ # 创建 TextVectorization 层
72
+ vectorizer = layers.TextVectorization(
73
+ output_mode="int", split="character", standardize=None
74
+ )
75
+ vectorizer.set_vocabulary(vocab)
76
+
77
+ # 找到 end_of_text 的索引
78
+ for idx, word in enumerate(vocab):
79
+ if word == "<|endoftext|>":
80
+ end_of_text = idx
81
+ break
82
+ else:
83
+ raise ValueError("Vocabulary does not contain <|endoftext|> token.")
84
+
85
+ def decode(tokens: list[int]) -> str:
86
+ words = [vocab[token] for token in tokens]
87
+ return "".join(words)
88
+
89
+ return vectorizer, end_of_text, decode
data/wiki/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集模块
2
+
3
+ 导出 WikiDataset 类。
4
+ """
5
+
6
+ from data.wiki.dataset import WikiDataset
7
+
8
+ __all__ = ["WikiDataset"]
data/wiki/dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集主模块
2
+
3
+ 实现 WikiDataset 类,继承自 DataBundle。
4
+ """
5
+
6
+ import pathlib
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import tensorflow as tf
11
+
12
+ from data.base import DataBundle, TokenizerBundle
13
+ from data.wiki.loader import doc_load
14
+ from data.wiki.transformer import transform
15
+ from data.wiki.tokenizer import sentence_piece, character_vectorization
16
+
17
+
18
+ @dataclass
19
+ class WikiDataset(DataBundle):
20
+ """Wiki 数据集
21
+
22
+ 将文档加载、分词、统计等功能绑定在一起的数据集类。
23
+
24
+ Usage:
25
+ dataset = WikiDataset(
26
+ data_dir="~/data/wiki/mini_c4",
27
+ tokenizer_type="sentence_piece" # 或 "character"
28
+ )
29
+
30
+ # 获取文档数据集
31
+ doc_ds = dataset.doc_ds()
32
+
33
+ # 获取 token 数据集
34
+ tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
35
+
36
+ # 打印统计信息
37
+ dataset.stat(seq_length=256)
38
+ """
39
+
40
+ glob_pattern: str = "*"
41
+ tokenizer_type: str = "sentence_piece"
42
+
43
+ _data_path: pathlib.Path = field(init=False, repr=False)
44
+ _tokenizer_bundle: Optional[TokenizerBundle] = field(
45
+ init=False, repr=False, default=None
46
+ )
47
+
48
+ def __post_init__(self):
49
+ self._data_path = pathlib.Path(self.data_dir).expanduser()
50
+
51
+ def _load_tokenizer(self):
52
+ """懒加载分词器"""
53
+ if self._tokenizer_bundle is None:
54
+ if self.tokenizer_type == "sentence_piece":
55
+ tokenizer, end_of_text, decode = sentence_piece()
56
+ elif self.tokenizer_type == "character":
57
+ tokenizer, end_of_text, decode = character_vectorization()
58
+ else:
59
+ raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
60
+
61
+ vocab_size = tokenizer.vocabulary_size()
62
+ self._tokenizer_bundle = TokenizerBundle(
63
+ tokenizer=tokenizer,
64
+ decode=decode,
65
+ end_of_text=end_of_text,
66
+ vocab_size=vocab_size
67
+ )
68
+
69
+ def doc_ds(self) -> tf.data.Dataset:
70
+ """返回原始文档数据集
71
+
72
+ Returns:
73
+ TensorFlow Dataset,每个元素是一个文档字符串
74
+ """
75
+ return doc_load(self._data_path, glob_pattern=self.glob_pattern)
76
+
77
+ def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
78
+ """返回 tokenized 数据集
79
+
80
+ Args:
81
+ seq_length: 序列长度
82
+ batch_size: 批次大小
83
+
84
+ Returns:
85
+ TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
86
+ """
87
+ self._load_tokenizer()
88
+ ds = self.doc_ds()
89
+ return transform(
90
+ ds=ds,
91
+ tokenizer=self._tokenizer_bundle.tokenizer,
92
+ end_of_text=self._tokenizer_bundle.end_of_text,
93
+ sequence_length=seq_length,
94
+ batch_size=batch_size,
95
+ )
96
+
97
+ def tokenizer_bundle(self) -> TokenizerBundle:
98
+ """返回分词器信息"""
99
+ self._load_tokenizer()
100
+ return self._tokenizer_bundle
data/wiki/loader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集文档加载模块
2
+
3
+ 从 mini_c4 格式加载文档数据集。
4
+ """
5
+
6
+ import pathlib
7
+
8
+ import tensorflow as tf
9
+
10
+
11
+ def doc_load(
12
+ data_dir: pathlib.Path, glob_pattern: str = "*", cycle_length: int = 32
13
+ ) -> tf.data.Dataset:
14
+ """加载并处理文档数据集为 TensorFlow Dataset。
15
+
16
+ 递归查找指定目录下匹配 glob_pattern 的所有文件,使用 doc_extract 函数
17
+ 将每个文件转换为 TensorFlow Dataset,然后使用 interleave 进行并行处理。
18
+
19
+ 目录下的文件格式要求每行一个文档,其中的换行符使用 "\\n" 转义。
20
+
21
+ Args:
22
+ data_dir: 数据目录路径
23
+ glob_pattern: 文件匹配模式,如 "*.txt",默认为 "*" 匹配所有文件
24
+ cycle_length: interleave 的 cycle_length 参数,控制并行处理的文件数量,默认为 32
25
+
26
+ Returns:
27
+ 合并后的 TensorFlow Dataset,包含所有文件处理后的数据
28
+ """
29
+ # 获取所有文件(过滤掉目录),递归查找子目录
30
+ files = [str(file) for file in data_dir.rglob(glob_pattern) if file.is_file()]
31
+ if not files:
32
+ raise FileNotFoundError(f"在目录 {data_dir} 中未找到匹配 {glob_pattern} 的文件")
33
+
34
+ # 排序文件列表以确保一致的处理顺序
35
+ files = sorted(files)
36
+
37
+ # 创建数据集管道
38
+ ds = tf.data.Dataset.from_tensor_slices(files)
39
+ ds = ds.interleave(
40
+ _line_doc_extract,
41
+ cycle_length=cycle_length,
42
+ num_parallel_calls=tf.data.AUTOTUNE,
43
+ )
44
+
45
+ return ds
46
+
47
+
48
+ def _line_doc_extract(path: str) -> tf.data.Dataset:
49
+ """Mini-c4 format: one document per line."""
50
+ return tf.data.TextLineDataset(path).map(
51
+ lambda x: tf.strings.regex_replace(x, r"\\n", "\n")
52
+ )
data/wiki/runner.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集 Runner
2
+
3
+ Usage:
4
+ python data/wiki/runner.py test_dataset
5
+ ENV=production python data/wiki/runner.py test_dataset
6
+ """
7
+
8
+ import pathlib
9
+ import sys
10
+
11
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
12
+
13
+ from data.runner import DatasetRunner
14
+ from data.wiki.dataset import WikiDataset
15
+ from env.resolve import resolve_path, resolve_env
16
+
17
+
18
+ dataset = WikiDataset(
19
+ data_dir=str(
20
+ resolve_env(resolve_path("data/dev/mini_c4"), resolve_path("~/data/wiki/mini_c4"))
21
+ ),
22
+ tokenizer_type=resolve_env("character", "sentence_piece"),
23
+ sequence_length=256,
24
+ )
25
+
26
+ runner = DatasetRunner(
27
+ dataset=dataset,
28
+ name="wiki",
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ runner()
data/wiki/tokenizer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集分词器模块
2
+
3
+ 提供 Wiki 数据集专用的分词器实现。
4
+ """
5
+
6
+ import keras
7
+ import keras_hub
8
+ from keras import layers
9
+
10
+
11
+ def sentence_piece():
12
+ """SentencePiece 分词器
13
+
14
+ 使用预训练好的分词器,无需自己训练。
15
+
16
+ Returns:
17
+ (tokenizer, end_of_text, decode): 分词器、结束标记ID、解码函数
18
+ """
19
+ # 用预训练好的分词器,也就是说我们不去自己训练分词器了
20
+ vocabulary_file = keras.utils.get_file(
21
+ origin="https://hf-mirror.com/mattdangerw/spiece/resolve/main/vocabulary.proto"
22
+ )
23
+ # [Note] 依然需要 tensorflow_text 包
24
+ tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
25
+
26
+ end_of_text = tokenizer.token_to_id("<|endoftext|>")
27
+
28
+ def decode(tokens: list[int]) -> str:
29
+ return tokenizer.detokenize(tokens)
30
+
31
+ return tokenizer, end_of_text, decode
32
+
33
+
34
+ def character_vectorization():
35
+ """字符级分词器
36
+
37
+ 简单的字符级分词器,适用于测试。
38
+
39
+ Returns:
40
+ (tokenizer, end_of_text, decode): 分词器、结束标记ID、解码函数
41
+ """
42
+ vectorizer = layers.TextVectorization(output_mode="int", split="character")
43
+ vectorizer.set_vocabulary(
44
+ list("abcdefghijklmnopqrstuvwxyz0123456789 .,!?;:()[]{}\u003c\u003e-_\n")
45
+ + ["<|endoftext|>"] # 兼容 sentence_piece 分词器的特殊标记
46
+ )
47
+
48
+ vocab = vectorizer.get_vocabulary()
49
+ for idx, word in enumerate(vocab):
50
+ if word == "<|endoftext|>":
51
+ end_of_text = idx
52
+ break
53
+ else:
54
+ raise ValueError("Vocabulary does not contain <|endoftext|> token.")
55
+
56
+ def decode(tokens: list[int]) -> str:
57
+ words = [vocab[token] for token in tokens]
58
+ return "".join(words)
59
+
60
+ return vectorizer, end_of_text, decode
data/wiki/transformer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wiki 数据集 token 转换模块
2
+
3
+ 将文档数据集转换为训练用的 token 序列。
4
+ """
5
+
6
+ from typing import Callable
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+
11
+
12
+ def transform(
13
+ ds: tf.data.Dataset,
14
+ tokenizer: Callable,
15
+ end_of_text: int,
16
+ sequence_length: int,
17
+ batch_size: int,
18
+ ) -> tf.data.Dataset:
19
+ """转换文档数据集为训练数据集
20
+
21
+ 将文档转换为 token ID,添加结束标记,分割为固定长度的序列。
22
+
23
+ Args:
24
+ ds: 文档数据集
25
+ tokenizer: 分词器函数
26
+ end_of_text: 结束标记的 token ID
27
+ sequence_length: 序列长度
28
+ batch_size: 批次大小
29
+
30
+ Returns:
31
+ 训练数据集,每个元素是 (input_ids, target_ids) 对
32
+ """
33
+ ds = ds.map(tokenizer, num_parallel_calls=8)
34
+
35
+ # 将文档之间添加 end_of_text 标记分隔
36
+ ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1))
37
+
38
+ # 重新设置样本大小为固定长度序列
39
+ ds = ds.rebatch(sequence_length + 1, drop_remainder=True)
40
+
41
+ # 构建输入和目标(偏移一位)
42
+ ds = ds.map(lambda x: (x[:-1], x[1:]))
43
+
44
+ # 重新设置批次大小并预取数据以提高性能
45
+ ds = ds.batch(batch_size).prefetch(8)
46
+
47
+ return ds
data/wiki/wiki_cleaner.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wiki 文本清洗模块。
3
+
4
+ 提供多种过滤器用于清洗 wiki 格式的文本数据。
5
+ """
6
+
7
+ import re
8
+
9
+
10
+ def filter_single_line(text: str) -> str | None:
11
+ """
12
+ 过滤只有一行的数据(通常是重定向页面)。
13
+
14
+ Args:
15
+ text: 输入文本
16
+
17
+ Returns:
18
+ 如果只有一行返回 None,否则返回原文本
19
+ """
20
+ lines = [line for line in text.split("\n") if line.strip()]
21
+ if len(lines) <= 1:
22
+ return None
23
+ return text
24
+
25
+
26
+ def filter_empty_brackets(text: str) -> str:
27
+ """
28
+ 移除文本中的空括号对。
29
+
30
+ 例如:()、()、( )、( )、[ ]、【 】、{ } 等
31
+
32
+ Args:
33
+ text: 输入文本
34
+
35
+ Returns:
36
+ 移除空括号后的文本
37
+ """
38
+ # 匹配空括号对:() () [] 【】 {} 等,中间可有空白
39
+ pattern = re.compile(r"[\(\)()\[\]【】{}]\s*[\(\)()\[\]【】{}]")
40
+ return pattern.sub("", text)
41
+
42
+
43
+ def filter_html_tags(text: str) -> str:
44
+ """
45
+ 移除 HTML/XML 标签(HTML 实体编码格式)。
46
+
47
+ 例如:&lt;templatestyles src="ShareCSS/infobox.css" /&gt;
48
+
49
+ Args:
50
+ text: 输入文本
51
+
52
+ Returns:
53
+ 移除 HTML 标签后的文本
54
+ """
55
+ # 匹配 &lt;...&gt; 格式的实体编码标签
56
+ pattern = re.compile(r"&lt;[^&]+&gt;")
57
+ return pattern.sub("", text)
58
+
59
+
60
+ def filter_lang_tags(text: str) -> str:
61
+ """
62
+ 移除特殊的语言标记(支持嵌套)。
63
+
64
+ 例如:-{H|zh-hans:重定向;zh-hant:重新导向;}-
65
+ 嵌套例如:-{T|zh:-{zh|}-;zh-hans:-{zh-hans|}-;}-
66
+
67
+ Args:
68
+ text: 输入文本
69
+
70
+ Returns:
71
+ 移除语言转换标记后的文本
72
+ """
73
+ # 使用非贪婪匹配,循环处理嵌套
74
+ pattern = re.compile(r"-\{[^{}]+?}-")
75
+ while True:
76
+ new_text = pattern.sub("", text)
77
+ if new_text == text: # 没有更多匹配了
78
+ break
79
+ text = new_text
80
+ return text
81
+
82
+
83
+ def clean(text: str) -> str | None:
84
+ """
85
+ 应用所有过滤器清洗文本。
86
+
87
+ 过滤顺序:
88
+ 1. 单行检查(重定向页面)
89
+ 2. HTML 标签
90
+ 3. 空白括号行
91
+ 4. 语言转换标记
92
+ 5. 最终空检查
93
+
94
+ Args:
95
+ text: 输入文本
96
+
97
+ Returns:
98
+ 清洗后的文本,如果应该丢弃则返回 None
99
+ """
100
+ # 1. 检查单行
101
+ result = filter_single_line(text)
102
+ if result is None:
103
+ return None
104
+
105
+ # 2. 移除 HTML 标签
106
+ result = filter_html_tags(result)
107
+
108
+ # 3. 移除空白括号行
109
+ result = filter_empty_brackets(result)
110
+
111
+ # 4. 移除语言转换标记
112
+ result = filter_lang_tags(result)
113
+
114
+ # 5. 多个连续空行替换为一个空行
115
+ result = re.sub(r"\n\s*\n", "\n\n", result)
116
+ result = result.strip()
117
+
118
+ # 6. 最终检查:如果结果为空或只剩空白,返回 None
119
+ if not result.strip():
120
+ return None
121
+
122
+ return result
docs/TODOs.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ - [ ] `<doc>` 格式由于计算图的限制还无法实现,未来打算实现。
2
+ - [ ] 希望能通过 Callback 或 train_step 截取到训练过程中的数据。
3
+ - [ ] wiki 训练后不能回答事实性问题,感觉是过拟合了,将 dropout 调成 0.5 试一试(当前 0.1)。
docs/pycharm.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyCharm 开发指南
2
+
3
+ 最近,我在项目里尝试将 PyCharm 的代码连接到远程服务器运行,遇到了一些莫名的问题。现在将一些解决方案记录下来,供以后参考。
4
+
5
+ 总的来说,我直接应用远程环境就会出错,有各种各样的问题。我需要重新构建一个全新的环境才使得它正常运作。记录如下:
6
+
7
+ 1. 在远程服务器上创建一个新的 conda 环境。
8
+ 2. 创建一个新的 Python 项目(我直接移动了我的项目目录,并删除目录下的 .idea, .ruff_cache, .pytest_cache 等文件夹)。
9
+ 3. 在本地 PyCharm 中配置远程 Python 解释器,指向远程服务器上的新环境。这一步骤中,注意配置好目录映射,和不自动上传文件。
10
+ 4. 等待一段时间,就能正常运作了。
env/keras.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Keras 相关工具模块
2
+
3
+ 提供 Keras 配置相关的功能。
4
+ """
5
+
6
+ import keras
7
+
8
+
9
+ def enable_mixed_precision():
10
+ """开启混合精度训练/推理"""
11
+ keras.config.set_dtype_policy("mixed_float16")
env/logger.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import wraps
3
+
4
+
5
+ def get_logger(name: str, filepath: str = None):
6
+ logger = logging.getLogger(name)
7
+ logger.setLevel(logging.INFO)
8
+
9
+ # 控制台
10
+ console_handler = logging.StreamHandler()
11
+ logger.addHandler(console_handler)
12
+
13
+ # 文件
14
+ if filepath:
15
+ file_handler = logging.FileHandler(filepath)
16
+ logger.addHandler(file_handler)
17
+
18
+ return logger
19
+
20
+
21
+ def log(enter_message: str = "", exit_message: str = ""):
22
+ return _Log(enter_message=enter_message, exit_message=exit_message)
23
+
24
+
25
+ class _Log:
26
+ def __init__(
27
+ self,
28
+ enter_message: str = "",
29
+ exit_message: str = ""
30
+ ):
31
+ self.enter_message = enter_message
32
+ self.exit_message = exit_message
33
+
34
+ def __enter__(self):
35
+ if self.enter_message:
36
+ print(self.enter_message)
37
+ return self
38
+
39
+ def __exit__(self, exc_type, exc, tb):
40
+ if self.exit_message:
41
+ print(self.exit_message)
42
+ print("")
43
+ return False
44
+
45
+ def __call__(self, func):
46
+ @wraps(func)
47
+ def wrapper(*args, **kwargs):
48
+ with _Log(self.enter_message, self.exit_message):
49
+ return func(*args, **kwargs)
50
+ return None
51
+
52
+ return wrapper
env/resolve.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum, StrEnum
3
+ from pathlib import Path
4
+
5
+
6
+ """定义项目的根路径"""
7
+ PROJECT_ROOT = Path(__file__).parent.parent.absolute()
8
+
9
+
10
+ """定义根据环境变量选择配置的函数"""
11
+ class Env(StrEnum):
12
+ TEST = "test"
13
+ PRODUCTION = "production"
14
+
15
+ def resolve_env[T](test_conf: T = Env.TEST, prod_conf: T = Env.PRODUCTION) -> T:
16
+ env = os.environ.get("ENV", str(Env.TEST))
17
+ return prod_conf if env == str(Env.PRODUCTION) else test_conf
18
+
19
+
20
+ """定义一些预设的目录"""
21
+ SAVED_DIR = resolve_env(
22
+ PROJECT_ROOT / "local" / "saved",
23
+ PROJECT_ROOT / "saved",
24
+ )
25
+ TASKS_DIR = PROJECT_ROOT / "local" / "tasks"
26
+
27
+
28
+ """定义一些路径解析函数,方便在项目中使用"""
29
+ def resolve_saved(path: str | Path = None) -> Path:
30
+ """解析相对于 saved 目录的路径
31
+
32
+ 1. 如果本身就是 Path 对象,直接返回。
33
+ 2. 如果 path 是 None,返回 saved 目录本身。
34
+ 3. 否则,将 path 解析为相对于 saved 目录的路径。
35
+ """
36
+ if isinstance(path, Path):
37
+ return path
38
+ return SAVED_DIR / path if path else SAVED_DIR
39
+
40
+
41
+ def resolve_task_dir(task_name: str) -> Path:
42
+ """解析任务所在的目录
43
+
44
+ Args:
45
+ task_name: 任务名称,即定义在 Pipeline 中的 name 字段,例如 "poetry_gpt" 或 "poetry_rnn"。
46
+ """
47
+ return TASKS_DIR / task_name
48
+
49
+
50
+ def resolve_path(path: str | Path) -> Path:
51
+ """从项目根目录解析路径
52
+
53
+ 1. 如果路径是 Path 对象,直接返回。
54
+ 2. 如果路径是以 ~ 或 / 开头的绝对路径,则直接返回该路径。
55
+ 3. 如果路径是相对路径,则将其解析为相对于项目根目录的路径。
56
+
57
+ Args:
58
+ path: 相对于项目根目录的路径
59
+
60
+ Returns:
61
+ 解析后的绝对路径
62
+
63
+ Example:
64
+ >>> resolve_path("data/dev/mini_c4/file.txt")
65
+ PosixPath('/Users/.../universal_deeplearning/data/dev/mini_c4/file.txt')
66
+ """
67
+ if isinstance(path, Path):
68
+ return path
69
+ elif path.startswith("~") or path.startswith("/"):
70
+ return Path(path).expanduser().resolve()
71
+ else:
72
+ return PROJECT_ROOT / path
73
+
74
+
75
+ def display_path(path: str | Path) -> str:
76
+ """将路径转换为适合展示的字符串
77
+
78
+ 如果路径位于项目根目录内,则显示为相对项目根目录的路径;
79
+ 否则显示绝对路径。
80
+ """
81
+ resolved = resolve_path(path)
82
+ try:
83
+ return str(resolved.relative_to(PROJECT_ROOT))
84
+ except ValueError:
85
+ return str(resolved)
env/runner.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Callable
3
+
4
+
5
+ class ActionRunner:
6
+ def __call__(self, default_method: str | Callable = None):
7
+ if len(sys.argv) > 1:
8
+ method = self._resolve_method(sys.argv[1])
9
+ else:
10
+ method = default_method
11
+ if type(method) == str:
12
+ method = self._resolve_method(method)
13
+
14
+ if method:
15
+ method()
16
+ else:
17
+ raise ValueError("没有指定要执行的方法")
18
+
19
+ def _resolve_method(self, method_name: str) -> Callable:
20
+ method = getattr(self, method_name, None)
21
+ if method is None:
22
+ raise ValueError(f"没有找到对应的方法:{method_name}")
23
+ return method
env/vocab.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # 定义所有词典中 PAD 的 id token
2
+ PAD = 0
environment-linux.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: general-dl
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.12
6
+ - pip
7
+ - numpy
8
+ - tensorflow
9
+ - tensorflow-text
10
+ - keras
11
+ - pip:
12
+ - keras-hub
13
+ - gradio
14
+ variables:
15
+ ENV: production
environment.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: general-dl
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.12
6
+ - pip
7
+ - setuptools>=68,<70
8
+ - numpy
9
+ - ruff
10
+ - pytest
11
+ - pytest-mock
12
+ - tensorflow
13
+ - keras
14
+ - pip:
15
+ - gradio
16
+ variables:
17
+ ENV: test
generate_requirements.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 从 environment-linux.yml 生成 requirements.txt
4
+ YAML 中的版本号优先级最高
5
+ 未指定版本号时查询当前环境的实际版本
6
+ 排除 python 和 pip
7
+ """
8
+
9
+ import os
10
+ from datetime import datetime
11
+ import yaml
12
+ from importlib.metadata import version, PackageNotFoundError
13
+
14
+
15
+ # 排除的包(不加入 requirements.txt)
16
+ EXCLUDE_PACKAGES = {"python", "pip"}
17
+
18
+
19
+ def get_installed_version(package_name):
20
+ """获取包的安装版本,未安装返回 None"""
21
+ try:
22
+ return version(package_name)
23
+ except PackageNotFoundError:
24
+ return None
25
+
26
+
27
+ def parse_package_string(dep):
28
+ """
29
+ 解析包字符串,返回 (包名, yaml版本号或None)
30
+ 例如: "tensorflow=2.15.0" -> ("tensorflow", "2.15.0")
31
+ "numpy" -> ("numpy", None)
32
+ """
33
+ if "=" in dep:
34
+ parts = dep.split("=")
35
+ pkg_name = parts[0]
36
+ pkg_version = parts[1]
37
+ return pkg_name, pkg_version
38
+ else:
39
+ return dep, None
40
+
41
+
42
+ def parse_environment_yml(filepath):
43
+ """解析 environment-linux.yml,提取包列表和版本信息"""
44
+ with open(filepath, "r") as f:
45
+ env = yaml.safe_load(f)
46
+
47
+ packages = []
48
+
49
+ for dep in env.get("dependencies", []):
50
+ if isinstance(dep, str):
51
+ # 简单字符串格式:package 或 package=version
52
+ pkg_name, yaml_version = parse_package_string(dep)
53
+ if pkg_name not in EXCLUDE_PACKAGES:
54
+ packages.append((pkg_name, yaml_version))
55
+ elif isinstance(dep, dict) and "pip" in dep:
56
+ # pip 子列表
57
+ for pip_dep in dep["pip"]:
58
+ pkg_name, yaml_version = parse_package_string(pip_dep)
59
+ if pkg_name not in EXCLUDE_PACKAGES:
60
+ packages.append((pkg_name, yaml_version))
61
+
62
+ return packages
63
+
64
+
65
+ def main():
66
+ yml_file = "environment-linux.yml"
67
+ output_file = "requirements.txt"
68
+
69
+ print(f"读取 {yml_file}...")
70
+ packages = parse_environment_yml(yml_file)
71
+ print(f"发现 {len(packages)} 个包(排除 {EXCLUDE_PACKAGES})")
72
+
73
+ lines = []
74
+ for pkg_name, yaml_version in packages:
75
+ if yaml_version:
76
+ # YAML 中有版本号,优先使用
77
+ lines.append(f"{pkg_name}=={yaml_version}")
78
+ print(f" ✓ {pkg_name}=={yaml_version} (来自 YAML)")
79
+ else:
80
+ # YAML 中没有版本号,查询当前环境
81
+ env_version = get_installed_version(pkg_name)
82
+ if env_version:
83
+ lines.append(f"{pkg_name}=={env_version}")
84
+ print(f" ✓ {pkg_name}=={env_version} (来自当前环境)")
85
+ else:
86
+ lines.append(pkg_name)
87
+ print(f" ⚠ {pkg_name} (未安装,无版本号)")
88
+
89
+ # 添加头部注释
90
+ header_lines = [
91
+ f"# Generated from {yml_file}",
92
+ f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
93
+ f"# Environment: {os.environ.get('ENV', 'unknown')}",
94
+ "#",
95
+ ]
96
+
97
+ # 合并所有行
98
+ all_lines = header_lines + lines
99
+
100
+ with open(output_file, "w") as f:
101
+ f.write("\n".join(all_lines) + "\n")
102
+
103
+ print(f"\n已生成 {output_file}:")
104
+ print("-" * 40)
105
+ print("\n".join(all_lines))
106
+ print("-" * 40)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
models/__init__.py ADDED
File without changes
models/mini_gpt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.mini_gpt.model_builder import GptModelBuilder
models/mini_gpt/gpt_components.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT模型的共享组件模块:
3
+
4
+ - Positional Encoding
5
+ - Transformer Decoder
6
+ """
7
+
8
+ import keras
9
+ from keras import layers, ops
10
+
11
+ class PositionalEmbedding(keras.Layer):
12
+ def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.token_embeddings = layers.Embedding(input_dim, output_dim)
15
+ self.position_embeddings = layers.Embedding(sequence_length, output_dim)
16
+
17
+ def call(self, inputs, reverse=False):
18
+ if reverse:
19
+ token_embeddings = self.token_embeddings.embeddings
20
+ return ops.matmul(inputs, ops.transpose(token_embeddings))
21
+ positions = ops.cumsum(ops.ones_like(inputs), axis=-1) - 1
22
+ embedded_tokens = self.token_embeddings(inputs)
23
+ embedded_positions = self.position_embeddings(positions)
24
+ return embedded_tokens + embedded_positions
25
+
26
+
27
+ class TransformerDecoder(keras.Layer):
28
+ def __init__(self, hidden_dim, intermediate_dim, num_heads, **kwargs):
29
+ super().__init__(**kwargs)
30
+
31
+ self.hidden_dim = hidden_dim
32
+ self.intermediate_dim = intermediate_dim
33
+
34
+ key_dim = hidden_dim // num_heads
35
+
36
+ # self-attention 层
37
+ self.self_attention = layers.MultiHeadAttention(num_heads, key_dim, dropout=0.1)
38
+ self.self_attention_layernorm = layers.LayerNormalization()
39
+
40
+ # feed-forward 层
41
+ self.feed_forward_1 = layers.Dense(intermediate_dim, activation="relu")
42
+ self.feed_forward_2 = layers.Dense(hidden_dim)
43
+ self.feed_forward_layernorm = layers.LayerNormalization()
44
+ self.dropout = layers.Dropout(0.1)
45
+
46
+ def call(self, inputs):
47
+ # self-attention 计算
48
+ residual = x = inputs
49
+ x = self.self_attention(query=x, key=x, value=x, use_causal_mask=True)
50
+ x = self.dropout(x)
51
+ x = x + residual
52
+ x = self.self_attention_layernorm(x)
53
+
54
+ # feed-forward 计算
55
+ residual = x
56
+ x = self.feed_forward_1(x)
57
+ x = self.feed_forward_2(x)
58
+ x = self.dropout(x)
59
+ x = x + residual
60
+ x = self.feed_forward_layernorm(x)
61
+ return x
models/mini_gpt/model_builder.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+
4
+ import keras
5
+ from keras import layers
6
+
7
+ from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder
8
+ from pipeline.base.generation import generate_with_training_model
9
+ from pipeline.base.model_builder import ModelArtifact
10
+
11
+
12
+ @dataclass
13
+ class GptModelBuilder:
14
+ hidden_dim: int
15
+ intermediate_dim: int
16
+ num_heads: int
17
+ num_layers: int
18
+
19
+ def build_training_artifact(
20
+ self,
21
+ vocab_size: int,
22
+ sequence_length: int
23
+ ) -> ModelArtifact:
24
+ inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
25
+ embedding = PositionalEmbedding(
26
+ sequence_length,
27
+ vocab_size,
28
+ self.hidden_dim,
29
+ name="embedding"
30
+ )
31
+ x = embedding(inputs)
32
+ x = layers.LayerNormalization(name="input_layer_norm")(x)
33
+
34
+ for i in range(self.num_layers):
35
+ decoder = TransformerDecoder(
36
+ self.hidden_dim,
37
+ self.intermediate_dim,
38
+ self.num_heads,
39
+ name=f"decoder_{i}"
40
+ )
41
+ x = decoder(x)
42
+
43
+ outputs = embedding(x, reverse=True)
44
+ model = keras.Model(inputs, outputs, name="mini_gpt")
45
+ return ModelArtifact(
46
+ model=model,
47
+ generate=partial(generate_with_training_model, model)
48
+ )
49
+
50
+ def build_inference_artifact(
51
+ self,
52
+ training_artifact: ModelArtifact
53
+ ) -> ModelArtifact:
54
+ return training_artifact
models/rnn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.rnn.model_builder import RNNModelBuilder
models/rnn/model_builder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+
4
+ import keras
5
+ import tensorflow as tf
6
+ from keras import layers
7
+
8
+ from pipeline.base.generation import generate_with_stateful_model, generate_with_training_model
9
+ from pipeline.base.model_builder import ModelArtifact
10
+
11
+
12
+ @dataclass
13
+ class RNNModelBuilder:
14
+ num_layers: int = 2
15
+ embedding_dim: int = 100
16
+ hidden_dim: int = 1024
17
+
18
+ def build_training_artifact(
19
+ self,
20
+ vocab_size: int,
21
+ sequence_length: int
22
+ ) -> ModelArtifact:
23
+ inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
24
+ x = layers.Embedding(
25
+ input_dim=vocab_size,
26
+ output_dim=self.embedding_dim,
27
+ mask_zero=True,
28
+ name="embedding"
29
+ )(inputs)
30
+
31
+ for i in range(self.num_layers):
32
+ x = layers.LSTM(
33
+ self.hidden_dim,
34
+ return_sequences=True,
35
+ recurrent_dropout=0.1,
36
+ name=f"lstm_{i}"
37
+ )(x)
38
+ x = layers.Dropout(0.1, name=f"dropout_{i}")(x)
39
+
40
+ outputs = layers.Dense(vocab_size, name="logits")(x)
41
+ model = keras.Model(inputs=inputs, outputs=outputs, name="rnn_training")
42
+ return ModelArtifact(
43
+ model=model,
44
+ generate=partial(generate_with_training_model, model)
45
+ )
46
+
47
+ def build_inference_artifact(
48
+ self,
49
+ training_artifact: ModelArtifact
50
+ ) -> ModelArtifact:
51
+ inference_model = self._build_inference_model_from_training_model(
52
+ training_artifact.model
53
+ )
54
+ return ModelArtifact(
55
+ model=inference_model,
56
+ generate=partial(
57
+ generate_with_stateful_model,
58
+ inference_model,
59
+ initial_states=self._initial_states(batch_size=1)
60
+ )
61
+ )
62
+
63
+ def _build_inference_model_from_training_model(
64
+ self,
65
+ training_model: keras.Model
66
+ ) -> keras.Model:
67
+ token_input = keras.Input(shape=(None,), dtype="int32", name="token_input")
68
+ state_inputs = []
69
+ for i in range(self.num_layers):
70
+ h_input = keras.Input(shape=(self.hidden_dim,), name=f"h_{i}_input")
71
+ c_input = keras.Input(shape=(self.hidden_dim,), name=f"c_{i}_input")
72
+ state_inputs.extend([h_input, c_input])
73
+
74
+ embedding = training_model.get_layer("embedding")
75
+ logits_layer = training_model.get_layer("logits")
76
+ x = embedding(token_input)
77
+
78
+ new_states = []
79
+ inference_lstm_layers = []
80
+ for i in range(self.num_layers):
81
+ inference_lstm = layers.LSTM(
82
+ self.hidden_dim,
83
+ return_sequences=i < self.num_layers - 1,
84
+ return_state=True,
85
+ recurrent_dropout=0.1,
86
+ name=f"lstm_{i}"
87
+ )
88
+ h_input = state_inputs[i * 2]
89
+ c_input = state_inputs[i * 2 + 1]
90
+ x, new_h, new_c = inference_lstm(x, initial_state=[h_input, c_input])
91
+ new_states.extend([new_h, new_c])
92
+ dropout = training_model.get_layer(f"dropout_{i}")
93
+ x = dropout(x)
94
+ inference_lstm_layers.append(inference_lstm)
95
+
96
+ logits = logits_layer(x)
97
+ inference_model = keras.Model(
98
+ [token_input] + state_inputs,
99
+ [logits] + new_states,
100
+ name="rnn_inference"
101
+ )
102
+
103
+ for i, inference_lstm in enumerate(inference_lstm_layers):
104
+ training_lstm = training_model.get_layer(f"lstm_{i}")
105
+ inference_lstm.set_weights(training_lstm.get_weights())
106
+
107
+ return inference_model
108
+
109
+ def _initial_states(self, batch_size: int) -> list:
110
+ states = []
111
+ for _ in range(self.num_layers):
112
+ states.append(tf.zeros((batch_size, self.hidden_dim)))
113
+ states.append(tf.zeros((batch_size, self.hidden_dim)))
114
+ return states
pipeline/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .runner import PipelineRunner
2
+ from .pipeline import Pipeline
3
+ from .base.configs import CheckpointConfig
pipeline/base/__init__.py ADDED
File without changes
pipeline/base/checkpoint.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 模型工具模块
3
+
4
+ 包含模型构建、检查点管理等通用功能。
5
+ """
6
+
7
+ import pathlib
8
+ import re
9
+ import warnings
10
+
11
+ from env.resolve import resolve_path
12
+
13
+
14
+ def extract_number_of_filename(filename: str) -> int:
15
+ """
16
+ 从文件名中提取数字,无论数字出现在文件名的哪个位置。
17
+
18
+ 例如:
19
+ - "model_epoch_001.weights.h5" -> 1
20
+ - "checkpoint_2024_06_30_epoch_002.weights.h5" -> 2
21
+ - "model_epoch_final.weights.h5" -> 抛出异常
22
+
23
+ :param filename: 包含数字的文件名字符串
24
+ :return: 提取的数字,如果没有数字则返回0
25
+ """
26
+ numbers = re.findall(r"\d+", filename)
27
+ if numbers:
28
+ return int(numbers[-1]) # 返回最后一个数字,假设它是代数
29
+ else:
30
+ raise ValueError(f"No number found in filename: {filename}")
31
+
32
+
33
+ def resolve_checkpoint(
34
+ dirs: list[pathlib.Path | str] | None = None,
35
+ path: pathlib.Path | str | None = None,
36
+ epoch: int | None = None,
37
+ suffix: str | None = None
38
+ ):
39
+ """统一解析模型检查点路径
40
+
41
+ 支持直接指定检查点文件路径或在目录中查找检查点文件。
42
+
43
+ 参数:
44
+ dirs: 检查点目录列表
45
+ path: 直接指定的检查点文件路径(支持绝对路径和相对路径)
46
+ epoch: 指定的 epoch,用于查找对应的 .weights.h5 文件
47
+ suffix: 指定检查点文件后缀
48
+
49
+ 返回:
50
+ (resolved_path, epoch): 绝对路径和 epoch 数
51
+
52
+ 抛出:
53
+ FileNotFoundError: 当指定的路径不存在或未找到检查点文件时
54
+ ValueError: 当参数无效时
55
+ """
56
+ resolved_dirs = _resolve_checkpoint_dirs(dirs)
57
+
58
+ if path is not None:
59
+ path = pathlib.Path(path)
60
+
61
+ if not path.is_absolute():
62
+ if not resolved_dirs:
63
+ raise ValueError("path 是相对路径时,必须提供 dirs")
64
+ path = _resolve_relative_checkpoint_path(path, resolved_dirs)
65
+ else:
66
+ if dirs is not None:
67
+ warnings.warn(
68
+ "警告:path 是绝对路径,dirs 参数将被忽略",
69
+ UserWarning
70
+ )
71
+
72
+ if not path.exists():
73
+ raise FileNotFoundError(f"检查点文件不存在: {path}")
74
+ if suffix is not None and not path.name.endswith(suffix):
75
+ raise FileNotFoundError(f"检查点文件后缀不匹配: {path}")
76
+
77
+ try:
78
+ epoch_num = extract_number_of_filename(path.stem)
79
+ except ValueError:
80
+ epoch_num = 0
81
+
82
+ return path, epoch_num
83
+
84
+ if not resolved_dirs:
85
+ raise ValueError("必须提供 dirs 或 path")
86
+
87
+ files_with_number = _collect_checkpoint_files(
88
+ checkpoint_dirs=resolved_dirs,
89
+ suffix=suffix
90
+ )
91
+
92
+ if epoch is not None:
93
+ matches = [(f, num) for f, num in files_with_number if num == epoch]
94
+ if not matches:
95
+ raise FileNotFoundError(f"未找到 epoch {epoch} 对应的检查点文件")
96
+ if len(matches) > 1:
97
+ raise RuntimeError(
98
+ f"找到多个 epoch {epoch} 对应的检查点文件: {[match[0].name for match in matches]}"
99
+ )
100
+ return matches[0]
101
+
102
+ if not files_with_number:
103
+ return None, 0
104
+
105
+ return max(files_with_number, key=lambda item: item[1])
106
+
107
+
108
+ def _resolve_checkpoint_dirs(
109
+ dirs: list[pathlib.Path | str] | None
110
+ ) -> list[pathlib.Path]:
111
+ if dirs is None:
112
+ return []
113
+ return [resolve_path(path) for path in dirs]
114
+
115
+
116
+ def _resolve_relative_checkpoint_path(
117
+ checkpoint_path: pathlib.Path,
118
+ checkpoint_dirs: list[pathlib.Path]
119
+ ) -> pathlib.Path:
120
+ for checkpoint_dir in checkpoint_dirs:
121
+ candidate = checkpoint_dir / checkpoint_path
122
+ if candidate.exists():
123
+ return candidate
124
+ return checkpoint_dirs[0] / checkpoint_path
125
+
126
+
127
+ def _collect_checkpoint_files(
128
+ checkpoint_dirs: list[pathlib.Path],
129
+ suffix: str | None
130
+ ) -> list[tuple[pathlib.Path, int]]:
131
+ files_with_number = []
132
+ for checkpoint_dir in checkpoint_dirs:
133
+ if not checkpoint_dir.exists():
134
+ continue
135
+ for file_path in sorted(checkpoint_dir.iterdir()):
136
+ if not file_path.is_file():
137
+ continue
138
+ if suffix is not None and not file_path.name.endswith(suffix):
139
+ continue
140
+ if suffix is None and not _is_checkpoint_file(file_path):
141
+ continue
142
+ files_with_number.append((file_path, extract_number_of_filename(file_path.stem)))
143
+ return files_with_number
144
+
145
+
146
+ def _is_checkpoint_file(file_path: pathlib.Path) -> bool:
147
+ return file_path.name.endswith(".keras") or file_path.name.endswith(".weights.h5")
pipeline/base/configs.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ from typing import Callable
4
+
5
+
6
+ @dataclass
7
+ class CheckpointConfig:
8
+ dirs: list[Path] | None = None
9
+ path: Path = None
10
+ epoch: int = None
11
+ suffix: str = None
12
+
13
+
14
+ @dataclass
15
+ class ModelConfig:
16
+ sequence_length: int = 256
17
+ hidden_dim: int = 512
18
+ intermediate_dim: int = 2056
19
+ num_heads: int = 8
20
+ num_layers: int = 8
21
+
22
+
23
+ @dataclass
24
+ class TrainingRule:
25
+ batch_size: int = 128
26
+ epochs: int = 1
27
+ steps_per_epoch: int = 30
28
+ validation_batches: int = 1
29
+
30
+
31
+ @dataclass
32
+ class GenerationRule:
33
+ prompts_generator: Callable
34
+ sample_strategy: Callable
35
+
36
+
37
+ @dataclass
38
+ class CheckpointRules:
39
+ training: CheckpointConfig = field(default_factory=CheckpointConfig)
40
+ testing: CheckpointConfig = field(default_factory=CheckpointConfig)
41
+ deployment: CheckpointConfig = field(default_factory=CheckpointConfig)
42
+
43
+ def resolve_training_rule(
44
+ self,
45
+ default_dirs: list[Path | str] | None = None
46
+ ) -> dict:
47
+ return self._resolve_rule(self.training, default_dirs)
48
+
49
+ def resolve_testing_rule(
50
+ self,
51
+ default_dirs: list[Path | str] | None = None
52
+ ) -> dict:
53
+ return self._resolve_rule(self.testing, default_dirs)
54
+
55
+ def resolve_deployment_rule(
56
+ self,
57
+ default_dirs: list[Path | str] | None = None
58
+ ) -> dict:
59
+ return self._resolve_rule(self.deployment, default_dirs)
60
+
61
+ @staticmethod
62
+ def _resolve_rule(checkpoint: CheckpointConfig, default_dirs: list[Path | str] | None) -> dict:
63
+ dirs = checkpoint.dirs if checkpoint.dirs is not None else default_dirs
64
+ return {
65
+ "dirs": dirs,
66
+ "path": checkpoint.path,
67
+ "epoch": checkpoint.epoch,
68
+ "suffix": checkpoint.suffix
69
+ }
pipeline/base/generation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 与生成有关的组件
3
+ """
4
+
5
+ import pathlib
6
+ from dataclasses import dataclass
7
+ from typing import Any, Callable
8
+
9
+ import keras
10
+ import numpy as np
11
+ from keras import callbacks, ops
12
+
13
+ from env.vocab import PAD
14
+ from env.logger import get_logger
15
+ from pipeline.base.model_builder import GenerationContext, GenerationResult, ModelArtifact
16
+
17
+
18
+ def generate_with_training_model(
19
+ model: keras.Model,
20
+ context: GenerationContext,
21
+ prompt_tokens: list[int]
22
+ ) -> GenerationResult:
23
+ prompt_length = len(prompt_tokens)
24
+
25
+ if prompt_length == 0:
26
+ return GenerationResult([], "<|empty|>")
27
+
28
+ tokens = prompt_tokens + [PAD] * (context.max_length - prompt_length)
29
+
30
+ for i in range(prompt_length, context.max_length):
31
+ prediction = model.predict(np.array([tokens]), verbose=0)
32
+ prediction = prediction[0, i - 1]
33
+ next_token = ops.convert_to_numpy(context.sample_fn(prediction))
34
+ next_token_id = np.array(next_token).item()
35
+ tokens[i] = next_token_id
36
+
37
+ if next_token_id == context.end_of_text:
38
+ return GenerationResult(tokens[:i], "<|endoftext|>")
39
+ if next_token_id == PAD:
40
+ return GenerationResult(tokens[:i], "<|pad|>")
41
+
42
+ return GenerationResult(tokens, "<|maxlength|>")
43
+
44
+
45
+ def generate_with_stateful_model(
46
+ model: keras.Model,
47
+ context: GenerationContext,
48
+ prompt_tokens: list[int],
49
+ initial_states: list
50
+ ) -> GenerationResult:
51
+ if not prompt_tokens:
52
+ return GenerationResult([], "<|empty|>")
53
+
54
+ tokens = list(prompt_tokens)
55
+ batch_tokens = np.array([tokens])
56
+ logits, *states = model.predict([batch_tokens] + initial_states, verbose=0)
57
+
58
+ for _ in range(len(tokens), context.max_length):
59
+ next_token = ops.convert_to_numpy(context.sample_fn(logits[0]))
60
+ next_token_id = np.array(next_token).item()
61
+ tokens.append(next_token_id)
62
+
63
+ if next_token_id == context.end_of_text:
64
+ return GenerationResult(tokens[:-1], "<|endoftext|>")
65
+ if next_token_id <= PAD:
66
+ return GenerationResult(tokens, "<|pad|>")
67
+
68
+ logits, *states = model.predict([np.array([[next_token_id]])] + states, verbose=0)
69
+
70
+ return GenerationResult(tokens, "<|maxlength|>")
71
+
72
+
73
+ @dataclass
74
+ class TextGenerationResult:
75
+ text: str
76
+ stop_reason: str
77
+
78
+
79
+ class TextGenerator:
80
+ def __init__(
81
+ self,
82
+ artifact: ModelArtifact,
83
+ tokenizer: Any,
84
+ decode: Callable,
85
+ end_of_text: int,
86
+ sample_fn: Callable,
87
+ max_length: int
88
+ ):
89
+ self.artifact = artifact
90
+ self.tokenizer = tokenizer
91
+ self.decode = decode
92
+ self.context = GenerationContext(
93
+ end_of_text=end_of_text,
94
+ max_length=max_length,
95
+ sample_fn=sample_fn
96
+ )
97
+
98
+ def generate_tokens(
99
+ self,
100
+ prompt: str,
101
+ max_length: int | None = None,
102
+ sample_fn: Callable | None = None
103
+ ) -> GenerationResult:
104
+ context = GenerationContext(
105
+ end_of_text=self.context.end_of_text,
106
+ max_length=max_length if max_length is not None else self.context.max_length,
107
+ sample_fn=sample_fn if sample_fn is not None else self.context.sample_fn
108
+ )
109
+ prompt_tokens = self._tokenize_prompt(prompt)
110
+ return self.artifact.generate(context, prompt_tokens)
111
+
112
+ def generate_text(
113
+ self,
114
+ prompt: str,
115
+ max_length: int | None = None,
116
+ sample_fn: Callable | None = None
117
+ ) -> TextGenerationResult:
118
+ result = self.generate_tokens(prompt, max_length, sample_fn)
119
+ return TextGenerationResult(
120
+ text=self.decode(result.token_ids),
121
+ stop_reason=result.stop_reason
122
+ )
123
+
124
+ def _tokenize_prompt(self, prompt: str) -> list[int]:
125
+ prompt_tokens = list(ops.convert_to_numpy(self.tokenizer(prompt)))
126
+ return [token for token in prompt_tokens if token > PAD]
127
+
128
+
129
+ class GenerationCallback(callbacks.Callback):
130
+ def __init__(
131
+ self,
132
+ prompts: list[str],
133
+ log_file: pathlib.Path,
134
+ tokenizer: Any,
135
+ decode: Callable,
136
+ end_of_text: int,
137
+ max_length: int,
138
+ sample_fn: Callable,
139
+ training_artifact: ModelArtifact
140
+ ):
141
+ super().__init__()
142
+ self.prompts = prompts
143
+ self.tokenizer = tokenizer
144
+ self.decode = decode
145
+ self.end_of_text = end_of_text
146
+ self.max_length = max_length
147
+ self.sample_fn = sample_fn
148
+ self.training_artifact = training_artifact
149
+ self.logger = self.init_logger(log_file)
150
+
151
+ def on_epoch_end(self, epoch, logs=None):
152
+ generator = TextGenerator(
153
+ artifact=self.training_artifact,
154
+ tokenizer=self.tokenizer,
155
+ decode=self.decode,
156
+ end_of_text=self.end_of_text,
157
+ max_length=self.max_length,
158
+ sample_fn=self.sample_fn
159
+ )
160
+ self.logger.info(f"\nGenerated text after epoch {epoch + 1}:")
161
+ for i, prompt in enumerate(self.prompts):
162
+ result = generator.generate_text(prompt)
163
+ self.logger.info(f"Prompt {i + 1:2}: {prompt}")
164
+ self.logger.info(f"Generated: {result.text}{result.stop_reason}\n")
165
+
166
+ @staticmethod
167
+ def init_logger(log_file: pathlib.Path):
168
+ if not log_file.parent.exists():
169
+ log_file.parent.mkdir(parents=True)
170
+
171
+ logger = get_logger("GenerationCallback", filepath=str(log_file))
172
+ logger.info("Initialized GenerationCallback logger")
173
+
174
+ return logger