Instructions to use LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot
- SGLang
How to use LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot with Docker Model Runner:
docker model run hf.co/LLM-OS-Models/HRM-Text-Ko-Terminal-B-SWE-GLM-Pilot
Add files using upload-large-folder tool
Browse files- .gitattributes +9 -0
- README.md +64 -0
- all_config.yaml +40 -0
- carry_epoch_1.0.pt +3 -0
- carry_epoch_1.1.pt +3 -0
- carry_epoch_1.2.pt +3 -0
- carry_epoch_1.3.pt +3 -0
- carry_epoch_1.4.pt +3 -0
- carry_epoch_1.5.pt +3 -0
- carry_epoch_1.6.pt +3 -0
- carry_epoch_1.7.pt +3 -0
- fsdp2_epoch_1/.metadata +3 -0
- fsdp2_epoch_1/__0_0.distcp +3 -0
- fsdp2_epoch_1/__1_0.distcp +3 -0
- fsdp2_epoch_1/__2_0.distcp +3 -0
- fsdp2_epoch_1/__3_0.distcp +3 -0
- fsdp2_epoch_1/__4_0.distcp +3 -0
- fsdp2_epoch_1/__5_0.distcp +3 -0
- fsdp2_epoch_1/__6_0.distcp +3 -0
- fsdp2_epoch_1/__7_0.distcp +3 -0
- hrm_nocarry_bp_warmup.py +100 -0
- train_metadata.yaml +13 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fsdp2_epoch_1/.metadata filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
fsdp2_epoch_1/__5_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
fsdp2_epoch_1/__7_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
fsdp2_epoch_1/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
fsdp2_epoch_1/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
fsdp2_epoch_1/__4_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
fsdp2_epoch_1/__6_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
fsdp2_epoch_1/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
fsdp2_epoch_1/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HRM-Text Ko Terminal B SWE+GLM Pilot
|
| 2 |
+
|
| 3 |
+
Date: 2026-05-23
|
| 4 |
+
|
| 5 |
+
This repository contains a raw HRM-Text FSDP2 training checkpoint, not a
|
| 6 |
+
Transformers-ready model. Convert it with `HRM-Text/conversion/convert_to_hf.py`
|
| 7 |
+
after selecting a checkpoint for release.
|
| 8 |
+
|
| 9 |
+
## Run
|
| 10 |
+
|
| 11 |
+
| Item | Value |
|
| 12 |
+
|---|---:|
|
| 13 |
+
| Architecture | HRM-Text B |
|
| 14 |
+
| Parameters | 435,159,040 |
|
| 15 |
+
| GPUs | 8 x NVIDIA H200 |
|
| 16 |
+
| Epochs | 1 |
|
| 17 |
+
| Global batch | 262,144 tokens |
|
| 18 |
+
| Context | 4,096 train tokens |
|
| 19 |
+
| Wall time | about 7m 38s |
|
| 20 |
+
| Final train loss | 3.00653 |
|
| 21 |
+
| Final token accuracy | 0.46379 |
|
| 22 |
+
|
| 23 |
+
Command:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
WANDB_MODE=offline WANDB_DIR=/home/work/.data/wandb \
|
| 27 |
+
TOKENIZERS_PARALLELISM=false OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
|
| 28 |
+
NCCL_DEBUG=WARN TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \
|
| 29 |
+
torchrun --standalone --nproc_per_node=8 pretrain.py \
|
| 30 |
+
arch/size@arch=B \
|
| 31 |
+
data.path=/home/work/.data/hrm_text_prepared/sft_swe_glm_mix_v1 \
|
| 32 |
+
+checkpoint_path=/home/work/.data/hrm_text_checkpoints/koterm_b_swe_glm_pilot_v1 \
|
| 33 |
+
+project_name=HRM-Ko-Terminal \
|
| 34 |
+
+run_name=koterm_b_swe_glm_pilot_v1 \
|
| 35 |
+
epochs=1 \
|
| 36 |
+
global_batch_size=262144 \
|
| 37 |
+
lr_warmup_steps=100 \
|
| 38 |
+
+log_interval=5 \
|
| 39 |
+
checkpoint_interval=1
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Data
|
| 43 |
+
|
| 44 |
+
Prepared dataset:
|
| 45 |
+
|
| 46 |
+
`/home/work/.data/hrm_text_prepared/sft_swe_glm_mix_v1`
|
| 47 |
+
|
| 48 |
+
| Source | Samples | Tokens | Processing |
|
| 49 |
+
|---|---:|---:|---|
|
| 50 |
+
| SWE-ZERO terminal/code trajectories | 53,868 | 182,717,999 | long instructions middle-truncated |
|
| 51 |
+
| GLM-5.1 reasoning cleaned sample | 56,021 | 68,452,781 | `<think>...</think>` stripped, direct answers |
|
| 52 |
+
| Total | 109,889 | 251,170,780 | PrefixLM, response-only loss |
|
| 53 |
+
|
| 54 |
+
Tokenizer:
|
| 55 |
+
|
| 56 |
+
`https://huggingface.co/LLM-OS-Models/HRM-Text-Ko-Terminal-Tokenizer-131K`
|
| 57 |
+
|
| 58 |
+
## Files
|
| 59 |
+
|
| 60 |
+
- `fsdp2_epoch_1/`: distributed FSDP2 model and optimizer checkpoint.
|
| 61 |
+
- `carry_epoch_1.*.pt`: per-rank carry state.
|
| 62 |
+
- `all_config.yaml`: resolved training config.
|
| 63 |
+
- `train_metadata.yaml`: tokenizer and dataset metadata.
|
| 64 |
+
- `hrm_nocarry_bp_warmup.py`: copied model source for reproducibility.
|
all_config.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
arch:
|
| 2 |
+
H_cycles: 2
|
| 3 |
+
H_override: {}
|
| 4 |
+
L_cycles: 3
|
| 5 |
+
bp_max_steps: 5
|
| 6 |
+
bp_warmup_ratio: 0.2
|
| 7 |
+
expansion: 4
|
| 8 |
+
half_layers: true
|
| 9 |
+
head: lm_head@LMHead
|
| 10 |
+
hidden_size: 1024
|
| 11 |
+
init_type: lecun_normal
|
| 12 |
+
n_layers: 12
|
| 13 |
+
name: baselines.hrm_nocarry_bp_warmup@HierarchicalReasoningModel
|
| 14 |
+
norm_eps: 1.0e-06
|
| 15 |
+
norm_type: pre
|
| 16 |
+
num_heads: 8
|
| 17 |
+
pos_emb_type: rope
|
| 18 |
+
rope_theta: 10000.0
|
| 19 |
+
beta1: 0.9
|
| 20 |
+
beta2: 0.95
|
| 21 |
+
checkpoint_interval: 1
|
| 22 |
+
checkpoint_path: /home/work/.data/hrm_text_checkpoints/koterm_b_swe_glm_pilot_v1
|
| 23 |
+
data:
|
| 24 |
+
path: /home/work/.data/hrm_text_prepared/sft_swe_glm_mix_v1
|
| 25 |
+
target_only: true
|
| 26 |
+
ema: 0.9999
|
| 27 |
+
epochs: 1
|
| 28 |
+
fwd_bwd_dtype: bfloat16
|
| 29 |
+
global_batch_size: 262144
|
| 30 |
+
log_interval: 5
|
| 31 |
+
lr: 0.00022
|
| 32 |
+
lr_min_ratio: 1.0
|
| 33 |
+
lr_warmup_steps: 100
|
| 34 |
+
project_name: HRM-Ko-Terminal
|
| 35 |
+
resume_epoch: null
|
| 36 |
+
resume_from: null
|
| 37 |
+
run_name: koterm_b_swe_glm_pilot_v1
|
| 38 |
+
seed: 0
|
| 39 |
+
weight_decay: 0.1
|
| 40 |
+
weights_only_resume_from_ema: false
|
carry_epoch_1.0.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e3b089e35eacca121e8bc850c0fc138c4ef45a63cf9e370e2f852d6245db36b
|
| 3 |
+
size 1309
|
carry_epoch_1.1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66952aadc5b6f1d9d38cd436e6dba2c3b6a487138c6960beb815672ddf699495
|
| 3 |
+
size 1309
|
carry_epoch_1.2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a7da182d5d1dfe900b2018b6e4fe6d318c69f791b7a3c94c1727e2112a5f57d
|
| 3 |
+
size 1309
|
carry_epoch_1.3.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45156a7f9cef48f22d7a3c46c59d92394c3da40b2b596f2681cecece9156177e
|
| 3 |
+
size 1309
|
carry_epoch_1.4.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34356bd35b48ac6ce98742241b2c0e1c96147c36743415c7b0e432ae28f8bfc8
|
| 3 |
+
size 1309
|
carry_epoch_1.5.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f5934083e16382c0d96bab003d7f577ced0da026175aff5a4ad2aaf31c603f6
|
| 3 |
+
size 1309
|
carry_epoch_1.6.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:93480251cc3a3285f6cea88b5f8b7c6d46672f04cacd0623127885ff4469e7d8
|
| 3 |
+
size 1309
|
carry_epoch_1.7.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2de90149405f8416a3f1c4bb6b69b52843f8a2f835c1b874bd13c13129fc3f5
|
| 3 |
+
size 1309
|
fsdp2_epoch_1/.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:984cd5ca047ff6b6320306732b5cb74da526e625ed003ea674bffd0d9227368c
|
| 3 |
+
size 377453
|
fsdp2_epoch_1/__0_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f991350c2018ff046cf9cb64e91e6cda0691dd9df1a4095b7c2a65e7773f625f
|
| 3 |
+
size 870637105
|
fsdp2_epoch_1/__1_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df90800179941bc59922b15a6365b7d1d00a567e1aeca89cc41deb6d6709a2fa
|
| 3 |
+
size 870646096
|
fsdp2_epoch_1/__2_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bc9f737b46b876c5ce2b8cfea94af87c0d62e11b4c44209cdafbab8874a7e49
|
| 3 |
+
size 870646096
|
fsdp2_epoch_1/__3_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b3e15eb5a237bbd4ece55c0e81c958018d858f7736a6170def544c6d227c9446
|
| 3 |
+
size 870645780
|
fsdp2_epoch_1/__4_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ed8a0de5afb4be166fbd23ae8f1d6f268842b5c9bb39d2734e1b6b6c73a73a7
|
| 3 |
+
size 870645780
|
fsdp2_epoch_1/__5_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5de517605507b4aba4e970c0b3ca88156bc32f0700eb9a6647ea597975899e6a
|
| 3 |
+
size 870645780
|
fsdp2_epoch_1/__6_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68acea8b343e1765c75d56903d948cab297479da6934af6c59bb7c6f2d145ef9
|
| 3 |
+
size 870648468
|
fsdp2_epoch_1/__7_0.distcp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d357e127548f5f89b30289ce50f4ae3f1ed87ce732db5331b51a61bf00afc57
|
| 3 |
+
size 870644519
|
hrm_nocarry_bp_warmup.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Dict, Any, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
from models.common import trunc_normal_init_
|
| 8 |
+
from models.transformer import Transformer, Cache, TransformerConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class HierarchicalReasoningModelConfig(TransformerConfig):
|
| 12 |
+
half_layers: bool = False
|
| 13 |
+
|
| 14 |
+
H_cycles: int
|
| 15 |
+
L_cycles: int
|
| 16 |
+
|
| 17 |
+
bp_warmup_ratio: float = 0.0
|
| 18 |
+
bp_min_steps: int = 2
|
| 19 |
+
bp_max_steps: int = 5
|
| 20 |
+
|
| 21 |
+
# Change some Transformer config of H-level
|
| 22 |
+
# TODO: Try asymmetric H and L module, such as different size, hidden dims, architecture, attention type, etc.
|
| 23 |
+
H_override: Dict[str, Any] = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HierarchicalReasoningModelRecurrentBlock(nn.Module):
|
| 27 |
+
def __init__(self, config: TransformerConfig) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.core = Transformer(config)
|
| 30 |
+
|
| 31 |
+
# Create cache function
|
| 32 |
+
self.create_cache = self.core.create_cache
|
| 33 |
+
|
| 34 |
+
def forward(self, hidden_states: Tensor, input_injection: Tensor, **kwargs) -> Tensor:
|
| 35 |
+
# Input injection (add)
|
| 36 |
+
# TODO: Try better alternatives, such as GRU / gating in the following papers
|
| 37 |
+
# Alternatively, "fixed" gating that does not depend on hidden state is also worth trying
|
| 38 |
+
# E.g. only depends on position and index of hidden_states dimension
|
| 39 |
+
# https://arxiv.org/pdf/1910.06764
|
| 40 |
+
# https://arxiv.org/pdf/2202.10447
|
| 41 |
+
|
| 42 |
+
# TODO: Asymmetric fusion is also worth trying. assign different number of tokens to H and L.
|
| 43 |
+
return self.core(hidden_states + input_injection, **kwargs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class HierarchicalReasoningModel(nn.Module):
|
| 47 |
+
def __init__(self, config_dict: dict) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
config = HierarchicalReasoningModelConfig(**config_dict)
|
| 50 |
+
if config.half_layers:
|
| 51 |
+
assert config.n_layers % 2 == 0, "n_layers must be divisible by 2."
|
| 52 |
+
config.n_layers //= 2
|
| 53 |
+
|
| 54 |
+
# Reasoning Layers
|
| 55 |
+
# TODO: Asymmetric.
|
| 56 |
+
self.H_level = HierarchicalReasoningModelRecurrentBlock(TransformerConfig(**(config.model_dump() | config.H_override)))
|
| 57 |
+
self.L_level = HierarchicalReasoningModelRecurrentBlock(config)
|
| 58 |
+
|
| 59 |
+
# Config
|
| 60 |
+
self.H_cycles = config.H_cycles
|
| 61 |
+
self.L_cycles = config.L_cycles
|
| 62 |
+
self.bp_warmup_ratio = config.bp_warmup_ratio
|
| 63 |
+
self.bp_min_steps = config.bp_min_steps
|
| 64 |
+
self.bp_max_steps = config.bp_max_steps
|
| 65 |
+
|
| 66 |
+
self.hidden_size = config.hidden_size
|
| 67 |
+
self.head_hint = self.H_level.core.head_hint # Hint for LMHead init (inherit from H)
|
| 68 |
+
|
| 69 |
+
self.zL_init = nn.Buffer(trunc_normal_init_(torch.empty(config.hidden_size, dtype=torch.bfloat16), std=1.0), persistent=True) # NOTE: hardcoded dtype.
|
| 70 |
+
|
| 71 |
+
# Create cache function
|
| 72 |
+
self.create_cache = lambda **kwargs: dict(H=[self.H_level.create_cache(**kwargs) for _i in range(self.H_cycles)],
|
| 73 |
+
L=[self.L_level.create_cache(**kwargs) for _i in range(self.H_cycles * self.L_cycles)])
|
| 74 |
+
|
| 75 |
+
def forward(self, carry: None, x: torch.Tensor, cache: Optional[dict[str, list[list[Cache]]]] = None, bp_steps: int = 2, **seq_info) -> Tuple[None, torch.Tensor]:
|
| 76 |
+
z_H, z_L = x, self.zL_init
|
| 77 |
+
|
| 78 |
+
# Calculate H and L bp_steps
|
| 79 |
+
# Priortize H, and at least 1 is allocated to L.
|
| 80 |
+
H_bp_steps = min(self.H_cycles, bp_steps - 1)
|
| 81 |
+
L_bp_steps = bp_steps - H_bp_steps
|
| 82 |
+
|
| 83 |
+
for i in range(self.H_cycles):
|
| 84 |
+
for k in range(i * self.L_cycles, (i + 1) * self.L_cycles):
|
| 85 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() and (k >= self.H_cycles * self.L_cycles - L_bp_steps)):
|
| 86 |
+
z_L = self.L_level(z_L, z_H, **seq_info, cache=cache["L"][k] if cache is not None else None)
|
| 87 |
+
|
| 88 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() and (i >= self.H_cycles - H_bp_steps)):
|
| 89 |
+
z_H = self.H_level(z_H, z_L, **seq_info, cache=cache["H"][i] if cache is not None else None)
|
| 90 |
+
|
| 91 |
+
return None, z_H
|
| 92 |
+
|
| 93 |
+
def compute_train_extra_args(self, train_state: Any) -> dict[str, Any]:
|
| 94 |
+
warmup_steps = train_state.total_steps * self.bp_warmup_ratio
|
| 95 |
+
progress = min(1.0, train_state.step / warmup_steps) if warmup_steps > 0 else 1.0
|
| 96 |
+
|
| 97 |
+
return dict(bp_steps=self.bp_min_steps + int(progress * (self.bp_max_steps - self.bp_min_steps)))
|
| 98 |
+
|
| 99 |
+
def initial_carry(self, batch_size: int, dtype: torch.dtype) -> None:
|
| 100 |
+
return None
|
train_metadata.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
max_seq_len: 4096
|
| 2 |
+
tokenizer_info:
|
| 3 |
+
boq: <|im_start|>
|
| 4 |
+
condition_mapping:
|
| 5 |
+
cot: <|object_ref_end|>
|
| 6 |
+
direct: <|object_ref_start|>
|
| 7 |
+
noisy: <|quad_start|>
|
| 8 |
+
synth: <|quad_end|>
|
| 9 |
+
eoa: <|box_end|>
|
| 10 |
+
eoq: <|im_end|>
|
| 11 |
+
tokenizer_path: /home/work/.data/hrm_text_prepared/sft_swe_glm_mix_v1
|
| 12 |
+
total_length: 251170780
|
| 13 |
+
vocab_size: 131072
|