Rain-v2 / README.md
raincandy-u's picture
Update README.md
ea1f810 verified
metadata
language:
  - en
license: apache-2.0
pipeline_tag: text-generation
library_name: transformers
datasets:
  - HuggingFaceFW/fineweb-edu
  - HuggingFaceTB/stack-edu
  - HuggingFaceTB/finemath
tags:
  - causal-lm
  - 100m-parameters
  - single-gpu-training
  - flashattention2
  - gqa
model-index:
  - name: Rain-v2
    results:
      - task:
          type: multiple-choice-qa
          name: ARC-Easy (5-shot)
        metrics:
          - type: accuracy
            value: 0.35-0.40
      - task:
          type: multiple-choice-qa
          name: HellaSwag (5-shot)
        metrics:
          - type: accuracy
            value: 0.28-0.30
      - task:
          type: multiple-choice-qa
          name: PIQA (5-shot)
        metrics:
          - type: accuracy
            value: 0.6
      - task:
          type: coreference-resolution
          name: Winogrande (5-shot)
        metrics:
          - type: accuracy
            value: 0.51-0.52

Rain-v2

Rain-v2 是一个约 1 亿参数的英文自回归语言模型,在 RTX 4090 约两天内完成预训练,展示了在有限算力下从数据到模型的完整实践路径。

模型与训练配置

  • 参数规模:≈100M
  • 架构:32 层解码器,隐藏维 512,8 头 GQA(4 个 KV 头),RoPE,RMSNorm,SwiGLU,输入/输出权重共享
  • 词表:自训 BPE,16,384 词,面向英文/代码/数学混合语料
  • 上下文长度:1024
  • 学习率调度:1% warmup + cosine decay
  • 训练总量:≈6.64×10^8 tokens,总用时 ~40 小时 @ RTX 4090

数据配比

  • FineWeb-Edu(高质量英文教育语料)60%
  • Stack-Edu(Python 教学代码/问答子集)30%
  • FineMath-4+(高质量数学/逻辑)10%

总量约 10 B。

评测摘要(5-shot)

  • ARC-Easy:40%
  • HellaSwag:30%
  • PIQA:60%
  • Winogrande: 51%

安全与限制

易输出错误事实或伪造信息。未经对齐,会生成偏见/有害/违法内容;请勿直接面向终端用户。

使用示例

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("raincandy-u/Rain-v2", torch_dtype=torch.bfloat16, device_map="auto")
tok = AutoTokenizer.from_pretrained("your-namespace/Rain-v2")

prompt = "Here's a fairy tale about a little pig. A long, long time ago, there was a little pig called "
inputs = tok(prompt, return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=120, temperature=0.8, top_p=0.9)
print(tok.decode(out[0], skip_special_tokens=True))