Spaces:
Running
Running
Commit ·
a5fd608
0
Parent(s):
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
Browse files模型架构:
- 手写 Mini GPT(Transformer):实现 PositionalEmbedding、TransformerDecoder 组件
- 手写 RNN:LSTM 堆叠结构
训练系统:
- Pipeline 框架:数据加载、Tokenizer、训练、生成全流程封装
- Checkpoint 机制:支持断点续训、分代模型保存和加载
- TensorBoard 训练监控
任务实现:
- Wiki GPT:基于中文维基语料的中文文本生成
- 诗歌生成器(GPT):基于 Transformer 的诗歌生成
- 诗歌生成器(RNN):基于 LSTM 的诗歌生成
工程支持:
- Gradio 交互界面
- Hugging Face Space 部署配置
- pytest 测试套件
其他特点:
- 多种采样方法:top-k、随机采样(temperature)、贪婪搜索
- Pipeline 重要支持组件:ModelBuidler、DataBundle
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +3 -0
- AGENTS.md +220 -0
- README.md +146 -0
- app.py +58 -0
- data/__init__.py +25 -0
- data/base.py +85 -0
- data/common.py +147 -0
- data/dev/mini_c4/file1.txt +3 -0
- data/dev/mini_c4/file2.txt +4 -0
- data/dev/mini_c4/file3.txt +3 -0
- data/dev/poetry/元.csv +11 -0
- data/dev/poetry/先秦.csv +5 -0
- data/dev/poetry/南北朝.csv +6 -0
- data/poetry/__init__.py +21 -0
- data/poetry/dataset.py +104 -0
- data/poetry/loader.py +59 -0
- data/poetry/runner.py +38 -0
- data/poetry/tokenizer.py +67 -0
- data/poetry/transformer.py +38 -0
- data/runner.py +142 -0
- data/tokenizers.py +89 -0
- data/wiki/__init__.py +8 -0
- data/wiki/dataset.py +100 -0
- data/wiki/loader.py +52 -0
- data/wiki/runner.py +32 -0
- data/wiki/tokenizer.py +60 -0
- data/wiki/transformer.py +47 -0
- data/wiki/wiki_cleaner.py +122 -0
- docs/TODOs.md +3 -0
- docs/pycharm.md +10 -0
- env/keras.py +11 -0
- env/logger.py +52 -0
- env/resolve.py +85 -0
- env/runner.py +23 -0
- env/vocab.py +2 -0
- environment-linux.yml +15 -0
- environment.yml +17 -0
- generate_requirements.py +110 -0
- models/__init__.py +0 -0
- models/mini_gpt/__init__.py +1 -0
- models/mini_gpt/gpt_components.py +61 -0
- models/mini_gpt/model_builder.py +54 -0
- models/rnn/__init__.py +1 -0
- models/rnn/model_builder.py +114 -0
- pipeline/__init__.py +3 -0
- pipeline/base/__init__.py +0 -0
- pipeline/base/checkpoint.py +147 -0
- pipeline/base/configs.py +69 -0
- pipeline/base/generation.py +174 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.keras filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
/.idea
|
| 3 |
+
/local
|
AGENTS.md
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agent 编码规范
|
| 2 |
+
|
| 3 |
+
## 防御性编程精简
|
| 4 |
+
|
| 5 |
+
避免过度防御性编程,遵循以下原则:
|
| 6 |
+
|
| 7 |
+
### 1. None 检查
|
| 8 |
+
|
| 9 |
+
- **不要**进行显式的 None 检查
|
| 10 |
+
- 信任输入数据,让程序在真正的错误点上失败
|
| 11 |
+
- 避免 `if x is not None:` 这样的防御性代码
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
# ❌ 避免
|
| 15 |
+
def process(data):
|
| 16 |
+
if data is not None:
|
| 17 |
+
return data.value
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
# ✅ 推荐
|
| 21 |
+
def process(data):
|
| 22 |
+
return data.value
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### 2. 类型检查
|
| 26 |
+
|
| 27 |
+
- **不要**使用 `isinstance`、`type()`、`typeof` 等进行运行时类型检查
|
| 28 |
+
- 依靠类型提示和静态类型检查工具(如 mypy)
|
| 29 |
+
- 让 Duck Typing 发挥作用
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
# ❌ 避免
|
| 33 |
+
def calculate(obj):
|
| 34 |
+
if isinstance(obj, int):
|
| 35 |
+
return obj * 2
|
| 36 |
+
elif isinstance(obj, str):
|
| 37 |
+
return obj * 2
|
| 38 |
+
else:
|
| 39 |
+
raise TypeError("不支持的类型")
|
| 40 |
+
|
| 41 |
+
# ✅ 推荐
|
| 42 |
+
def calculate(obj: int | str) -> int | str:
|
| 43 |
+
return obj * 2
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 3. 异常处理
|
| 47 |
+
|
| 48 |
+
- **不要**滥用 try-catch 来压制异常
|
| 49 |
+
- **不要**用 try-catch 让程序"容错"运行
|
| 50 |
+
- 只在真正需要处理异常的地方捕获
|
| 51 |
+
- 让未处理的异常自然抛出,暴露真正的问题
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
# ❌ 避免 - 压制异常
|
| 55 |
+
import logging
|
| 56 |
+
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
def parse_config(path):
|
| 60 |
+
try:
|
| 61 |
+
with open(path) as f:
|
| 62 |
+
return json.load(f)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"加载配置失败: {e}")
|
| 65 |
+
return {} # 返回空配置让程序继续运行
|
| 66 |
+
|
| 67 |
+
# ✅ 推荐 - 让异常传播
|
| 68 |
+
def parse_config(path):
|
| 69 |
+
with open(path) as f:
|
| 70 |
+
return json.load(f)
|
| 71 |
+
|
| 72 |
+
# ✅ 或仅在必要时转换异常类型
|
| 73 |
+
def parse_config(path):
|
| 74 |
+
try:
|
| 75 |
+
with open(path) as f:
|
| 76 |
+
return json.load(f)
|
| 77 |
+
except json.JSONDecodeError as e:
|
| 78 |
+
raise ConfigError(f"配置文件格式错误: {e}") from e
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### 4. 原则总结
|
| 82 |
+
|
| 83 |
+
1. **早失败(Fail Fast)** - 让错误尽早暴露,不要试图掩盖
|
| 84 |
+
2. **信任调用方** - 假设调用方会提供正确的输入
|
| 85 |
+
3. **清晰错误信息** - 让异常信息直接指出问题所在
|
| 86 |
+
4. **代码简洁** - 减少不必要的检查代码,专注于业务逻辑
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
**核心信条**:清晰的代码比健壮的代码更重要。让错误暴露,让问题可见。
|
| 91 |
+
|
| 92 |
+
## 编辑文件时的精准修改原则
|
| 93 |
+
|
| 94 |
+
在进行代码编辑时,**只修改必要的部分**,不要进行任何无关改动:
|
| 95 |
+
|
| 96 |
+
### 禁止的无关改动
|
| 97 |
+
|
| 98 |
+
- **不要**调整代码缩进或格式
|
| 99 |
+
- **不要**重排 import 语句的顺序
|
| 100 |
+
- **不要**添加或删除空行
|
| 101 |
+
- **不要**修改注释(除非任务明确要求)
|
| 102 |
+
- **不要**修改变量名、函数名等标识符(除非任务明确要求)
|
| 103 |
+
- **不要**进行任何代码重构(除非任务明确要求)
|
| 104 |
+
|
| 105 |
+
### ✅ 正确示例
|
| 106 |
+
|
| 107 |
+
如果任务是将 `import config` 改为 `from mini_gpt import config`:
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
# 修改前
|
| 111 |
+
import config
|
| 112 |
+
from typing import Callable
|
| 113 |
+
|
| 114 |
+
# 修改后 - 只修改 import 语句,其他保持不变
|
| 115 |
+
from mini_gpt import config
|
| 116 |
+
from typing import Callable
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### ❌ 错误示例
|
| 120 |
+
|
| 121 |
+
#### 示例1:无关地调整 import 顺序
|
| 122 |
+
|
| 123 |
+
```python
|
| 124 |
+
# 修改前
|
| 125 |
+
import config
|
| 126 |
+
from typing import Callable
|
| 127 |
+
|
| 128 |
+
# 错误 - 无关地调整了 import 顺序
|
| 129 |
+
from typing import Callable
|
| 130 |
+
from mini_gpt import config
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
#### 示例2:无关地修改函数参数格式
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
# 修改前
|
| 137 |
+
def my_function(
|
| 138 |
+
param1,
|
| 139 |
+
param2,
|
| 140 |
+
param3,
|
| 141 |
+
):
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
# 错误 - 任务只要求修改函数体,却无关地修改了参数格式
|
| 145 |
+
def my_function(
|
| 146 |
+
param1, # 调整了缩进宽度
|
| 147 |
+
param2,
|
| 148 |
+
param3 # 去掉了尾部逗号
|
| 149 |
+
):
|
| 150 |
+
pass
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
**原则**:最小化改动范围,只改必须改的地方。
|
| 154 |
+
|
| 155 |
+
## 运行单元测试
|
| 156 |
+
|
| 157 |
+
本项目使用 pytest 运行单元测试,必须在 `mini-gpt` conda 环境中执行。
|
| 158 |
+
|
| 159 |
+
### 运行命令
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
/Users/run/anaconda3/envs/mini-gpt/bin/python -m pytest test/ -v
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### 重要提示
|
| 166 |
+
|
| 167 |
+
1. **必须使用 mini-gpt 环境** - 基础环境缺少 tensorflow 依赖,会导致测试收集失败
|
| 168 |
+
2. **不要添加 `pytest.importorskip("tensorflow")`** - 这些测试依赖 tensorflow,跳过会掩盖真正的问题
|
| 169 |
+
|
| 170 |
+
## Python 代码风格
|
| 171 |
+
|
| 172 |
+
### 禁止尾逗号
|
| 173 |
+
|
| 174 |
+
**任何情况下都不应出现尾逗号**(trailing comma)。
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
# ❌ 避免 - 尾逗号
|
| 178 |
+
my_list = [
|
| 179 |
+
1,
|
| 180 |
+
2,
|
| 181 |
+
3,
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
# ✅ 推荐
|
| 185 |
+
my_list = [
|
| 186 |
+
1,
|
| 187 |
+
2,
|
| 188 |
+
3
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
# ❌ 避免 - 函数参数尾逗号
|
| 192 |
+
def my_func(
|
| 193 |
+
arg1,
|
| 194 |
+
arg2,
|
| 195 |
+
):
|
| 196 |
+
pass
|
| 197 |
+
|
| 198 |
+
# ✅ 推荐
|
| 199 |
+
def my_func(
|
| 200 |
+
arg1,
|
| 201 |
+
arg2
|
| 202 |
+
):
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
# ❌ 避免 - 字典尾逗号
|
| 206 |
+
my_dict = {
|
| 207 |
+
"key1": "value1",
|
| 208 |
+
"key2": "value2",
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
# ✅ 推荐
|
| 212 |
+
my_dict = {
|
| 213 |
+
"key1": "value1",
|
| 214 |
+
"key2": "value2"
|
| 215 |
+
}
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
## 禁止命令行参数
|
| 219 |
+
|
| 220 |
+
永远不要在代码中使用命令行参数(如 `argparse`、`sys.argv` 等)。配置应通过代码中硬编码实现。
|
README.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: General Deep Learning
|
| 3 |
+
emoji: 🏃
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.12.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: General Deep Learning is a practical deep learning experimen
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# 通用深度学习(General Deep Learning)
|
| 15 |
+
|
| 16 |
+
## 项目简介
|
| 17 |
+
|
| 18 |
+
**通用深度学习(General Deep Learning)** 是一个面向实践的深度学习实验平台,致力于打造"训练-部署-体验"一体化的完整工作流。
|
| 19 |
+
|
| 20 |
+
### ✨ 为什么适合你?
|
| 21 |
+
|
| 22 |
+
**🎯 我的愿景**
|
| 23 |
+
- 构建一个**从零开始、透明可学、工程模块化**的深度学习平台。
|
| 24 |
+
|
| 25 |
+
**🎓学习友好**
|
| 26 |
+
- ✅ **纯手工从零构建** - Transformer、RNN 都是一行行代码手撸
|
| 27 |
+
- ✅ **代码即教程** - 没有黑盒封装,每个组件清晰可见
|
| 28 |
+
- ✅ **完整的训练闭环** - 从数据处理到部署,全流程透明
|
| 29 |
+
-
|
| 30 |
+
**🔧 技术特性**
|
| 31 |
+
- ✅ **覆盖主流模型** - Transformer、RNN,未来将扩展至 CNN、Diffusion 等
|
| 32 |
+
- ✅ **模块化架构** - 可插拔设计,新模型/新数据集快速接入
|
| 33 |
+
- ✅ **生产级部署** - 一键部署到 Hugging Face,支持断点续训、TensorBoard 监控
|
| 34 |
+
|
| 35 |
+
### 📅 关于这个项目
|
| 36 |
+
|
| 37 |
+
> *历时俩月,忙里偷闲。*
|
| 38 |
+
|
| 39 |
+
这不是一个追求最新模型的项目,而是一个**"代码即教程"**的个人实验场。
|
| 40 |
+
|
| 41 |
+
**已完成功能**:
|
| 42 |
+
- Wiki GPT - 基于中文维基的手写 Transformer
|
| 43 |
+
- 诗歌生成器 - GPT 和 RNN 双版本对比
|
| 44 |
+
|
| 45 |
+
**未来规划**:
|
| 46 |
+
4 月有事不再投入,5 月开始计划每月新增一个模型,探索更多架构(CNN、Diffusion...)
|
| 47 |
+
|
| 48 |
+
- 🔮 逐步扩展至 CV、多模态等领域
|
| 49 |
+
- 🔮 保持"从零手撸"的风格,让每个新模型都成为学习素材
|
| 50 |
+
|
| 51 |
+
**欢迎一起折腾** —— 反馈问题、贡献代码,或单纯聊聊技术!
|
| 52 |
+
|
| 53 |
+
### 🤗 在线体验
|
| 54 |
+
|
| 55 |
+
[](https://huggingface.co/spaces/yetrun/general-deep-learning)
|
| 56 |
+
|
| 57 |
+
🚀 **在线体验**:[点击访问 Hugging Face Space](https://huggingface.co/spaces/yetrun/general-deep-learning)
|
| 58 |
+
|
| 59 |
+
本项目已部署到 Hugging Face Space,你可以在线体验以下功能:
|
| 60 |
+
|
| 61 |
+
- **Wiki GPT 文本生成**:基于 Transformer 架构的中文文本生成,训练数据来自中文维基语料库
|
| 62 |
+
- **诗歌生成器(GPT)**:基于 Transformer 的中文诗歌生成,支持五言、七言诗等
|
| 63 |
+
- **诗歌生成器(RNN)**:基于 RNN 架构的中文诗歌生成,支持五言、七言诗等
|
| 64 |
+
|
| 65 |
+
## 部署说明
|
| 66 |
+
|
| 67 |
+
本项目已配置为 Hugging Face Space 兼容格式,如需更新部署:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# 1. 在 Hugging Face 创建新的 Space(选择 Gradio SDK)
|
| 71 |
+
# 2. 绑定 Space 远程仓库
|
| 72 |
+
git remote add huggingface https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 73 |
+
# 3. 确保依赖同步(生成 requirements.txt)
|
| 74 |
+
python3 generate_requirements.py
|
| 75 |
+
# 4. 提交并推送
|
| 76 |
+
git push huggingface master
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## 本地开发
|
| 80 |
+
|
| 81 |
+
### Conda 环境使用
|
| 82 |
+
|
| 83 |
+
使用方法:
|
| 84 |
+
```bash
|
| 85 |
+
# 创建环境
|
| 86 |
+
conda env create -f <environment.yml>
|
| 87 |
+
# 激活环境
|
| 88 |
+
conda activate general-dl
|
| 89 |
+
# 更新 environment.yml
|
| 90 |
+
conda env update -f <environment.yml> --prune
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
上述 `<environment.yml>` 是环境配置文件的路径,需要替换成实际的文件名:
|
| 94 |
+
|
| 95 |
+
- 如果你是本地开发,使用 `environment.yml`(Mac Intel 64 环境,`ENV=test`)
|
| 96 |
+
- 如果你是在远程服务器上运行,使用 `environments-linux.yml`(Linux 服务器环境,`ENV=production`)
|
| 97 |
+
|
| 98 |
+
> **插曲:**
|
| 99 |
+
>
|
| 100 |
+
> 环境配置出现了问题,强制重新安装 tensorflow-text 才修复。
|
| 101 |
+
>
|
| 102 |
+
> ```bash
|
| 103 |
+
> pip uninstall tensorflow-text -y
|
| 104 |
+
> pip install --no-cache-dir --force-reinstall tensorflow-text==2.20.0
|
| 105 |
+
> ```
|
| 106 |
+
|
| 107 |
+
### 开发工具配置
|
| 108 |
+
|
| 109 |
+
#### TensorBoard 说明
|
| 110 |
+
|
| 111 |
+
训练时,调用 `tensorboard --logdir=<logdir>` 来启动 TensorBoard,默认访问地址是 http://localhost:6006/.
|
| 112 |
+
|
| 113 |
+
`<logdir>` 通常是 `local/tasks/<project_name>/tensorboard`.
|
| 114 |
+
|
| 115 |
+
> 冷知识:tensorboard 中的代数与我们常规认为的代数不一致,第一代的计数是 0.
|
| 116 |
+
|
| 117 |
+
#### JetBrains 远程开发配置
|
| 118 |
+
|
| 119 |
+
配置本地代码映射:
|
| 120 |
+
|
| 121 |
+
1. 菜单栏:Tools → Deployment → Configuration
|
| 122 |
+
2. 配置目录映射:切换到Mappings标签页,Deployment path 设置远程目录路径
|
| 123 |
+
3. 配置排除目录,一般可排除的本地目录包括:`data/dev`, `local`, `test`.
|
| 124 |
+
|
| 125 |
+
手工同步:
|
| 126 |
+
|
| 127 |
+
- 右键文件/目录 → Deployment → Upload to...
|
| 128 |
+
|
| 129 |
+
## 数据集说明
|
| 130 |
+
|
| 131 |
+
### WIKI 数据集
|
| 132 |
+
|
| 133 |
+
*(本项目中 `wiki_gpt` 任务使用了中文维基语料库进行训练)*
|
| 134 |
+
|
| 135 |
+
下载维基百科的数据。
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
wget https://dumps.wikimedia.org/other/mediawiki_content_current/zhwiki/2026-01-01/xml/bzip2/zhwiki-2026-01-01-p1p5254490.xml.bz2
|
| 139 |
+
wget https://dumps.wikimedia.org/other/mediawiki_content_current/zhwiki/2026-01-01/xml/bzip2/zhwiki-2026-01-01-p5254491p9382552.xml.bz2
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
维基百科的数据分成两个文件,可使用 cat 命令合并成一个文件:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
cat zhwiki-2026-01-01-p1p5254490.xml.bz2 zhwiki-2026-01-01-p5254491p9382552.xml.bz2 > zhwiki-2026-01-01.xml.bz2
|
| 146 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI 文本生成工具集 - 多页面 Gradio 应用
|
| 3 |
+
|
| 4 |
+
入口点,提供导航到各个子应用:
|
| 5 |
+
- /:首页导航
|
| 6 |
+
- /wiki_gpt:Wiki GPT 文本生成器
|
| 7 |
+
- /poetry_gpt:诗歌生成器(GPT)
|
| 8 |
+
- /poetry_rnn:诗歌生成器(RNN)
|
| 9 |
+
|
| 10 |
+
特点:
|
| 11 |
+
- 每个子页面可以独立运行测试
|
| 12 |
+
"""
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from tasks.wiki_gpt.gradio import demo as wiki_gpt_demo
|
| 16 |
+
from tasks.poetry_gpt.gradio import demo as poetry_gpt_demo
|
| 17 |
+
from tasks.poetry_rnn.gradio import demo as poetry_rnn_demo
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
with gr.Blocks(title="AI 文本生成工具集") as demo:
|
| 21 |
+
gr.Markdown("# AI 文本生成工具集")
|
| 22 |
+
gr.Markdown("请选择要使用的应用:")
|
| 23 |
+
|
| 24 |
+
with gr.Row():
|
| 25 |
+
with gr.Column():
|
| 26 |
+
gr.Markdown("## 诗歌生成器(GPT)")
|
| 27 |
+
gr.Markdown("基于 Transformer 的中文诗歌生成,支持五言、七言诗等。")
|
| 28 |
+
gr.Button("进入诗歌生成器", link="/poetry_gpt")
|
| 29 |
+
|
| 30 |
+
with gr.Column():
|
| 31 |
+
gr.Markdown("## 诗歌生成器(RNN)")
|
| 32 |
+
gr.Markdown("基于 RNN 的中文诗歌生成,支持五言、七言诗等。")
|
| 33 |
+
gr.Button("进入诗歌生成器", link="/poetry_rnn")
|
| 34 |
+
|
| 35 |
+
with gr.Column():
|
| 36 |
+
gr.Markdown("## Wiki GPT 文本生成")
|
| 37 |
+
gr.Markdown("基于 Transformer 的中文文本生成,训练来自于中文维基语料库。")
|
| 38 |
+
gr.Button("进入 Wiki GPT", link="/wiki_gpt")
|
| 39 |
+
|
| 40 |
+
gr.Markdown("---")
|
| 41 |
+
gr.Markdown("### 说明")
|
| 42 |
+
gr.Markdown("每个应用都是独立加载的,进入页面后需要等待模型加载完成。")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
with demo.route("诗歌生成器(GPT)", "/poetry_gpt"):
|
| 46 |
+
poetry_gpt_demo.render()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
with demo.route("诗歌生成器(RNN)", "/poetry_rnn"):
|
| 50 |
+
poetry_rnn_demo.render()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
with demo.route("Wiki GPT", "/wiki_gpt"):
|
| 54 |
+
wiki_gpt_demo.render()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
demo.launch()
|
data/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""数据集模块
|
| 2 |
+
|
| 3 |
+
提供统一的数据集接口,包括 Wiki 和诗歌数据集。
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
from data import WikiDataset, PoetryDataset
|
| 7 |
+
|
| 8 |
+
# Wiki 数据集
|
| 9 |
+
wiki = WikiDataset(data_dir="~/data/wiki/mini_c4")
|
| 10 |
+
doc_ds = wiki.doc_ds()
|
| 11 |
+
tokens_ds = wiki.tokens_ds(seq_length=256, batch_size=32)
|
| 12 |
+
wiki.stat(seq_length=256)
|
| 13 |
+
|
| 14 |
+
# 诗歌数据集
|
| 15 |
+
poetry = PoetryDataset(data_dir="~/data/Poetry")
|
| 16 |
+
doc_ds = poetry.doc_ds()
|
| 17 |
+
tokens_ds = poetry.tokens_ds(seq_length=100, batch_size=128)
|
| 18 |
+
poetry.stat(seq_length=100)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from data.base import DataBundle, TokenizerBundle
|
| 22 |
+
from data.wiki import WikiDataset
|
| 23 |
+
from data.poetry import PoetryDataset
|
| 24 |
+
|
| 25 |
+
__all__ = ["DataBundle", "TokenizerBundle", "WikiDataset", "PoetryDataset"]
|
data/base.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""数据集抽象基类模块
|
| 2 |
+
|
| 3 |
+
定义 DataBundle 抽象基类,统一数据集的接口规范。
|
| 4 |
+
每个具体的数据集(如 Wiki、诗歌)都应该继承此类并实现相应方法。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Callable, Optional
|
| 10 |
+
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TokenizerBundle:
|
| 16 |
+
"""分词器信息包装类
|
| 17 |
+
|
| 18 |
+
将分词器相关的属性打包在一起,简化 DataBundle 接口。
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
tokenizer: Callable
|
| 22 |
+
decode: Callable
|
| 23 |
+
end_of_text: int
|
| 24 |
+
vocab_size: int
|
| 25 |
+
vocab_path: str = ""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class DataBundle(ABC):
|
| 30 |
+
"""数据集抽象基类
|
| 31 |
+
|
| 32 |
+
将数据加载、分词、统计等功能绑定在一起,提供统一的数据集接口。
|
| 33 |
+
|
| 34 |
+
Usage:
|
| 35 |
+
dataset = WikiDataset(data_dir="~/data/wiki")
|
| 36 |
+
doc_ds = dataset.doc_ds()
|
| 37 |
+
tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
|
| 38 |
+
dataset.stat()
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
data_dir: str
|
| 42 |
+
sequence_length: int = 256
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def doc_ds(self) -> tf.data.Dataset:
|
| 46 |
+
"""返回原始文档数据集
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
TensorFlow Dataset,每个元素是一个文档字符串
|
| 50 |
+
"""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
|
| 55 |
+
"""返回 tokenized 数据集
|
| 56 |
+
|
| 57 |
+
将原始文档转换为 token ID 序列,并分割为训练样本。
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
seq_length: 序列长度
|
| 61 |
+
batch_size: 批次大小
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
|
| 65 |
+
"""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def tokenizer_bundle(self) -> TokenizerBundle:
|
| 70 |
+
"""返回分词器信息"""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def stat(self, seq_length: int | None = None) -> None:
|
| 74 |
+
"""打印数据集统计信息
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
seq_length: 序列长度,用于估算训练样本数
|
| 78 |
+
"""
|
| 79 |
+
from data.common import collect_stats
|
| 80 |
+
|
| 81 |
+
info = self.tokenizer_bundle()
|
| 82 |
+
stats = collect_stats(
|
| 83 |
+
name=self.__class__.__name__, loader=self.doc_ds, tokenizer=info.tokenizer
|
| 84 |
+
)
|
| 85 |
+
stats.print_report(seq_length=seq_length)
|
data/common.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""数据集共享工具模块
|
| 2 |
+
|
| 3 |
+
提供数据集统计、报告生成等共享功能。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Callable
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
from keras import layers
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DatasetStats:
|
| 17 |
+
"""数据集统计结果"""
|
| 18 |
+
|
| 19 |
+
name: str
|
| 20 |
+
doc_count: int
|
| 21 |
+
total_chars: int
|
| 22 |
+
total_tokens: int
|
| 23 |
+
max_length: int
|
| 24 |
+
median_length: int
|
| 25 |
+
|
| 26 |
+
def print_report(self, seq_length: int | None = 256):
|
| 27 |
+
"""打印统一格式的统计报表
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
seq_length: 序列长度,用于估算训练样本数。
|
| 31 |
+
为 None 时表示不切割,一个文档一个样本。
|
| 32 |
+
"""
|
| 33 |
+
avg_chars = self.total_chars / self.doc_count if self.doc_count > 0 else 0
|
| 34 |
+
avg_tokens = self.total_tokens / self.doc_count if self.doc_count > 0 else 0
|
| 35 |
+
|
| 36 |
+
print()
|
| 37 |
+
print("=" * 60)
|
| 38 |
+
print(f"{self.name} 数据集统计")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
print(f"{'文档数:':<20} {self.doc_count:>15,}")
|
| 41 |
+
print(f"{'总字符数:':<20} {self.total_chars:>15,}")
|
| 42 |
+
print(f"{'总 Token 数:':<20} {self.total_tokens:>15,}")
|
| 43 |
+
print("-" * 60)
|
| 44 |
+
print(f"{'平均每文档字符数:':<20} {avg_chars:>15.1f}")
|
| 45 |
+
print(f"{'平均每文档 Token 数:':<20} {avg_tokens:>15.1f}")
|
| 46 |
+
print(f"{'最长文档字符数:':<20} {self.max_length:>15,}")
|
| 47 |
+
print(f"{'文档长度中位数:':<20} {self.median_length:>15,}")
|
| 48 |
+
print("=" * 60)
|
| 49 |
+
|
| 50 |
+
if self.total_tokens > 0:
|
| 51 |
+
print()
|
| 52 |
+
if seq_length is None:
|
| 53 |
+
print(f"训练样本数: {self.doc_count:,} 个 (一个文档一个样本)")
|
| 54 |
+
else:
|
| 55 |
+
print(f"训练样本预估 (seq={seq_length}):")
|
| 56 |
+
print(f" 可生成约 {self.total_tokens // seq_length:,} 个训练样本")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def collect_stats(
|
| 60 |
+
name: str, loader: Callable[[], tf.data.Dataset], tokenizer: Callable
|
| 61 |
+
) -> DatasetStats:
|
| 62 |
+
"""从 DatasetLoader 收集统计数据
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
name: 数据集名称(用于报表显示)
|
| 66 |
+
loader: 返回 tf.data.Dataset 的加载器函数
|
| 67 |
+
tokenizer: 分词器函数,接收文本返回 token ID 列表
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
DatasetStats 统计结果对象
|
| 71 |
+
"""
|
| 72 |
+
ds = loader()
|
| 73 |
+
|
| 74 |
+
doc_count = 0
|
| 75 |
+
total_chars = 0
|
| 76 |
+
total_tokens = 0
|
| 77 |
+
lengths = []
|
| 78 |
+
|
| 79 |
+
for item in ds:
|
| 80 |
+
text = item.numpy().decode("utf-8")
|
| 81 |
+
if not text.strip():
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
doc_count += 1
|
| 85 |
+
total_chars += len(text)
|
| 86 |
+
lengths.append(len(text))
|
| 87 |
+
|
| 88 |
+
# Token 统计,过滤掉末尾的 padding (值为 0 的 token)
|
| 89 |
+
try:
|
| 90 |
+
import keras
|
| 91 |
+
|
| 92 |
+
token_ids = keras.ops.convert_to_numpy(tokenizer(text))
|
| 93 |
+
except ImportError:
|
| 94 |
+
# Fallback: assume tokenizer returns numpy array directly
|
| 95 |
+
token_ids = np.array(tokenizer(text))
|
| 96 |
+
|
| 97 |
+
# 只去掉末尾的 0,保留中间内容(包括中间的 OOV/padding)
|
| 98 |
+
valid_tokens = np.trim_zeros(token_ids, "b")
|
| 99 |
+
total_tokens += len(valid_tokens)
|
| 100 |
+
|
| 101 |
+
return DatasetStats(
|
| 102 |
+
name=name,
|
| 103 |
+
doc_count=doc_count,
|
| 104 |
+
total_chars=total_chars,
|
| 105 |
+
total_tokens=total_tokens,
|
| 106 |
+
max_length=max(lengths) if lengths else 0,
|
| 107 |
+
median_length=int(np.median(lengths)) if lengths else 0,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def save_vocabulary(vocab: list[str], vocab_path: pathlib.Path) -> None:
|
| 112 |
+
"""保存词汇表到文件
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
vocab: 词汇表列表
|
| 116 |
+
vocab_path: 保存路径
|
| 117 |
+
"""
|
| 118 |
+
vocab_path.parent.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
with open(vocab_path, "w", encoding="utf-8") as f:
|
| 120 |
+
for char in vocab:
|
| 121 |
+
written = char if char != "\n" else r"\n"
|
| 122 |
+
f.write(written + "\n")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_vocab_from_dataset(
|
| 126 |
+
doc_ds: tf.data.Dataset, vocab_path: pathlib.Path
|
| 127 |
+
) -> list[str]:
|
| 128 |
+
"""从文档数据集构建词汇表
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
doc_ds: 文档数据集
|
| 132 |
+
vocab_path: 词汇表保存路径
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
词汇表列表
|
| 136 |
+
"""
|
| 137 |
+
vectorizer = layers.TextVectorization(
|
| 138 |
+
output_mode="int", split="character", standardize=None
|
| 139 |
+
)
|
| 140 |
+
vectorizer.adapt(doc_ds, batch_size=128)
|
| 141 |
+
|
| 142 |
+
vocab = vectorizer.get_vocabulary()
|
| 143 |
+
if "$" not in vocab:
|
| 144 |
+
vocab = [*vocab, "$"]
|
| 145 |
+
|
| 146 |
+
save_vocabulary(vocab, vocab_path)
|
| 147 |
+
return vocab
|
data/dev/mini_c4/file1.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
first document of first file
|
| 2 |
+
second document of first file
|
| 3 |
+
third document of first file
|
data/dev/mini_c4/file2.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
first document of second file
|
| 2 |
+
second document of second file
|
| 3 |
+
third document of second file
|
| 4 |
+
fourth document of second file
|
data/dev/mini_c4/file3.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
first document of third file
|
| 2 |
+
second document of third file
|
| 3 |
+
third document of third file
|
data/dev/poetry/元.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
标题,朝代,作者,体裁,内容
|
| 2 |
+
西洱河,元,述律杰,五言排律,洱水何雄壮,源流自邓川。两关龙首尾,九曲势蜿蜒。大理城池固,金汤铁石坚。四洲从古号,三岛至今传。罗阁凭巘崄,蒙人恃极边。要当兵十万,不数客三千。世祖亲征日,初还一统天。雨师清瘴疠,风伯扫氛烟。民物因蕃富,封疆近百年。点苍山色好,铭刻尚依然。
|
| 3 |
+
陟玩春山纪兴,元,忽必烈,七言律诗,时膺韶景陟兰峰,不惮跻攀谒粹容。花色映霞祥彩混,垆烟拂雾瑞光重。雨沾琼干岩边竹,风袭琴声岭际松。净刹玉毫瞻礼罢,回程仙驾驭苍龙。
|
| 4 |
+
结联,元,奥鲁赤,句,久立危栏须北望,无边秋色杳冥冥。
|
| 5 |
+
八月初四日雪坡太守周门拓入云居山中复度岭饮于水月尼寺赋诗书似太守及苏州刺史周义卿,元,杨维桢,七言律诗,文章太守早休牙,五马传呼处士家。好客新分朱露酒,题诗近在白云窝。山中子落千年桂,海上人归八月槎。水月楼头横玉笛,误猜萼绿是韶华。
|
| 6 |
+
用顾松江韵复理贰守并柬雪坡刺史,元,杨维桢,七言律诗,仙客归来隘九州,身骑黄鹤记南游。乌衣故国江山在,铜柱荒台草木秋。起舞刘琨空有志,登高王粲不胜愁。问君蔗境今何在,祇忆当年顾虎头。
|
| 7 |
+
寄小蓬莱主者闻梅涧并简沈元方宇文仲美贤主宾,元,杨维桢,七言律诗,罗浮主者是仙才,东老诸孙亦俊哉。风雨春城花落尽,江山故国燕归来。酒盟自有乌巾在,笑口应随皓齿开。十八仙人重会处,劫灰不到小蓬莱。
|
| 8 |
+
次韵奉答倪元镇,元,杨维桢,七言律诗,坐断深林事不闻,西窗风日爱余曛。旧经高赤寻三传,新咏山王削五君。翠筱侵床落苍雪,石池洗砚动玄云。东邻书屋最相忆,莫遣草堂移浪文。
|
| 9 |
+
送谢太守,元,杨维桢,七言律诗,朝廷遣使航东海,万里南来送玺书。著屐登山良不恶,分符典郡复何如。白苏事业千年后,吴楚封疆百战馀。今日养民方急务,肯将徵算及舟车。
|
| 10 |
+
送谢太守,元,杨维桢,七言律诗,
|
| 11 |
+
回上张太尉(一云"谢赐玳瑁笔见征楚国公碑文"),元,杨维桢,七言律诗,昨夜文星照南极,今朝客省过东维。锦囊颖脱千年兔,斑管光摇九尾龟。墨卷风云随王气,恩分雨露出天池。老夫来草平蛮策,先写新封楚国碑。
|
data/dev/poetry/先秦.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
标题,朝代,作者,体裁,内容
|
| 2 |
+
禹玉牒辞,先秦,无名氏,古风,祝融司方发其英,沐日浴月百宝生。
|
| 3 |
+
衣铭,先秦,无名氏,古风,桑蚕苦,女工难,得新捐故后必寒。
|
| 4 |
+
书车,先秦,无名氏,古风,出畏之,入惧之。
|
| 5 |
+
击壤歌,先秦,无名氏,古风,日出而作。日入而息。凿井而饮。耕田而食。帝力于我何有哉。
|
data/dev/poetry/南北朝.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
标题,朝代,作者,体裁,内容
|
| 2 |
+
悬瓠方丈竹堂飨侍臣联句诗,南北朝,元宏,古风,白日光天兮无不曜。江左一隅独未照。愿从圣明兮登衡会。万国驰诚混内外。云雷大振兮天门辟。率土来宾一正历。舜舞干戚兮天下归。文德远被莫不思。皇风一鼓兮九地匝。戴日依天清六合。遵彼汝坟兮昔化贞。未若今日道风明。文王政教兮晖江沼。宁如大化光四表。
|
| 3 |
+
歌,南北朝,元宏,句,两菖蒲,新野乐。
|
| 4 |
+
应制赋铜鞮山松诗,南北朝,元协,古风,问松林。松林经几冬。山川何如昔。风云与古同。
|
| 5 |
+
绝命诗二首 其一 ,南北朝,元熙,古风,义实动君子,主辱死忠臣。何以明是节,将解七尺身。
|
| 6 |
+
绝命诗二首 其二 ,南北朝,元熙,古风,平生方寸心,殷勤属知己。从今一销化,悲伤无极已。
|
data/poetry/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集模块
|
| 2 |
+
|
| 3 |
+
从以下 github 地址下载数据集到目录 ./data/Poetry:
|
| 4 |
+
|
| 5 |
+
> https://github.com/xiu-ze/Poetry.git
|
| 6 |
+
|
| 7 |
+
数据集的格式是多文件 CSV 格式,统计结果:
|
| 8 |
+
|
| 9 |
+
> 找到 22 个 CSV 文件
|
| 10 |
+
>
|
| 11 |
+
> 诗歌总数: 1014507
|
| 12 |
+
> 最长字符数: 4872
|
| 13 |
+
> 平均字符数: 66.04
|
| 14 |
+
> 中位数: 48
|
| 15 |
+
|
| 16 |
+
因此可设置序列长度为 100.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from data.poetry.dataset import PoetryDataset
|
| 20 |
+
|
| 21 |
+
__all__ = ["PoetryDataset"]
|
data/poetry/dataset.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集主模块
|
| 2 |
+
|
| 3 |
+
实现 PoetryDataset 类,继承自 DataBundle。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
from data.base import DataBundle, TokenizerBundle
|
| 13 |
+
from data.poetry.loader import doc_load_with_eot
|
| 14 |
+
from data.poetry.transformer import transform
|
| 15 |
+
from data.poetry.tokenizer import load_vectorizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class PoetryDataset(DataBundle):
|
| 20 |
+
"""诗歌数据集
|
| 21 |
+
|
| 22 |
+
将文档加载、分词、统计等功能绑定在一起的数据集类。
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
dataset = PoetryDataset(
|
| 26 |
+
data_dir="~/data/Poetry/诗歌数据集",
|
| 27 |
+
vocab_path="~/data/Poetry/vocabulary.txt",
|
| 28 |
+
sequence_length=100
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 获取文档数据集
|
| 32 |
+
doc_ds = dataset.doc_ds()
|
| 33 |
+
|
| 34 |
+
# 获取 token 数据集
|
| 35 |
+
tokens_ds = dataset.tokens_ds(seq_length=100, batch_size=128)
|
| 36 |
+
|
| 37 |
+
# 打印统计信息
|
| 38 |
+
dataset.stat(seq_length=100)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
vocab_path: str = ""
|
| 42 |
+
|
| 43 |
+
_data_path: pathlib.Path = field(init=False, repr=False)
|
| 44 |
+
_vocab_path: pathlib.Path = field(init=False, repr=False)
|
| 45 |
+
_tokenizer_info: Optional[TokenizerBundle] = field(
|
| 46 |
+
init=False, repr=False, default=None
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def __post_init__(self):
|
| 50 |
+
self._data_path = pathlib.Path(self.data_dir).expanduser()
|
| 51 |
+
self._vocab_path = pathlib.Path(self.vocab_path).expanduser()
|
| 52 |
+
|
| 53 |
+
def _load_tokenizer(self):
|
| 54 |
+
"""懒加载分词器"""
|
| 55 |
+
if self._tokenizer_info is None:
|
| 56 |
+
tokenizer = load_vectorizer(self._vocab_path, self.sequence_length + 1)
|
| 57 |
+
vocab = tokenizer.get_vocabulary()
|
| 58 |
+
end_of_text = vocab.index("$")
|
| 59 |
+
vocab_size = len(vocab)
|
| 60 |
+
|
| 61 |
+
def decode(token_ids: list[int]) -> str:
|
| 62 |
+
chars = [
|
| 63 |
+
vocab[token_id] for token_id in token_ids if token_id < len(vocab)
|
| 64 |
+
]
|
| 65 |
+
return "".join(chars)
|
| 66 |
+
|
| 67 |
+
self._tokenizer_info = TokenizerBundle(
|
| 68 |
+
tokenizer=tokenizer,
|
| 69 |
+
decode=decode,
|
| 70 |
+
end_of_text=end_of_text,
|
| 71 |
+
vocab_size=vocab_size,
|
| 72 |
+
vocab_path=str(self._vocab_path)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def doc_ds(self) -> tf.data.Dataset:
|
| 76 |
+
"""返回原始文档数据集
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
TensorFlow Dataset,每个元素是带结束标记的诗歌内容
|
| 80 |
+
"""
|
| 81 |
+
return doc_load_with_eot(self._data_path)
|
| 82 |
+
|
| 83 |
+
def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
|
| 84 |
+
"""返回 tokenized 数据集
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
seq_length: 序列长度(诗歌中此参数主要用于兼容性)
|
| 88 |
+
batch_size: 批次大小
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
|
| 92 |
+
"""
|
| 93 |
+
self._load_tokenizer()
|
| 94 |
+
ds = self.doc_ds()
|
| 95 |
+
return transform(
|
| 96 |
+
ds=ds,
|
| 97 |
+
tokenizer=self._tokenizer_info.tokenizer,
|
| 98 |
+
batch_size=batch_size,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def tokenizer_bundle(self) -> TokenizerBundle:
|
| 102 |
+
"""返回分词器信息"""
|
| 103 |
+
self._load_tokenizer()
|
| 104 |
+
return self._tokenizer_info
|
data/poetry/loader.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集文档加载模块
|
| 2 |
+
|
| 3 |
+
从 CSV 文件加载诗歌文本数据。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import os
|
| 8 |
+
import pathlib
|
| 9 |
+
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _parse_csv_line(line: tf.Tensor) -> tf.Tensor:
|
| 14 |
+
"""解析 CSV 行,返回内容列"""
|
| 15 |
+
fields = tf.io.decode_csv(
|
| 16 |
+
line,
|
| 17 |
+
use_quote_delim=False, # 行内的引号是普通字符
|
| 18 |
+
record_defaults=["", "", "", "", ""],
|
| 19 |
+
)
|
| 20 |
+
return fields[4] # 返回 '内容' 列的值
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def doc_load(data_dir: pathlib.Path) -> tf.data.Dataset:
|
| 24 |
+
"""加载诗歌数据集
|
| 25 |
+
|
| 26 |
+
从指定目录下的 CSV 文件中加载诗歌文本数据。
|
| 27 |
+
每个 CSV 文件应该包含以下列:标题、作者、朝代、类型、内容。
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data_dir: 数据目录路径
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
TensorFlow Dataset,每个元素是诗歌内容字符串
|
| 34 |
+
"""
|
| 35 |
+
csv_files = glob.glob(os.path.join(data_dir, "*.csv"))
|
| 36 |
+
if not csv_files:
|
| 37 |
+
raise ValueError(f"在目录 {data_dir} 中未找到任何 CSV 文件!")
|
| 38 |
+
|
| 39 |
+
files_ds = tf.data.Dataset.from_tensor_slices(csv_files)
|
| 40 |
+
csv_line_ds = files_ds.interleave(
|
| 41 |
+
lambda csv_file: tf.data.TextLineDataset(csv_file).skip(1),
|
| 42 |
+
cycle_length=1,
|
| 43 |
+
)
|
| 44 |
+
return csv_line_ds.map(_parse_csv_line, num_parallel_calls=tf.data.AUTOTUNE).filter(
|
| 45 |
+
lambda x: tf.strings.length(x) > 0
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def doc_load_with_eot(data_dir: pathlib.Path) -> tf.data.Dataset:
|
| 50 |
+
"""加载诗歌数据集,每行末尾添加结束标记
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
data_dir: 数据目录路径
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
TensorFlow Dataset,每个元素是带结束标记的诗歌内容
|
| 57 |
+
"""
|
| 58 |
+
ds = doc_load(data_dir)
|
| 59 |
+
return ds.map(lambda x: tf.strings.join([x, "$"]))
|
data/poetry/runner.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集 Runner
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python data/poetry/runner.py build_vocab
|
| 5 |
+
python data/poetry/runner.py test_dataset
|
| 6 |
+
ENV=production python data/poetry/runner.py build_vocab
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pathlib
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
|
| 13 |
+
|
| 14 |
+
from data.poetry.dataset import PoetryDataset
|
| 15 |
+
from data.runner import DatasetRunner
|
| 16 |
+
from env.resolve import resolve_path, resolve_env, resolve_saved
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
dataset = PoetryDataset(
|
| 20 |
+
data_dir=str(
|
| 21 |
+
resolve_env(resolve_path("data/dev/poetry"), resolve_path("~/data/Poetry/诗歌数据集"))
|
| 22 |
+
),
|
| 23 |
+
vocab_path=str(
|
| 24 |
+
resolve_env(
|
| 25 |
+
resolve_saved("vocab/poetry/vocab.txt"),
|
| 26 |
+
resolve_path("~/data/Poetry/vocabulary.txt"),
|
| 27 |
+
)
|
| 28 |
+
),
|
| 29 |
+
sequence_length=100,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
runner = DatasetRunner(
|
| 33 |
+
dataset=dataset,
|
| 34 |
+
name="poetry",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
runner()
|
data/poetry/tokenizer.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集分词器模块
|
| 2 |
+
|
| 3 |
+
提供诗歌数据集专用的分词器实现。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
|
| 8 |
+
from keras import layers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_vocabulary(vocab_path: pathlib.Path):
|
| 12 |
+
"""从文本文件加载词汇表,每行一个字符。
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
vocab_path: 词汇表文件路径
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
词汇表列表
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def extract_word(line: str) -> str:
|
| 22 |
+
word = line[:-1] # 去掉行末的换行符
|
| 23 |
+
return word if word != r"\n" else "\n"
|
| 24 |
+
|
| 25 |
+
with open(vocab_path, "r", encoding="utf-8") as f:
|
| 26 |
+
vocab = [extract_word(line) for line in f]
|
| 27 |
+
return vocab
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_vectorizer(
|
| 31 |
+
vocab_path: pathlib.Path, sequence_length: int = 101
|
| 32 |
+
) -> layers.TextVectorization:
|
| 33 |
+
"""从词汇表文件加载分词器
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_path: 词汇表文件路径
|
| 37 |
+
sequence_length: 输出序列长度,默认为 101
|
| 38 |
+
(多一位是为了在训练时构建输入和目标偏移一位)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
TextVectorization 层
|
| 42 |
+
"""
|
| 43 |
+
vectorizer = layers.TextVectorization(
|
| 44 |
+
output_mode="int",
|
| 45 |
+
split="character",
|
| 46 |
+
output_sequence_length=sequence_length,
|
| 47 |
+
standardize=None,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
vocab = load_vocabulary(vocab_path)
|
| 51 |
+
vectorizer.set_vocabulary(vocab)
|
| 52 |
+
|
| 53 |
+
return vectorizer
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_vectorizer(sequence_length: int = 101) -> layers.TextVectorization:
|
| 57 |
+
"""创建新的分词器(用于训练词汇表)
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
sequence_length: 输出序列长度,默认为 101
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
TextVectorization 层
|
| 64 |
+
"""
|
| 65 |
+
return layers.TextVectorization(
|
| 66 |
+
output_mode="int", split="character", standardize=None
|
| 67 |
+
)
|
data/poetry/transformer.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""诗歌数据集 token 转换模块
|
| 2 |
+
|
| 3 |
+
将诗歌文档数据集转换为训练用的 token 序列。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def transform(
|
| 12 |
+
ds: tf.data.Dataset,
|
| 13 |
+
tokenizer: Callable,
|
| 14 |
+
batch_size: int,
|
| 15 |
+
) -> tf.data.Dataset:
|
| 16 |
+
"""转换诗歌数据集为训练数据集
|
| 17 |
+
|
| 18 |
+
诗歌数据集已经生成了固定数量的 token 序列,不足的部分会 padding。
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
ds: 文档数据集
|
| 22 |
+
tokenizer: 分词器函数
|
| 23 |
+
batch_size: 批次大小
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
训练数据集,每个元素是 (input_ids, target_ids) 对
|
| 27 |
+
"""
|
| 28 |
+
# 文本向量化;对于诗歌数据集来说,已经生成了固定数量的 token 序列了,不足的部分会 padding
|
| 29 |
+
ds = ds.map(tokenizer, num_parallel_calls=8)
|
| 30 |
+
|
| 31 |
+
# 构建输入和目标(偏移一位)
|
| 32 |
+
# 无需在这里添加结束标记,因为在 doc_load 中已经添加了结束标记
|
| 33 |
+
ds = ds.map(lambda x: (x[:-1], x[1:]))
|
| 34 |
+
|
| 35 |
+
# 重新设置批次大小并预取数据以提高性能
|
| 36 |
+
ds = ds.batch(batch_size).prefetch(8)
|
| 37 |
+
|
| 38 |
+
return ds
|
data/runner.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""数据集 Runner 公共模块
|
| 2 |
+
|
| 3 |
+
提供通用的数据集测试和词汇表生成功能。
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
# 在各自 runner.py 中实例化
|
| 7 |
+
from data.runner import DatasetRunner
|
| 8 |
+
from data.poetry.dataset import PoetryDataset
|
| 9 |
+
from env.resolve import resolve, resolve_saved, resolve_env
|
| 10 |
+
|
| 11 |
+
dataset = PoetryDataset(
|
| 12 |
+
data_dir=str(resolve_env(resolve("data/dev/poetry"), resolve("~/data/Poetry/诗歌数据集"))),
|
| 13 |
+
vocab_path=str(resolve_env(resolve_saved("poetry/vocab.txt"), resolve("~/data/Poetry/vocabulary.txt"))),
|
| 14 |
+
sequence_length=100,
|
| 15 |
+
)
|
| 16 |
+
runner = DatasetRunner(dataset=dataset, name="poetry")
|
| 17 |
+
runner()
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from data.base import DataBundle
|
| 21 |
+
from data.common import build_vocab_from_dataset
|
| 22 |
+
from env.resolve import resolve_saved
|
| 23 |
+
from env.runner import ActionRunner
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DatasetRunner(ActionRunner):
|
| 27 |
+
"""数据集 Runner
|
| 28 |
+
|
| 29 |
+
提供通用的数据集测试和词汇表生成功能。
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
dataset: 数据集实例(PoetryDataset 或 WikiDataset)
|
| 33 |
+
name: 数据集英文名称(如 "poetry", "wiki")
|
| 34 |
+
max_docs: 测试时显示的文档数量,默认 5
|
| 35 |
+
max_samples: 测试时显示的 token 样本数量,默认 3
|
| 36 |
+
max_doc_chars: 文档显示的最大字符数,默认 200
|
| 37 |
+
max_text_display: token 文本显示的最大字符数,默认 80
|
| 38 |
+
|
| 39 |
+
Usage:
|
| 40 |
+
runner = DatasetRunner(dataset=poetry_dataset, name="poetry")
|
| 41 |
+
runner.test_dataset() # 或 runner.build_vocab()
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# 中英文名称映射
|
| 45 |
+
NAME_MAP = {
|
| 46 |
+
"poetry": "诗歌",
|
| 47 |
+
"wiki": "Wiki",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
dataset: DataBundle,
|
| 53 |
+
name: str,
|
| 54 |
+
max_docs: int = 5,
|
| 55 |
+
max_samples: int = 3,
|
| 56 |
+
max_doc_chars: int = 200,
|
| 57 |
+
max_text_display: int = 80,
|
| 58 |
+
):
|
| 59 |
+
self.dataset = dataset
|
| 60 |
+
self.name = name
|
| 61 |
+
self.display_name = self.NAME_MAP.get(name, name)
|
| 62 |
+
self.vocab_path = resolve_saved(f"vocab/{name}/vocab.txt")
|
| 63 |
+
self.max_docs = max_docs
|
| 64 |
+
self.max_samples = max_samples
|
| 65 |
+
self.max_doc_chars = max_doc_chars
|
| 66 |
+
self.max_text_display = max_text_display
|
| 67 |
+
|
| 68 |
+
def build_vocab(self) -> None:
|
| 69 |
+
"""生成字符词汇表"""
|
| 70 |
+
print(f"正在加载数据集...")
|
| 71 |
+
ds = self.dataset.doc_ds()
|
| 72 |
+
|
| 73 |
+
print(f"正在保存词汇表到: {self.vocab_path}")
|
| 74 |
+
vocab = build_vocab_from_dataset(ds, self.vocab_path)
|
| 75 |
+
|
| 76 |
+
print(f"词汇表大小: {len(vocab)}")
|
| 77 |
+
print("完成!")
|
| 78 |
+
|
| 79 |
+
def test_dataset(self) -> None:
|
| 80 |
+
"""测试数据集"""
|
| 81 |
+
print("\n" + "=" * 60)
|
| 82 |
+
print(f"{self.display_name} 数据集测试")
|
| 83 |
+
print("=" * 60)
|
| 84 |
+
|
| 85 |
+
self._view_documents(self.dataset.doc_ds())
|
| 86 |
+
self._view_tokens(self.dataset)
|
| 87 |
+
self._show_vocab_info(self.dataset.tokenizer_bundle())
|
| 88 |
+
|
| 89 |
+
print("\n" + "=" * 60)
|
| 90 |
+
print("测试完成")
|
| 91 |
+
print("=" * 60)
|
| 92 |
+
|
| 93 |
+
def _view_documents(self, doc_ds) -> None:
|
| 94 |
+
"""查看原始文档"""
|
| 95 |
+
print("\n【原始文档查看】")
|
| 96 |
+
print("-" * 60)
|
| 97 |
+
count = 0
|
| 98 |
+
for doc in doc_ds.take(self.max_docs):
|
| 99 |
+
count += 1
|
| 100 |
+
text = doc.numpy().decode("utf-8")
|
| 101 |
+
if len(text) > self.max_doc_chars:
|
| 102 |
+
text = text[: self.max_doc_chars] + "..."
|
| 103 |
+
print(f"\n第 {count} 个文档:")
|
| 104 |
+
print(f" {text}")
|
| 105 |
+
print(f"\n共显示 {count} 个文档")
|
| 106 |
+
|
| 107 |
+
def _view_tokens(self, dataset) -> None:
|
| 108 |
+
"""查看 tokenized 数据"""
|
| 109 |
+
print("\n【Tokenized 数据查看】")
|
| 110 |
+
print("-" * 60)
|
| 111 |
+
|
| 112 |
+
tokenizer_info = dataset.tokenizer_bundle()
|
| 113 |
+
tokens_ds = dataset.tokens_ds(seq_length=dataset.sequence_length, batch_size=1)
|
| 114 |
+
|
| 115 |
+
count = 0
|
| 116 |
+
for batch_input, batch_target in tokens_ds.take(self.max_samples):
|
| 117 |
+
count += 1
|
| 118 |
+
input_ids = batch_input[0].numpy()
|
| 119 |
+
target_ids = batch_target[0].numpy()
|
| 120 |
+
|
| 121 |
+
input_text = tokenizer_info.decode(input_ids.tolist())
|
| 122 |
+
target_text = tokenizer_info.decode(target_ids.tolist())
|
| 123 |
+
|
| 124 |
+
if len(input_text) > self.max_text_display:
|
| 125 |
+
input_text = input_text[: self.max_text_display] + "..."
|
| 126 |
+
if len(target_text) > self.max_text_display:
|
| 127 |
+
target_text = target_text[: self.max_text_display] + "..."
|
| 128 |
+
|
| 129 |
+
print(f"\n第 {count} 个样本:")
|
| 130 |
+
print(f" 输入 tokens: {input_ids[:20]}... (长度: {len(input_ids)})")
|
| 131 |
+
print(f" 目标 tokens: {target_ids[:20]}... (长度: {len(target_ids)})")
|
| 132 |
+
print(f" 输入文本: {input_text}")
|
| 133 |
+
print(f" 目标文本: {target_text}")
|
| 134 |
+
print(f"\n共显示 {count} 个样本")
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def _show_vocab_info(tokenizer_info) -> None:
|
| 138 |
+
"""显示词汇表信息"""
|
| 139 |
+
print("\n【词汇表信息】")
|
| 140 |
+
print("-" * 60)
|
| 141 |
+
print(f" 词汇表大小: {tokenizer_info.vocab_size}")
|
| 142 |
+
print(f" 结束标记 ID: {tokenizer_info.end_of_text}")
|
data/tokenizers.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT模型的共享组件模块:
|
| 3 |
+
|
| 4 |
+
- 分词器
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import keras
|
| 8 |
+
import keras_hub
|
| 9 |
+
from keras import layers
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sentence_piece():
|
| 13 |
+
# 用预训练好的分词器,也就是说我们不去自己训练分词器了
|
| 14 |
+
vocabulary_file = keras.utils.get_file(
|
| 15 |
+
origin="https://hf-mirror.com/mattdangerw/spiece/resolve/main/vocabulary.proto"
|
| 16 |
+
)
|
| 17 |
+
# [Note] 依然需要 tensorflow_text 包
|
| 18 |
+
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
|
| 19 |
+
|
| 20 |
+
end_of_text = tokenizer.token_to_id("<|endoftext|>")
|
| 21 |
+
|
| 22 |
+
def decode(tokens: list[int]) -> str:
|
| 23 |
+
return tokenizer.detokenize(tokens)
|
| 24 |
+
|
| 25 |
+
return tokenizer, end_of_text, decode
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def character_vectorization():
|
| 29 |
+
"""简单的字符级分词器,适用于测试"""
|
| 30 |
+
vectorizer = layers.TextVectorization(output_mode="int", split="character")
|
| 31 |
+
vectorizer.set_vocabulary(
|
| 32 |
+
list("abcdefghijklmnopqrstuvwxyz0123456789 .,!?;:()[]{}<>-_\n")
|
| 33 |
+
+ ["<|endoftext|>"] # 兼容 sentence_piece 分词器的特殊标记
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
vocab = vectorizer.get_vocabulary()
|
| 37 |
+
for idx, word in enumerate(vocab):
|
| 38 |
+
if word == "<|endoftext|>":
|
| 39 |
+
end_of_text = idx
|
| 40 |
+
break
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError("Vocabulary does not contain <|endoftext|> token.")
|
| 43 |
+
|
| 44 |
+
def decode(tokens: list[int]) -> str:
|
| 45 |
+
words = [vocab[token] for token in tokens]
|
| 46 |
+
return "".join(words)
|
| 47 |
+
|
| 48 |
+
return vectorizer, end_of_text, decode
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def poetry_character_vectorization(
|
| 52 |
+
vocab_path: str = "local/saved/vocab/poetry/vocab.txt",
|
| 53 |
+
):
|
| 54 |
+
"""从文本文件加载诗歌字符级分词器。
|
| 55 |
+
|
| 56 |
+
词汇表文件格式:每行一个字符,第一行必须是 <|endoftext|>。
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
vocab_path: 词汇表文件路径,默认为 "local/saved/poetry/vocab.txt"
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(vectorizer, end_of_text, decode): 分词器、结束标记ID、解码函数
|
| 63 |
+
"""
|
| 64 |
+
from env.resolve import resolve_path
|
| 65 |
+
|
| 66 |
+
# 读取词汇表
|
| 67 |
+
vocab_file = resolve_path(vocab_path)
|
| 68 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
| 69 |
+
vocab = [line.rstrip("\n") for line in f]
|
| 70 |
+
|
| 71 |
+
# 创建 TextVectorization 层
|
| 72 |
+
vectorizer = layers.TextVectorization(
|
| 73 |
+
output_mode="int", split="character", standardize=None
|
| 74 |
+
)
|
| 75 |
+
vectorizer.set_vocabulary(vocab)
|
| 76 |
+
|
| 77 |
+
# 找到 end_of_text 的索引
|
| 78 |
+
for idx, word in enumerate(vocab):
|
| 79 |
+
if word == "<|endoftext|>":
|
| 80 |
+
end_of_text = idx
|
| 81 |
+
break
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError("Vocabulary does not contain <|endoftext|> token.")
|
| 84 |
+
|
| 85 |
+
def decode(tokens: list[int]) -> str:
|
| 86 |
+
words = [vocab[token] for token in tokens]
|
| 87 |
+
return "".join(words)
|
| 88 |
+
|
| 89 |
+
return vectorizer, end_of_text, decode
|
data/wiki/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集模块
|
| 2 |
+
|
| 3 |
+
导出 WikiDataset 类。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from data.wiki.dataset import WikiDataset
|
| 7 |
+
|
| 8 |
+
__all__ = ["WikiDataset"]
|
data/wiki/dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集主模块
|
| 2 |
+
|
| 3 |
+
实现 WikiDataset 类,继承自 DataBundle。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
from data.base import DataBundle, TokenizerBundle
|
| 13 |
+
from data.wiki.loader import doc_load
|
| 14 |
+
from data.wiki.transformer import transform
|
| 15 |
+
from data.wiki.tokenizer import sentence_piece, character_vectorization
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class WikiDataset(DataBundle):
|
| 20 |
+
"""Wiki 数据集
|
| 21 |
+
|
| 22 |
+
将文档加载、分词、统计等功能绑定在一起的数据集类。
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
dataset = WikiDataset(
|
| 26 |
+
data_dir="~/data/wiki/mini_c4",
|
| 27 |
+
tokenizer_type="sentence_piece" # 或 "character"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# 获取文档数据集
|
| 31 |
+
doc_ds = dataset.doc_ds()
|
| 32 |
+
|
| 33 |
+
# 获取 token 数据集
|
| 34 |
+
tokens_ds = dataset.tokens_ds(seq_length=256, batch_size=32)
|
| 35 |
+
|
| 36 |
+
# 打印统计信息
|
| 37 |
+
dataset.stat(seq_length=256)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
glob_pattern: str = "*"
|
| 41 |
+
tokenizer_type: str = "sentence_piece"
|
| 42 |
+
|
| 43 |
+
_data_path: pathlib.Path = field(init=False, repr=False)
|
| 44 |
+
_tokenizer_bundle: Optional[TokenizerBundle] = field(
|
| 45 |
+
init=False, repr=False, default=None
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def __post_init__(self):
|
| 49 |
+
self._data_path = pathlib.Path(self.data_dir).expanduser()
|
| 50 |
+
|
| 51 |
+
def _load_tokenizer(self):
|
| 52 |
+
"""懒加载分词器"""
|
| 53 |
+
if self._tokenizer_bundle is None:
|
| 54 |
+
if self.tokenizer_type == "sentence_piece":
|
| 55 |
+
tokenizer, end_of_text, decode = sentence_piece()
|
| 56 |
+
elif self.tokenizer_type == "character":
|
| 57 |
+
tokenizer, end_of_text, decode = character_vectorization()
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
| 60 |
+
|
| 61 |
+
vocab_size = tokenizer.vocabulary_size()
|
| 62 |
+
self._tokenizer_bundle = TokenizerBundle(
|
| 63 |
+
tokenizer=tokenizer,
|
| 64 |
+
decode=decode,
|
| 65 |
+
end_of_text=end_of_text,
|
| 66 |
+
vocab_size=vocab_size
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def doc_ds(self) -> tf.data.Dataset:
|
| 70 |
+
"""返回原始文档数据集
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
TensorFlow Dataset,每个元素是一个文档字符串
|
| 74 |
+
"""
|
| 75 |
+
return doc_load(self._data_path, glob_pattern=self.glob_pattern)
|
| 76 |
+
|
| 77 |
+
def tokens_ds(self, seq_length: int, batch_size: int) -> tf.data.Dataset:
|
| 78 |
+
"""返回 tokenized 数据集
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
seq_length: 序列长度
|
| 82 |
+
batch_size: 批次大小
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
TensorFlow Dataset,每个元素是 (input_ids, target_ids) 对
|
| 86 |
+
"""
|
| 87 |
+
self._load_tokenizer()
|
| 88 |
+
ds = self.doc_ds()
|
| 89 |
+
return transform(
|
| 90 |
+
ds=ds,
|
| 91 |
+
tokenizer=self._tokenizer_bundle.tokenizer,
|
| 92 |
+
end_of_text=self._tokenizer_bundle.end_of_text,
|
| 93 |
+
sequence_length=seq_length,
|
| 94 |
+
batch_size=batch_size,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def tokenizer_bundle(self) -> TokenizerBundle:
|
| 98 |
+
"""返回分词器信息"""
|
| 99 |
+
self._load_tokenizer()
|
| 100 |
+
return self._tokenizer_bundle
|
data/wiki/loader.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集文档加载模块
|
| 2 |
+
|
| 3 |
+
从 mini_c4 格式加载文档数据集。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def doc_load(
|
| 12 |
+
data_dir: pathlib.Path, glob_pattern: str = "*", cycle_length: int = 32
|
| 13 |
+
) -> tf.data.Dataset:
|
| 14 |
+
"""加载并处理文档数据集为 TensorFlow Dataset。
|
| 15 |
+
|
| 16 |
+
递归查找指定目录下匹配 glob_pattern 的所有文件,使用 doc_extract 函数
|
| 17 |
+
将每个文件转换为 TensorFlow Dataset,然后使用 interleave 进行并行处理。
|
| 18 |
+
|
| 19 |
+
目录下的文件格式要求每行一个文档,其中的换行符使用 "\\n" 转义。
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
data_dir: 数据目录路径
|
| 23 |
+
glob_pattern: 文件匹配模式,如 "*.txt",默认为 "*" 匹配所有文件
|
| 24 |
+
cycle_length: interleave 的 cycle_length 参数,控制并行处理的文件数量,默认为 32
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
合并后的 TensorFlow Dataset,包含所有文件处理后的数据
|
| 28 |
+
"""
|
| 29 |
+
# 获取所有文件(过滤掉目录),递归查找子目录
|
| 30 |
+
files = [str(file) for file in data_dir.rglob(glob_pattern) if file.is_file()]
|
| 31 |
+
if not files:
|
| 32 |
+
raise FileNotFoundError(f"在目录 {data_dir} 中未找到匹配 {glob_pattern} 的文件")
|
| 33 |
+
|
| 34 |
+
# 排序文件列表以确保一致的处理顺序
|
| 35 |
+
files = sorted(files)
|
| 36 |
+
|
| 37 |
+
# 创建数据集管道
|
| 38 |
+
ds = tf.data.Dataset.from_tensor_slices(files)
|
| 39 |
+
ds = ds.interleave(
|
| 40 |
+
_line_doc_extract,
|
| 41 |
+
cycle_length=cycle_length,
|
| 42 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return ds
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _line_doc_extract(path: str) -> tf.data.Dataset:
|
| 49 |
+
"""Mini-c4 format: one document per line."""
|
| 50 |
+
return tf.data.TextLineDataset(path).map(
|
| 51 |
+
lambda x: tf.strings.regex_replace(x, r"\\n", "\n")
|
| 52 |
+
)
|
data/wiki/runner.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集 Runner
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python data/wiki/runner.py test_dataset
|
| 5 |
+
ENV=production python data/wiki/runner.py test_dataset
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pathlib
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
|
| 12 |
+
|
| 13 |
+
from data.runner import DatasetRunner
|
| 14 |
+
from data.wiki.dataset import WikiDataset
|
| 15 |
+
from env.resolve import resolve_path, resolve_env
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
dataset = WikiDataset(
|
| 19 |
+
data_dir=str(
|
| 20 |
+
resolve_env(resolve_path("data/dev/mini_c4"), resolve_path("~/data/wiki/mini_c4"))
|
| 21 |
+
),
|
| 22 |
+
tokenizer_type=resolve_env("character", "sentence_piece"),
|
| 23 |
+
sequence_length=256,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
runner = DatasetRunner(
|
| 27 |
+
dataset=dataset,
|
| 28 |
+
name="wiki",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
runner()
|
data/wiki/tokenizer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集分词器模块
|
| 2 |
+
|
| 3 |
+
提供 Wiki 数据集专用的分词器实现。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import keras
|
| 7 |
+
import keras_hub
|
| 8 |
+
from keras import layers
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sentence_piece():
|
| 12 |
+
"""SentencePiece 分词器
|
| 13 |
+
|
| 14 |
+
使用预训练好的分词器,无需自己训练。
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
(tokenizer, end_of_text, decode): 分词器、结束标记ID、解码函数
|
| 18 |
+
"""
|
| 19 |
+
# 用预训练好的分词器,也就是说我们不去自己训练分词器了
|
| 20 |
+
vocabulary_file = keras.utils.get_file(
|
| 21 |
+
origin="https://hf-mirror.com/mattdangerw/spiece/resolve/main/vocabulary.proto"
|
| 22 |
+
)
|
| 23 |
+
# [Note] 依然需要 tensorflow_text 包
|
| 24 |
+
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
|
| 25 |
+
|
| 26 |
+
end_of_text = tokenizer.token_to_id("<|endoftext|>")
|
| 27 |
+
|
| 28 |
+
def decode(tokens: list[int]) -> str:
|
| 29 |
+
return tokenizer.detokenize(tokens)
|
| 30 |
+
|
| 31 |
+
return tokenizer, end_of_text, decode
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def character_vectorization():
|
| 35 |
+
"""字符级分词器
|
| 36 |
+
|
| 37 |
+
简单的字符级分词器,适用于测试。
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
(tokenizer, end_of_text, decode): 分词器、结束标记ID、解码函数
|
| 41 |
+
"""
|
| 42 |
+
vectorizer = layers.TextVectorization(output_mode="int", split="character")
|
| 43 |
+
vectorizer.set_vocabulary(
|
| 44 |
+
list("abcdefghijklmnopqrstuvwxyz0123456789 .,!?;:()[]{}\u003c\u003e-_\n")
|
| 45 |
+
+ ["<|endoftext|>"] # 兼容 sentence_piece 分词器的特殊标记
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
vocab = vectorizer.get_vocabulary()
|
| 49 |
+
for idx, word in enumerate(vocab):
|
| 50 |
+
if word == "<|endoftext|>":
|
| 51 |
+
end_of_text = idx
|
| 52 |
+
break
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError("Vocabulary does not contain <|endoftext|> token.")
|
| 55 |
+
|
| 56 |
+
def decode(tokens: list[int]) -> str:
|
| 57 |
+
words = [vocab[token] for token in tokens]
|
| 58 |
+
return "".join(words)
|
| 59 |
+
|
| 60 |
+
return vectorizer, end_of_text, decode
|
data/wiki/transformer.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wiki 数据集 token 转换模块
|
| 2 |
+
|
| 3 |
+
将文档数据集转换为训练用的 token 序列。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def transform(
|
| 13 |
+
ds: tf.data.Dataset,
|
| 14 |
+
tokenizer: Callable,
|
| 15 |
+
end_of_text: int,
|
| 16 |
+
sequence_length: int,
|
| 17 |
+
batch_size: int,
|
| 18 |
+
) -> tf.data.Dataset:
|
| 19 |
+
"""转换文档数据集为训练数据集
|
| 20 |
+
|
| 21 |
+
将文档转换为 token ID,添加结束标记,分割为固定长度的序列。
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
ds: 文档数据集
|
| 25 |
+
tokenizer: 分词器函数
|
| 26 |
+
end_of_text: 结束标记的 token ID
|
| 27 |
+
sequence_length: 序列长度
|
| 28 |
+
batch_size: 批次大小
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
训练数据集,每个元素是 (input_ids, target_ids) 对
|
| 32 |
+
"""
|
| 33 |
+
ds = ds.map(tokenizer, num_parallel_calls=8)
|
| 34 |
+
|
| 35 |
+
# 将文档之间添加 end_of_text 标记分隔
|
| 36 |
+
ds = ds.map(lambda x: tf.concat([x, np.array([end_of_text])], -1))
|
| 37 |
+
|
| 38 |
+
# 重新设置样本大小为固定长度序列
|
| 39 |
+
ds = ds.rebatch(sequence_length + 1, drop_remainder=True)
|
| 40 |
+
|
| 41 |
+
# 构建输入和目标(偏移一位)
|
| 42 |
+
ds = ds.map(lambda x: (x[:-1], x[1:]))
|
| 43 |
+
|
| 44 |
+
# 重新设置批次大小并预取数据以提高性能
|
| 45 |
+
ds = ds.batch(batch_size).prefetch(8)
|
| 46 |
+
|
| 47 |
+
return ds
|
data/wiki/wiki_cleaner.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wiki 文本清洗模块。
|
| 3 |
+
|
| 4 |
+
提供多种过滤器用于清洗 wiki 格式的文本数据。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def filter_single_line(text: str) -> str | None:
|
| 11 |
+
"""
|
| 12 |
+
过滤只有一行的数据(通常是重定向页面)。
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
text: 输入文本
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
如果只有一行返回 None,否则返回原文本
|
| 19 |
+
"""
|
| 20 |
+
lines = [line for line in text.split("\n") if line.strip()]
|
| 21 |
+
if len(lines) <= 1:
|
| 22 |
+
return None
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def filter_empty_brackets(text: str) -> str:
|
| 27 |
+
"""
|
| 28 |
+
移除文本中的空括号对。
|
| 29 |
+
|
| 30 |
+
例如:()、()、( )、( )、[ ]、【 】、{ } 等
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
text: 输入文本
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
移除空括号后的文本
|
| 37 |
+
"""
|
| 38 |
+
# 匹配空括号对:() () [] 【】 {} 等,中间可有空白
|
| 39 |
+
pattern = re.compile(r"[\(\)()\[\]【】{}]\s*[\(\)()\[\]【】{}]")
|
| 40 |
+
return pattern.sub("", text)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def filter_html_tags(text: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
移除 HTML/XML 标签(HTML 实体编码格式)。
|
| 46 |
+
|
| 47 |
+
例如:<templatestyles src="ShareCSS/infobox.css" />
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
text: 输入文本
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
移除 HTML 标签后的文本
|
| 54 |
+
"""
|
| 55 |
+
# 匹配 <...> 格式的实体编码标签
|
| 56 |
+
pattern = re.compile(r"<[^&]+>")
|
| 57 |
+
return pattern.sub("", text)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def filter_lang_tags(text: str) -> str:
|
| 61 |
+
"""
|
| 62 |
+
移除特殊的语言标记(支持嵌套)。
|
| 63 |
+
|
| 64 |
+
例如:-{H|zh-hans:重定向;zh-hant:重新导向;}-
|
| 65 |
+
嵌套例如:-{T|zh:-{zh|}-;zh-hans:-{zh-hans|}-;}-
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
text: 输入文本
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
移除语言转换标记后的文本
|
| 72 |
+
"""
|
| 73 |
+
# 使用非贪婪匹配,循环处理嵌套
|
| 74 |
+
pattern = re.compile(r"-\{[^{}]+?}-")
|
| 75 |
+
while True:
|
| 76 |
+
new_text = pattern.sub("", text)
|
| 77 |
+
if new_text == text: # 没有更多匹配了
|
| 78 |
+
break
|
| 79 |
+
text = new_text
|
| 80 |
+
return text
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def clean(text: str) -> str | None:
|
| 84 |
+
"""
|
| 85 |
+
应用所有过滤器清洗文本。
|
| 86 |
+
|
| 87 |
+
过滤顺序:
|
| 88 |
+
1. 单行检查(重定向页面)
|
| 89 |
+
2. HTML 标签
|
| 90 |
+
3. 空白括号行
|
| 91 |
+
4. 语言转换标记
|
| 92 |
+
5. 最终空检查
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
text: 输入文本
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
清洗后的文本,如果应该丢弃则返回 None
|
| 99 |
+
"""
|
| 100 |
+
# 1. 检查单行
|
| 101 |
+
result = filter_single_line(text)
|
| 102 |
+
if result is None:
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
# 2. 移除 HTML 标签
|
| 106 |
+
result = filter_html_tags(result)
|
| 107 |
+
|
| 108 |
+
# 3. 移除空白括号行
|
| 109 |
+
result = filter_empty_brackets(result)
|
| 110 |
+
|
| 111 |
+
# 4. 移除语言转换标记
|
| 112 |
+
result = filter_lang_tags(result)
|
| 113 |
+
|
| 114 |
+
# 5. 多个连续空行替换为一个空行
|
| 115 |
+
result = re.sub(r"\n\s*\n", "\n\n", result)
|
| 116 |
+
result = result.strip()
|
| 117 |
+
|
| 118 |
+
# 6. 最终检查:如果结果为空或只剩空白,返回 None
|
| 119 |
+
if not result.strip():
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
return result
|
docs/TODOs.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- [ ] `<doc>` 格式由于计算图的限制还无法实现,未来打算实现。
|
| 2 |
+
- [ ] 希望能通过 Callback 或 train_step 截取到训练过程中的数据。
|
| 3 |
+
- [ ] wiki 训练后不能回答事实性问题,感觉是过拟合了,将 dropout 调成 0.5 试一试(当前 0.1)。
|
docs/pycharm.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyCharm 开发指南
|
| 2 |
+
|
| 3 |
+
最近,我在项目里尝试将 PyCharm 的代码连接到远程服务器运行,遇到了一些莫名的问题。现在将一些解决方案记录下来,供以后参考。
|
| 4 |
+
|
| 5 |
+
总的来说,我直接应用远程环境就会出错,有各种各样的问题。我需要重新构建一个全新的环境才使得它正常运作。记录如下:
|
| 6 |
+
|
| 7 |
+
1. 在远程服务器上创建一个新的 conda 环境。
|
| 8 |
+
2. 创建一个新的 Python 项目(我直接移动了我的项目目录,并删除目录下的 .idea, .ruff_cache, .pytest_cache 等文件夹)。
|
| 9 |
+
3. 在本地 PyCharm 中配置远程 Python 解释器,指向远程服务器上的新环境。这一步骤中,注意配置好目录映射,和不自动上传文件。
|
| 10 |
+
4. 等待一段时间,就能正常运作了。
|
env/keras.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Keras 相关工具模块
|
| 2 |
+
|
| 3 |
+
提供 Keras 配置相关的功能。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import keras
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def enable_mixed_precision():
|
| 10 |
+
"""开启混合精度训练/推理"""
|
| 11 |
+
keras.config.set_dtype_policy("mixed_float16")
|
env/logger.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import wraps
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_logger(name: str, filepath: str = None):
|
| 6 |
+
logger = logging.getLogger(name)
|
| 7 |
+
logger.setLevel(logging.INFO)
|
| 8 |
+
|
| 9 |
+
# 控制台
|
| 10 |
+
console_handler = logging.StreamHandler()
|
| 11 |
+
logger.addHandler(console_handler)
|
| 12 |
+
|
| 13 |
+
# 文件
|
| 14 |
+
if filepath:
|
| 15 |
+
file_handler = logging.FileHandler(filepath)
|
| 16 |
+
logger.addHandler(file_handler)
|
| 17 |
+
|
| 18 |
+
return logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def log(enter_message: str = "", exit_message: str = ""):
|
| 22 |
+
return _Log(enter_message=enter_message, exit_message=exit_message)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _Log:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
enter_message: str = "",
|
| 29 |
+
exit_message: str = ""
|
| 30 |
+
):
|
| 31 |
+
self.enter_message = enter_message
|
| 32 |
+
self.exit_message = exit_message
|
| 33 |
+
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
if self.enter_message:
|
| 36 |
+
print(self.enter_message)
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __exit__(self, exc_type, exc, tb):
|
| 40 |
+
if self.exit_message:
|
| 41 |
+
print(self.exit_message)
|
| 42 |
+
print("")
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
def __call__(self, func):
|
| 46 |
+
@wraps(func)
|
| 47 |
+
def wrapper(*args, **kwargs):
|
| 48 |
+
with _Log(self.enter_message, self.exit_message):
|
| 49 |
+
return func(*args, **kwargs)
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
return wrapper
|
env/resolve.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from enum import Enum, StrEnum
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
"""定义项目的根路径"""
|
| 7 |
+
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
"""定义根据环境变量选择配置的函数"""
|
| 11 |
+
class Env(StrEnum):
|
| 12 |
+
TEST = "test"
|
| 13 |
+
PRODUCTION = "production"
|
| 14 |
+
|
| 15 |
+
def resolve_env[T](test_conf: T = Env.TEST, prod_conf: T = Env.PRODUCTION) -> T:
|
| 16 |
+
env = os.environ.get("ENV", str(Env.TEST))
|
| 17 |
+
return prod_conf if env == str(Env.PRODUCTION) else test_conf
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
"""定义一些预设的目录"""
|
| 21 |
+
SAVED_DIR = resolve_env(
|
| 22 |
+
PROJECT_ROOT / "local" / "saved",
|
| 23 |
+
PROJECT_ROOT / "saved",
|
| 24 |
+
)
|
| 25 |
+
TASKS_DIR = PROJECT_ROOT / "local" / "tasks"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
"""定义一些路径解析函数,方便在项目中使用"""
|
| 29 |
+
def resolve_saved(path: str | Path = None) -> Path:
|
| 30 |
+
"""解析相对于 saved 目录的路径
|
| 31 |
+
|
| 32 |
+
1. 如果本身就是 Path 对象,直接返回。
|
| 33 |
+
2. 如果 path 是 None,返回 saved 目录本身。
|
| 34 |
+
3. 否则,将 path 解析为相对于 saved 目录的路径。
|
| 35 |
+
"""
|
| 36 |
+
if isinstance(path, Path):
|
| 37 |
+
return path
|
| 38 |
+
return SAVED_DIR / path if path else SAVED_DIR
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def resolve_task_dir(task_name: str) -> Path:
|
| 42 |
+
"""解析任务所在的目录
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
task_name: 任务名称,即定义在 Pipeline 中的 name 字段,例如 "poetry_gpt" 或 "poetry_rnn"。
|
| 46 |
+
"""
|
| 47 |
+
return TASKS_DIR / task_name
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def resolve_path(path: str | Path) -> Path:
|
| 51 |
+
"""从项目根目录解析路径
|
| 52 |
+
|
| 53 |
+
1. 如果路径是 Path 对象,直接返回。
|
| 54 |
+
2. 如果路径是以 ~ 或 / 开头的绝对路径,则直接返回该路径。
|
| 55 |
+
3. 如果路径是相对路径,则将其解析为相对于项目根目录的路径。
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
path: 相对于项目根目录的路径
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
解析后的绝对路径
|
| 62 |
+
|
| 63 |
+
Example:
|
| 64 |
+
>>> resolve_path("data/dev/mini_c4/file.txt")
|
| 65 |
+
PosixPath('/Users/.../universal_deeplearning/data/dev/mini_c4/file.txt')
|
| 66 |
+
"""
|
| 67 |
+
if isinstance(path, Path):
|
| 68 |
+
return path
|
| 69 |
+
elif path.startswith("~") or path.startswith("/"):
|
| 70 |
+
return Path(path).expanduser().resolve()
|
| 71 |
+
else:
|
| 72 |
+
return PROJECT_ROOT / path
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def display_path(path: str | Path) -> str:
|
| 76 |
+
"""将路径转换为适合展示的字符串
|
| 77 |
+
|
| 78 |
+
如果路径位于项目根目录内,则显示为相对项目根目录的路径;
|
| 79 |
+
否则显示绝对路径。
|
| 80 |
+
"""
|
| 81 |
+
resolved = resolve_path(path)
|
| 82 |
+
try:
|
| 83 |
+
return str(resolved.relative_to(PROJECT_ROOT))
|
| 84 |
+
except ValueError:
|
| 85 |
+
return str(resolved)
|
env/runner.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ActionRunner:
|
| 6 |
+
def __call__(self, default_method: str | Callable = None):
|
| 7 |
+
if len(sys.argv) > 1:
|
| 8 |
+
method = self._resolve_method(sys.argv[1])
|
| 9 |
+
else:
|
| 10 |
+
method = default_method
|
| 11 |
+
if type(method) == str:
|
| 12 |
+
method = self._resolve_method(method)
|
| 13 |
+
|
| 14 |
+
if method:
|
| 15 |
+
method()
|
| 16 |
+
else:
|
| 17 |
+
raise ValueError("没有指定要执行的方法")
|
| 18 |
+
|
| 19 |
+
def _resolve_method(self, method_name: str) -> Callable:
|
| 20 |
+
method = getattr(self, method_name, None)
|
| 21 |
+
if method is None:
|
| 22 |
+
raise ValueError(f"没有找到对应的方法:{method_name}")
|
| 23 |
+
return method
|
env/vocab.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 定义所有词典中 PAD 的 id token
|
| 2 |
+
PAD = 0
|
environment-linux.yml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: general-dl
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.12
|
| 6 |
+
- pip
|
| 7 |
+
- numpy
|
| 8 |
+
- tensorflow
|
| 9 |
+
- tensorflow-text
|
| 10 |
+
- keras
|
| 11 |
+
- pip:
|
| 12 |
+
- keras-hub
|
| 13 |
+
- gradio
|
| 14 |
+
variables:
|
| 15 |
+
ENV: production
|
environment.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: general-dl
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.12
|
| 6 |
+
- pip
|
| 7 |
+
- setuptools>=68,<70
|
| 8 |
+
- numpy
|
| 9 |
+
- ruff
|
| 10 |
+
- pytest
|
| 11 |
+
- pytest-mock
|
| 12 |
+
- tensorflow
|
| 13 |
+
- keras
|
| 14 |
+
- pip:
|
| 15 |
+
- gradio
|
| 16 |
+
variables:
|
| 17 |
+
ENV: test
|
generate_requirements.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
从 environment-linux.yml 生成 requirements.txt
|
| 4 |
+
YAML 中的版本号优先级最高
|
| 5 |
+
未指定版本号时查询当前环境的实际版本
|
| 6 |
+
排除 python 和 pip
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import yaml
|
| 12 |
+
from importlib.metadata import version, PackageNotFoundError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 排除的包(不加入 requirements.txt)
|
| 16 |
+
EXCLUDE_PACKAGES = {"python", "pip"}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_installed_version(package_name):
|
| 20 |
+
"""获取包的安装版本,未安装返回 None"""
|
| 21 |
+
try:
|
| 22 |
+
return version(package_name)
|
| 23 |
+
except PackageNotFoundError:
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_package_string(dep):
|
| 28 |
+
"""
|
| 29 |
+
解析包字符串,返回 (包名, yaml版本号或None)
|
| 30 |
+
例如: "tensorflow=2.15.0" -> ("tensorflow", "2.15.0")
|
| 31 |
+
"numpy" -> ("numpy", None)
|
| 32 |
+
"""
|
| 33 |
+
if "=" in dep:
|
| 34 |
+
parts = dep.split("=")
|
| 35 |
+
pkg_name = parts[0]
|
| 36 |
+
pkg_version = parts[1]
|
| 37 |
+
return pkg_name, pkg_version
|
| 38 |
+
else:
|
| 39 |
+
return dep, None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def parse_environment_yml(filepath):
|
| 43 |
+
"""解析 environment-linux.yml,提取包列表和版本信息"""
|
| 44 |
+
with open(filepath, "r") as f:
|
| 45 |
+
env = yaml.safe_load(f)
|
| 46 |
+
|
| 47 |
+
packages = []
|
| 48 |
+
|
| 49 |
+
for dep in env.get("dependencies", []):
|
| 50 |
+
if isinstance(dep, str):
|
| 51 |
+
# 简单字符串格式:package 或 package=version
|
| 52 |
+
pkg_name, yaml_version = parse_package_string(dep)
|
| 53 |
+
if pkg_name not in EXCLUDE_PACKAGES:
|
| 54 |
+
packages.append((pkg_name, yaml_version))
|
| 55 |
+
elif isinstance(dep, dict) and "pip" in dep:
|
| 56 |
+
# pip 子列表
|
| 57 |
+
for pip_dep in dep["pip"]:
|
| 58 |
+
pkg_name, yaml_version = parse_package_string(pip_dep)
|
| 59 |
+
if pkg_name not in EXCLUDE_PACKAGES:
|
| 60 |
+
packages.append((pkg_name, yaml_version))
|
| 61 |
+
|
| 62 |
+
return packages
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
yml_file = "environment-linux.yml"
|
| 67 |
+
output_file = "requirements.txt"
|
| 68 |
+
|
| 69 |
+
print(f"读取 {yml_file}...")
|
| 70 |
+
packages = parse_environment_yml(yml_file)
|
| 71 |
+
print(f"发现 {len(packages)} 个包(排除 {EXCLUDE_PACKAGES})")
|
| 72 |
+
|
| 73 |
+
lines = []
|
| 74 |
+
for pkg_name, yaml_version in packages:
|
| 75 |
+
if yaml_version:
|
| 76 |
+
# YAML 中有版本号,优先使用
|
| 77 |
+
lines.append(f"{pkg_name}=={yaml_version}")
|
| 78 |
+
print(f" ✓ {pkg_name}=={yaml_version} (来自 YAML)")
|
| 79 |
+
else:
|
| 80 |
+
# YAML 中没有版本号,查询当前环境
|
| 81 |
+
env_version = get_installed_version(pkg_name)
|
| 82 |
+
if env_version:
|
| 83 |
+
lines.append(f"{pkg_name}=={env_version}")
|
| 84 |
+
print(f" ✓ {pkg_name}=={env_version} (来自当前环境)")
|
| 85 |
+
else:
|
| 86 |
+
lines.append(pkg_name)
|
| 87 |
+
print(f" ⚠ {pkg_name} (未安装,无版本号)")
|
| 88 |
+
|
| 89 |
+
# 添加头部注释
|
| 90 |
+
header_lines = [
|
| 91 |
+
f"# Generated from {yml_file}",
|
| 92 |
+
f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
| 93 |
+
f"# Environment: {os.environ.get('ENV', 'unknown')}",
|
| 94 |
+
"#",
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# 合并所有行
|
| 98 |
+
all_lines = header_lines + lines
|
| 99 |
+
|
| 100 |
+
with open(output_file, "w") as f:
|
| 101 |
+
f.write("\n".join(all_lines) + "\n")
|
| 102 |
+
|
| 103 |
+
print(f"\n已生成 {output_file}:")
|
| 104 |
+
print("-" * 40)
|
| 105 |
+
print("\n".join(all_lines))
|
| 106 |
+
print("-" * 40)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|
models/__init__.py
ADDED
|
File without changes
|
models/mini_gpt/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from models.mini_gpt.model_builder import GptModelBuilder
|
models/mini_gpt/gpt_components.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT模型的共享组件模块:
|
| 3 |
+
|
| 4 |
+
- Positional Encoding
|
| 5 |
+
- Transformer Decoder
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import keras
|
| 9 |
+
from keras import layers, ops
|
| 10 |
+
|
| 11 |
+
class PositionalEmbedding(keras.Layer):
|
| 12 |
+
def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
self.token_embeddings = layers.Embedding(input_dim, output_dim)
|
| 15 |
+
self.position_embeddings = layers.Embedding(sequence_length, output_dim)
|
| 16 |
+
|
| 17 |
+
def call(self, inputs, reverse=False):
|
| 18 |
+
if reverse:
|
| 19 |
+
token_embeddings = self.token_embeddings.embeddings
|
| 20 |
+
return ops.matmul(inputs, ops.transpose(token_embeddings))
|
| 21 |
+
positions = ops.cumsum(ops.ones_like(inputs), axis=-1) - 1
|
| 22 |
+
embedded_tokens = self.token_embeddings(inputs)
|
| 23 |
+
embedded_positions = self.position_embeddings(positions)
|
| 24 |
+
return embedded_tokens + embedded_positions
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TransformerDecoder(keras.Layer):
|
| 28 |
+
def __init__(self, hidden_dim, intermediate_dim, num_heads, **kwargs):
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
|
| 31 |
+
self.hidden_dim = hidden_dim
|
| 32 |
+
self.intermediate_dim = intermediate_dim
|
| 33 |
+
|
| 34 |
+
key_dim = hidden_dim // num_heads
|
| 35 |
+
|
| 36 |
+
# self-attention 层
|
| 37 |
+
self.self_attention = layers.MultiHeadAttention(num_heads, key_dim, dropout=0.1)
|
| 38 |
+
self.self_attention_layernorm = layers.LayerNormalization()
|
| 39 |
+
|
| 40 |
+
# feed-forward 层
|
| 41 |
+
self.feed_forward_1 = layers.Dense(intermediate_dim, activation="relu")
|
| 42 |
+
self.feed_forward_2 = layers.Dense(hidden_dim)
|
| 43 |
+
self.feed_forward_layernorm = layers.LayerNormalization()
|
| 44 |
+
self.dropout = layers.Dropout(0.1)
|
| 45 |
+
|
| 46 |
+
def call(self, inputs):
|
| 47 |
+
# self-attention 计算
|
| 48 |
+
residual = x = inputs
|
| 49 |
+
x = self.self_attention(query=x, key=x, value=x, use_causal_mask=True)
|
| 50 |
+
x = self.dropout(x)
|
| 51 |
+
x = x + residual
|
| 52 |
+
x = self.self_attention_layernorm(x)
|
| 53 |
+
|
| 54 |
+
# feed-forward 计算
|
| 55 |
+
residual = x
|
| 56 |
+
x = self.feed_forward_1(x)
|
| 57 |
+
x = self.feed_forward_2(x)
|
| 58 |
+
x = self.dropout(x)
|
| 59 |
+
x = x + residual
|
| 60 |
+
x = self.feed_forward_layernorm(x)
|
| 61 |
+
return x
|
models/mini_gpt/model_builder.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import keras
|
| 5 |
+
from keras import layers
|
| 6 |
+
|
| 7 |
+
from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder
|
| 8 |
+
from pipeline.base.generation import generate_with_training_model
|
| 9 |
+
from pipeline.base.model_builder import ModelArtifact
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GptModelBuilder:
|
| 14 |
+
hidden_dim: int
|
| 15 |
+
intermediate_dim: int
|
| 16 |
+
num_heads: int
|
| 17 |
+
num_layers: int
|
| 18 |
+
|
| 19 |
+
def build_training_artifact(
|
| 20 |
+
self,
|
| 21 |
+
vocab_size: int,
|
| 22 |
+
sequence_length: int
|
| 23 |
+
) -> ModelArtifact:
|
| 24 |
+
inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
|
| 25 |
+
embedding = PositionalEmbedding(
|
| 26 |
+
sequence_length,
|
| 27 |
+
vocab_size,
|
| 28 |
+
self.hidden_dim,
|
| 29 |
+
name="embedding"
|
| 30 |
+
)
|
| 31 |
+
x = embedding(inputs)
|
| 32 |
+
x = layers.LayerNormalization(name="input_layer_norm")(x)
|
| 33 |
+
|
| 34 |
+
for i in range(self.num_layers):
|
| 35 |
+
decoder = TransformerDecoder(
|
| 36 |
+
self.hidden_dim,
|
| 37 |
+
self.intermediate_dim,
|
| 38 |
+
self.num_heads,
|
| 39 |
+
name=f"decoder_{i}"
|
| 40 |
+
)
|
| 41 |
+
x = decoder(x)
|
| 42 |
+
|
| 43 |
+
outputs = embedding(x, reverse=True)
|
| 44 |
+
model = keras.Model(inputs, outputs, name="mini_gpt")
|
| 45 |
+
return ModelArtifact(
|
| 46 |
+
model=model,
|
| 47 |
+
generate=partial(generate_with_training_model, model)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def build_inference_artifact(
|
| 51 |
+
self,
|
| 52 |
+
training_artifact: ModelArtifact
|
| 53 |
+
) -> ModelArtifact:
|
| 54 |
+
return training_artifact
|
models/rnn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from models.rnn.model_builder import RNNModelBuilder
|
models/rnn/model_builder.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import keras
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from keras import layers
|
| 7 |
+
|
| 8 |
+
from pipeline.base.generation import generate_with_stateful_model, generate_with_training_model
|
| 9 |
+
from pipeline.base.model_builder import ModelArtifact
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RNNModelBuilder:
|
| 14 |
+
num_layers: int = 2
|
| 15 |
+
embedding_dim: int = 100
|
| 16 |
+
hidden_dim: int = 1024
|
| 17 |
+
|
| 18 |
+
def build_training_artifact(
|
| 19 |
+
self,
|
| 20 |
+
vocab_size: int,
|
| 21 |
+
sequence_length: int
|
| 22 |
+
) -> ModelArtifact:
|
| 23 |
+
inputs = keras.Input(shape=(None,), dtype="int32", name="inputs")
|
| 24 |
+
x = layers.Embedding(
|
| 25 |
+
input_dim=vocab_size,
|
| 26 |
+
output_dim=self.embedding_dim,
|
| 27 |
+
mask_zero=True,
|
| 28 |
+
name="embedding"
|
| 29 |
+
)(inputs)
|
| 30 |
+
|
| 31 |
+
for i in range(self.num_layers):
|
| 32 |
+
x = layers.LSTM(
|
| 33 |
+
self.hidden_dim,
|
| 34 |
+
return_sequences=True,
|
| 35 |
+
recurrent_dropout=0.1,
|
| 36 |
+
name=f"lstm_{i}"
|
| 37 |
+
)(x)
|
| 38 |
+
x = layers.Dropout(0.1, name=f"dropout_{i}")(x)
|
| 39 |
+
|
| 40 |
+
outputs = layers.Dense(vocab_size, name="logits")(x)
|
| 41 |
+
model = keras.Model(inputs=inputs, outputs=outputs, name="rnn_training")
|
| 42 |
+
return ModelArtifact(
|
| 43 |
+
model=model,
|
| 44 |
+
generate=partial(generate_with_training_model, model)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def build_inference_artifact(
|
| 48 |
+
self,
|
| 49 |
+
training_artifact: ModelArtifact
|
| 50 |
+
) -> ModelArtifact:
|
| 51 |
+
inference_model = self._build_inference_model_from_training_model(
|
| 52 |
+
training_artifact.model
|
| 53 |
+
)
|
| 54 |
+
return ModelArtifact(
|
| 55 |
+
model=inference_model,
|
| 56 |
+
generate=partial(
|
| 57 |
+
generate_with_stateful_model,
|
| 58 |
+
inference_model,
|
| 59 |
+
initial_states=self._initial_states(batch_size=1)
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _build_inference_model_from_training_model(
|
| 64 |
+
self,
|
| 65 |
+
training_model: keras.Model
|
| 66 |
+
) -> keras.Model:
|
| 67 |
+
token_input = keras.Input(shape=(None,), dtype="int32", name="token_input")
|
| 68 |
+
state_inputs = []
|
| 69 |
+
for i in range(self.num_layers):
|
| 70 |
+
h_input = keras.Input(shape=(self.hidden_dim,), name=f"h_{i}_input")
|
| 71 |
+
c_input = keras.Input(shape=(self.hidden_dim,), name=f"c_{i}_input")
|
| 72 |
+
state_inputs.extend([h_input, c_input])
|
| 73 |
+
|
| 74 |
+
embedding = training_model.get_layer("embedding")
|
| 75 |
+
logits_layer = training_model.get_layer("logits")
|
| 76 |
+
x = embedding(token_input)
|
| 77 |
+
|
| 78 |
+
new_states = []
|
| 79 |
+
inference_lstm_layers = []
|
| 80 |
+
for i in range(self.num_layers):
|
| 81 |
+
inference_lstm = layers.LSTM(
|
| 82 |
+
self.hidden_dim,
|
| 83 |
+
return_sequences=i < self.num_layers - 1,
|
| 84 |
+
return_state=True,
|
| 85 |
+
recurrent_dropout=0.1,
|
| 86 |
+
name=f"lstm_{i}"
|
| 87 |
+
)
|
| 88 |
+
h_input = state_inputs[i * 2]
|
| 89 |
+
c_input = state_inputs[i * 2 + 1]
|
| 90 |
+
x, new_h, new_c = inference_lstm(x, initial_state=[h_input, c_input])
|
| 91 |
+
new_states.extend([new_h, new_c])
|
| 92 |
+
dropout = training_model.get_layer(f"dropout_{i}")
|
| 93 |
+
x = dropout(x)
|
| 94 |
+
inference_lstm_layers.append(inference_lstm)
|
| 95 |
+
|
| 96 |
+
logits = logits_layer(x)
|
| 97 |
+
inference_model = keras.Model(
|
| 98 |
+
[token_input] + state_inputs,
|
| 99 |
+
[logits] + new_states,
|
| 100 |
+
name="rnn_inference"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
for i, inference_lstm in enumerate(inference_lstm_layers):
|
| 104 |
+
training_lstm = training_model.get_layer(f"lstm_{i}")
|
| 105 |
+
inference_lstm.set_weights(training_lstm.get_weights())
|
| 106 |
+
|
| 107 |
+
return inference_model
|
| 108 |
+
|
| 109 |
+
def _initial_states(self, batch_size: int) -> list:
|
| 110 |
+
states = []
|
| 111 |
+
for _ in range(self.num_layers):
|
| 112 |
+
states.append(tf.zeros((batch_size, self.hidden_dim)))
|
| 113 |
+
states.append(tf.zeros((batch_size, self.hidden_dim)))
|
| 114 |
+
return states
|
pipeline/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .runner import PipelineRunner
|
| 2 |
+
from .pipeline import Pipeline
|
| 3 |
+
from .base.configs import CheckpointConfig
|
pipeline/base/__init__.py
ADDED
|
File without changes
|
pipeline/base/checkpoint.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
模型工具模块
|
| 3 |
+
|
| 4 |
+
包含模型构建、检查点管理等通用功能。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pathlib
|
| 8 |
+
import re
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
from env.resolve import resolve_path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def extract_number_of_filename(filename: str) -> int:
|
| 15 |
+
"""
|
| 16 |
+
从文件名中提取数字,无论数字出现在文件名的哪个位置。
|
| 17 |
+
|
| 18 |
+
例如:
|
| 19 |
+
- "model_epoch_001.weights.h5" -> 1
|
| 20 |
+
- "checkpoint_2024_06_30_epoch_002.weights.h5" -> 2
|
| 21 |
+
- "model_epoch_final.weights.h5" -> 抛出异常
|
| 22 |
+
|
| 23 |
+
:param filename: 包含数字的文件名字符串
|
| 24 |
+
:return: 提取的数字,如果没有数字则返回0
|
| 25 |
+
"""
|
| 26 |
+
numbers = re.findall(r"\d+", filename)
|
| 27 |
+
if numbers:
|
| 28 |
+
return int(numbers[-1]) # 返回最后一个数字,假设它是代数
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"No number found in filename: {filename}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def resolve_checkpoint(
|
| 34 |
+
dirs: list[pathlib.Path | str] | None = None,
|
| 35 |
+
path: pathlib.Path | str | None = None,
|
| 36 |
+
epoch: int | None = None,
|
| 37 |
+
suffix: str | None = None
|
| 38 |
+
):
|
| 39 |
+
"""统一解析模型检查点路径
|
| 40 |
+
|
| 41 |
+
支持直接指定检查点文件路径或在目录中查找检查点文件。
|
| 42 |
+
|
| 43 |
+
参数:
|
| 44 |
+
dirs: 检查点目录列表
|
| 45 |
+
path: 直接指定的检查点文件路径(支持绝对路径和相对路径)
|
| 46 |
+
epoch: 指定的 epoch,用于查找对应的 .weights.h5 文件
|
| 47 |
+
suffix: 指定检查点文件后缀
|
| 48 |
+
|
| 49 |
+
返回:
|
| 50 |
+
(resolved_path, epoch): 绝对路径和 epoch 数
|
| 51 |
+
|
| 52 |
+
抛出:
|
| 53 |
+
FileNotFoundError: 当指定的路径不存在或未找到检查点文件时
|
| 54 |
+
ValueError: 当参数无效时
|
| 55 |
+
"""
|
| 56 |
+
resolved_dirs = _resolve_checkpoint_dirs(dirs)
|
| 57 |
+
|
| 58 |
+
if path is not None:
|
| 59 |
+
path = pathlib.Path(path)
|
| 60 |
+
|
| 61 |
+
if not path.is_absolute():
|
| 62 |
+
if not resolved_dirs:
|
| 63 |
+
raise ValueError("path 是相对路径时,必须提供 dirs")
|
| 64 |
+
path = _resolve_relative_checkpoint_path(path, resolved_dirs)
|
| 65 |
+
else:
|
| 66 |
+
if dirs is not None:
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"警告:path 是绝对路径,dirs 参数将被忽略",
|
| 69 |
+
UserWarning
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not path.exists():
|
| 73 |
+
raise FileNotFoundError(f"检查点文件不存在: {path}")
|
| 74 |
+
if suffix is not None and not path.name.endswith(suffix):
|
| 75 |
+
raise FileNotFoundError(f"检查点文件后缀不匹配: {path}")
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
epoch_num = extract_number_of_filename(path.stem)
|
| 79 |
+
except ValueError:
|
| 80 |
+
epoch_num = 0
|
| 81 |
+
|
| 82 |
+
return path, epoch_num
|
| 83 |
+
|
| 84 |
+
if not resolved_dirs:
|
| 85 |
+
raise ValueError("必须提供 dirs 或 path")
|
| 86 |
+
|
| 87 |
+
files_with_number = _collect_checkpoint_files(
|
| 88 |
+
checkpoint_dirs=resolved_dirs,
|
| 89 |
+
suffix=suffix
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if epoch is not None:
|
| 93 |
+
matches = [(f, num) for f, num in files_with_number if num == epoch]
|
| 94 |
+
if not matches:
|
| 95 |
+
raise FileNotFoundError(f"未找到 epoch {epoch} 对应的检查点文件")
|
| 96 |
+
if len(matches) > 1:
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
f"找到多个 epoch {epoch} 对应的检查点文件: {[match[0].name for match in matches]}"
|
| 99 |
+
)
|
| 100 |
+
return matches[0]
|
| 101 |
+
|
| 102 |
+
if not files_with_number:
|
| 103 |
+
return None, 0
|
| 104 |
+
|
| 105 |
+
return max(files_with_number, key=lambda item: item[1])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _resolve_checkpoint_dirs(
|
| 109 |
+
dirs: list[pathlib.Path | str] | None
|
| 110 |
+
) -> list[pathlib.Path]:
|
| 111 |
+
if dirs is None:
|
| 112 |
+
return []
|
| 113 |
+
return [resolve_path(path) for path in dirs]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _resolve_relative_checkpoint_path(
|
| 117 |
+
checkpoint_path: pathlib.Path,
|
| 118 |
+
checkpoint_dirs: list[pathlib.Path]
|
| 119 |
+
) -> pathlib.Path:
|
| 120 |
+
for checkpoint_dir in checkpoint_dirs:
|
| 121 |
+
candidate = checkpoint_dir / checkpoint_path
|
| 122 |
+
if candidate.exists():
|
| 123 |
+
return candidate
|
| 124 |
+
return checkpoint_dirs[0] / checkpoint_path
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _collect_checkpoint_files(
|
| 128 |
+
checkpoint_dirs: list[pathlib.Path],
|
| 129 |
+
suffix: str | None
|
| 130 |
+
) -> list[tuple[pathlib.Path, int]]:
|
| 131 |
+
files_with_number = []
|
| 132 |
+
for checkpoint_dir in checkpoint_dirs:
|
| 133 |
+
if not checkpoint_dir.exists():
|
| 134 |
+
continue
|
| 135 |
+
for file_path in sorted(checkpoint_dir.iterdir()):
|
| 136 |
+
if not file_path.is_file():
|
| 137 |
+
continue
|
| 138 |
+
if suffix is not None and not file_path.name.endswith(suffix):
|
| 139 |
+
continue
|
| 140 |
+
if suffix is None and not _is_checkpoint_file(file_path):
|
| 141 |
+
continue
|
| 142 |
+
files_with_number.append((file_path, extract_number_of_filename(file_path.stem)))
|
| 143 |
+
return files_with_number
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _is_checkpoint_file(file_path: pathlib.Path) -> bool:
|
| 147 |
+
return file_path.name.endswith(".keras") or file_path.name.endswith(".weights.h5")
|
pipeline/base/configs.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class CheckpointConfig:
|
| 8 |
+
dirs: list[Path] | None = None
|
| 9 |
+
path: Path = None
|
| 10 |
+
epoch: int = None
|
| 11 |
+
suffix: str = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ModelConfig:
|
| 16 |
+
sequence_length: int = 256
|
| 17 |
+
hidden_dim: int = 512
|
| 18 |
+
intermediate_dim: int = 2056
|
| 19 |
+
num_heads: int = 8
|
| 20 |
+
num_layers: int = 8
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TrainingRule:
|
| 25 |
+
batch_size: int = 128
|
| 26 |
+
epochs: int = 1
|
| 27 |
+
steps_per_epoch: int = 30
|
| 28 |
+
validation_batches: int = 1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class GenerationRule:
|
| 33 |
+
prompts_generator: Callable
|
| 34 |
+
sample_strategy: Callable
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class CheckpointRules:
|
| 39 |
+
training: CheckpointConfig = field(default_factory=CheckpointConfig)
|
| 40 |
+
testing: CheckpointConfig = field(default_factory=CheckpointConfig)
|
| 41 |
+
deployment: CheckpointConfig = field(default_factory=CheckpointConfig)
|
| 42 |
+
|
| 43 |
+
def resolve_training_rule(
|
| 44 |
+
self,
|
| 45 |
+
default_dirs: list[Path | str] | None = None
|
| 46 |
+
) -> dict:
|
| 47 |
+
return self._resolve_rule(self.training, default_dirs)
|
| 48 |
+
|
| 49 |
+
def resolve_testing_rule(
|
| 50 |
+
self,
|
| 51 |
+
default_dirs: list[Path | str] | None = None
|
| 52 |
+
) -> dict:
|
| 53 |
+
return self._resolve_rule(self.testing, default_dirs)
|
| 54 |
+
|
| 55 |
+
def resolve_deployment_rule(
|
| 56 |
+
self,
|
| 57 |
+
default_dirs: list[Path | str] | None = None
|
| 58 |
+
) -> dict:
|
| 59 |
+
return self._resolve_rule(self.deployment, default_dirs)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def _resolve_rule(checkpoint: CheckpointConfig, default_dirs: list[Path | str] | None) -> dict:
|
| 63 |
+
dirs = checkpoint.dirs if checkpoint.dirs is not None else default_dirs
|
| 64 |
+
return {
|
| 65 |
+
"dirs": dirs,
|
| 66 |
+
"path": checkpoint.path,
|
| 67 |
+
"epoch": checkpoint.epoch,
|
| 68 |
+
"suffix": checkpoint.suffix
|
| 69 |
+
}
|
pipeline/base/generation.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
与生成有关的组件
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pathlib
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Callable
|
| 8 |
+
|
| 9 |
+
import keras
|
| 10 |
+
import numpy as np
|
| 11 |
+
from keras import callbacks, ops
|
| 12 |
+
|
| 13 |
+
from env.vocab import PAD
|
| 14 |
+
from env.logger import get_logger
|
| 15 |
+
from pipeline.base.model_builder import GenerationContext, GenerationResult, ModelArtifact
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_with_training_model(
|
| 19 |
+
model: keras.Model,
|
| 20 |
+
context: GenerationContext,
|
| 21 |
+
prompt_tokens: list[int]
|
| 22 |
+
) -> GenerationResult:
|
| 23 |
+
prompt_length = len(prompt_tokens)
|
| 24 |
+
|
| 25 |
+
if prompt_length == 0:
|
| 26 |
+
return GenerationResult([], "<|empty|>")
|
| 27 |
+
|
| 28 |
+
tokens = prompt_tokens + [PAD] * (context.max_length - prompt_length)
|
| 29 |
+
|
| 30 |
+
for i in range(prompt_length, context.max_length):
|
| 31 |
+
prediction = model.predict(np.array([tokens]), verbose=0)
|
| 32 |
+
prediction = prediction[0, i - 1]
|
| 33 |
+
next_token = ops.convert_to_numpy(context.sample_fn(prediction))
|
| 34 |
+
next_token_id = np.array(next_token).item()
|
| 35 |
+
tokens[i] = next_token_id
|
| 36 |
+
|
| 37 |
+
if next_token_id == context.end_of_text:
|
| 38 |
+
return GenerationResult(tokens[:i], "<|endoftext|>")
|
| 39 |
+
if next_token_id == PAD:
|
| 40 |
+
return GenerationResult(tokens[:i], "<|pad|>")
|
| 41 |
+
|
| 42 |
+
return GenerationResult(tokens, "<|maxlength|>")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def generate_with_stateful_model(
|
| 46 |
+
model: keras.Model,
|
| 47 |
+
context: GenerationContext,
|
| 48 |
+
prompt_tokens: list[int],
|
| 49 |
+
initial_states: list
|
| 50 |
+
) -> GenerationResult:
|
| 51 |
+
if not prompt_tokens:
|
| 52 |
+
return GenerationResult([], "<|empty|>")
|
| 53 |
+
|
| 54 |
+
tokens = list(prompt_tokens)
|
| 55 |
+
batch_tokens = np.array([tokens])
|
| 56 |
+
logits, *states = model.predict([batch_tokens] + initial_states, verbose=0)
|
| 57 |
+
|
| 58 |
+
for _ in range(len(tokens), context.max_length):
|
| 59 |
+
next_token = ops.convert_to_numpy(context.sample_fn(logits[0]))
|
| 60 |
+
next_token_id = np.array(next_token).item()
|
| 61 |
+
tokens.append(next_token_id)
|
| 62 |
+
|
| 63 |
+
if next_token_id == context.end_of_text:
|
| 64 |
+
return GenerationResult(tokens[:-1], "<|endoftext|>")
|
| 65 |
+
if next_token_id <= PAD:
|
| 66 |
+
return GenerationResult(tokens, "<|pad|>")
|
| 67 |
+
|
| 68 |
+
logits, *states = model.predict([np.array([[next_token_id]])] + states, verbose=0)
|
| 69 |
+
|
| 70 |
+
return GenerationResult(tokens, "<|maxlength|>")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class TextGenerationResult:
|
| 75 |
+
text: str
|
| 76 |
+
stop_reason: str
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TextGenerator:
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
artifact: ModelArtifact,
|
| 83 |
+
tokenizer: Any,
|
| 84 |
+
decode: Callable,
|
| 85 |
+
end_of_text: int,
|
| 86 |
+
sample_fn: Callable,
|
| 87 |
+
max_length: int
|
| 88 |
+
):
|
| 89 |
+
self.artifact = artifact
|
| 90 |
+
self.tokenizer = tokenizer
|
| 91 |
+
self.decode = decode
|
| 92 |
+
self.context = GenerationContext(
|
| 93 |
+
end_of_text=end_of_text,
|
| 94 |
+
max_length=max_length,
|
| 95 |
+
sample_fn=sample_fn
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def generate_tokens(
|
| 99 |
+
self,
|
| 100 |
+
prompt: str,
|
| 101 |
+
max_length: int | None = None,
|
| 102 |
+
sample_fn: Callable | None = None
|
| 103 |
+
) -> GenerationResult:
|
| 104 |
+
context = GenerationContext(
|
| 105 |
+
end_of_text=self.context.end_of_text,
|
| 106 |
+
max_length=max_length if max_length is not None else self.context.max_length,
|
| 107 |
+
sample_fn=sample_fn if sample_fn is not None else self.context.sample_fn
|
| 108 |
+
)
|
| 109 |
+
prompt_tokens = self._tokenize_prompt(prompt)
|
| 110 |
+
return self.artifact.generate(context, prompt_tokens)
|
| 111 |
+
|
| 112 |
+
def generate_text(
|
| 113 |
+
self,
|
| 114 |
+
prompt: str,
|
| 115 |
+
max_length: int | None = None,
|
| 116 |
+
sample_fn: Callable | None = None
|
| 117 |
+
) -> TextGenerationResult:
|
| 118 |
+
result = self.generate_tokens(prompt, max_length, sample_fn)
|
| 119 |
+
return TextGenerationResult(
|
| 120 |
+
text=self.decode(result.token_ids),
|
| 121 |
+
stop_reason=result.stop_reason
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _tokenize_prompt(self, prompt: str) -> list[int]:
|
| 125 |
+
prompt_tokens = list(ops.convert_to_numpy(self.tokenizer(prompt)))
|
| 126 |
+
return [token for token in prompt_tokens if token > PAD]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class GenerationCallback(callbacks.Callback):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
prompts: list[str],
|
| 133 |
+
log_file: pathlib.Path,
|
| 134 |
+
tokenizer: Any,
|
| 135 |
+
decode: Callable,
|
| 136 |
+
end_of_text: int,
|
| 137 |
+
max_length: int,
|
| 138 |
+
sample_fn: Callable,
|
| 139 |
+
training_artifact: ModelArtifact
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.prompts = prompts
|
| 143 |
+
self.tokenizer = tokenizer
|
| 144 |
+
self.decode = decode
|
| 145 |
+
self.end_of_text = end_of_text
|
| 146 |
+
self.max_length = max_length
|
| 147 |
+
self.sample_fn = sample_fn
|
| 148 |
+
self.training_artifact = training_artifact
|
| 149 |
+
self.logger = self.init_logger(log_file)
|
| 150 |
+
|
| 151 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 152 |
+
generator = TextGenerator(
|
| 153 |
+
artifact=self.training_artifact,
|
| 154 |
+
tokenizer=self.tokenizer,
|
| 155 |
+
decode=self.decode,
|
| 156 |
+
end_of_text=self.end_of_text,
|
| 157 |
+
max_length=self.max_length,
|
| 158 |
+
sample_fn=self.sample_fn
|
| 159 |
+
)
|
| 160 |
+
self.logger.info(f"\nGenerated text after epoch {epoch + 1}:")
|
| 161 |
+
for i, prompt in enumerate(self.prompts):
|
| 162 |
+
result = generator.generate_text(prompt)
|
| 163 |
+
self.logger.info(f"Prompt {i + 1:2}: {prompt}")
|
| 164 |
+
self.logger.info(f"Generated: {result.text}{result.stop_reason}\n")
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def init_logger(log_file: pathlib.Path):
|
| 168 |
+
if not log_file.parent.exists():
|
| 169 |
+
log_file.parent.mkdir(parents=True)
|
| 170 |
+
|
| 171 |
+
logger = get_logger("GenerationCallback", filepath=str(log_file))
|
| 172 |
+
logger.info("Initialized GenerationCallback logger")
|
| 173 |
+
|
| 174 |
+
return logger
|