Chinese Classical GPT-2
A 335M parameter GPT-2 model trained from scratch for style-conditioned classical Chinese text generation, with post-training for Li Bai persona emulation.
Overview
This project implements a complete pipeline from pre-training to persona-based dialogue:
- Pre-training: Two-stage curriculum learning (general Chinese โ classical Chinese) with Style Embedding for 5 literary genres
- Post-training: Continual Pre-training (CPT) + Supervised Fine-Tuning (SFT) for Li Bai persona
Model Variants
This repository contains multiple checkpoints from different training stages and ablation experiments:
| Model |
Checkpoint |
Description |
| GPT2-SE |
checkpoints_post/sft_final.pt |
Full model: post-trained with Style Embedding (primary) |
| GPT2-Base |
(available on request) |
Post-trained without Style Embedding (fair ablation) |
| GPT2-Raw |
checkpoints/stage2_final.pt |
Pre-trained only, no post-training (baseline) |
| Stage 1 |
checkpoints/stage1_final.pt |
General Chinese pre-training checkpoint |
Architecture
| Parameter |
Value |
| Architecture |
GPT-2 (Decoder-only Transformer) |
| Parameters |
335,609,856 |
| Layers |
24 |
| Attention Heads |
16 |
| Hidden Dimension |
1024 |
| Max Sequence Length |
512 tokens |
| Vocabulary |
32,000 (SentencePiece BPE) |
| Style Conditioning |
Learnable embedding (6 styles ร 1024 dim) |
Style Personas (Pre-training)
| Persona |
Genre |
Era |
| Li Bai (ๆ็ฝ) |
Poetry (่ฏ) |
Tang Dynasty |
| Su Shi (่่ฝผ) |
Ci Poetry (่ฏ) |
Song Dynasty |
| Pu Songling (่ฒๆพ้พ) |
Fiction (ๅฐ่ฏด) |
Qing Dynasty |
| Han Yu (้ฉๆ) |
Prose (ๆฃๆ) |
Tang Dynasty |
| Sima Qian (ๅธ้ฉฌ่ฟ) |
History (ๅฒไผ ) |
Han Dynasty |
Training
Stage 1 โ General Chinese Pre-training
- Data: 1.68M samples (classical + modern Chinese)
- Result: Loss 10.36 โ 4.0, Accuracy 2.5% โ 33.5%
Stage 2 โ Classical Chinese Specialization
- Data: 1.60M samples (classical Chinese only, with style labels)
- Result: Loss 4.0 โ 3.85, Perplexity 42.43
Post-training โ Li Bai Persona (CPT + SFT)
- CPT: 1,329 Li Bai texts (poems, prose, biographies), Loss 4.30 โ 1.34
- SFT: 1,000 multi-turn dialogues in Li Bai's voice, Loss 3.76 โ 0.58
- Hardware: NVIDIA RTX 4080 SUPER (16GB), ~10 min total
Evaluation
LLM-Judge Quality (Tasks 1-5, scored 0-100)
| Model |
Fluency |
Coherence |
Completeness |
Style |
Literary |
Total |
| GPT2-Raw |
7.47 |
3.81 |
3.88 |
2.65 |
1.91 |
19.72 |
| GPT2-Base |
16.03 |
14.01 |
13.20 |
15.27 |
10.50 |
69.01 |
| GPT2-SE |
16.30 |
14.32 |
13.39 |
15.74 |
10.52 |
70.27 |
Adversarial Robustness (Task 6, scored 0-100)
| Model |
Boundary |
Refusal |
Persona |
Coherence |
Fluency |
Total |
| GPT2-Raw |
2.35 |
2.18 |
5.35 |
4.88 |
11.18 |
25.94 |
| GPT2-Base |
10.94 |
10.71 |
17.00 |
15.94 |
18.71 |
73.30 |
| GPT2-SE |
14.35 |
13.94 |
18.18 |
16.18 |
18.41 |
81.06 |
Persona Identification (open-ended, by DeepSeek judge)
| Model |
Li Bai Identification Accuracy |
| GPT2-Raw |
17.3% |
| GPT2-Base |
70.1% |
| GPT2-SE |
69.8% |
Repository Structure
โโโ checkpoints/
โ โโโ stage1_final.pt # Pre-training Stage 1
โ โโโ stage2_final.pt # Pre-training Stage 2 (GPT2-Raw)
โโโ checkpoints_post/
โ โโโ cpt_final.pt # Post-training CPT (with SE)
โ โโโ sft_final.pt # Post-training SFT (GPT2-SE, primary)
โโโ tokenizer/
โ โโโ chinese_sp.model # SentencePiece BPE tokenizer
โ โโโ chinese_sp.vocab
โโโ evaluation/
โโโ questions.json # 130 evaluation questions
โโโ results_posttrain_style.json
โโโ results_posttrain_nostyle_fair.json
โโโ results_posttrain_nostyle_unfair.json
โโโ results_baseline.json
Usage
Persona Dialogue (Post-trained model)
import torch
import sentencepiece as spm
from model import GPT2
from config import ProjectConfig, STYLE_ID_MAP
sp = spm.SentencePieceProcessor()
sp.load("tokenizer/chinese_sp.model")
config = ProjectConfig()
config.model.vocab_size = sp.get_piece_size()
model = GPT2(config.model, pad_token_id=0)
state = torch.load("checkpoints_post/sft_final.pt", map_location="cpu", weights_only=True)
model.load_state_dict(state)
model.eval()
prompt = "[STYLE:ๆ็ฝ]\n### ็ณป็ป๏ผไฝ ๆฏๅคงๅ่ฏไบบๆ็ฝ\n### ็จๆท๏ผๅไธ้ฆๆไนก็่ฏ\n### ๅ็ญ๏ผ"
input_ids = [sp.bos_id()] + sp.encode(prompt)
idx = torch.tensor([input_ids])
output = model.generate(
idx, max_new_tokens=256, style_id=STYLE_ID_MAP["ๆ็ฝ"],
temperature=0.7, top_k=20, top_p=0.8, repetition_penalty=1.3,
)
print(sp.decode(output[0, len(input_ids):].tolist()))
Limitations
- Local coherence only: 335M parameters cannot maintain long-range narrative logic
- Style bleeding: Style signal attenuates in longer outputs (>200 tokens)
- Potential SFT overfitting: Low SFT loss (0.58) on 1,000 examples ร 10 epochs
- No explicit prosodic supervision: Tonal patterns learned incidentally through statistical co-occurrence
Citation
@misc{chinese-classical-gpt-2026,
title={Cross-Era Alignment for Emulating Ancient Chinese Literati},
author={Zichao Wei and Entang Wang and Zhenyu Feng},
year={2026},
howpublished={Software Project Neural Networks, Saarland University}
}
License
MIT