Pratap-K commited on
Commit
98fc9b6
·
0 Parent(s):

AutoMathReasoner

Browse files
.gitignore ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ .venv/
8
+ venv/
9
+ env/
10
+ ENV/
11
+ env.bak/
12
+ venv.bak/
13
+
14
+ # Environment variables
15
+ .env
16
+ .env.local
17
+
18
+ # Build/distribution directories
19
+ build/
20
+ dist/
21
+ *.egg-info/
22
+ .eggs/
23
+ eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+
32
+ # C extensions
33
+ *.so
34
+
35
+ # Unit test / coverage reports
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+ *.py,cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+ pytest_out*
49
+
50
+ # Machine Learning / Outputs
51
+ outputs/
52
+ colab_outputs/
53
+ wandb/
54
+ checkpoints/
55
+ *.pt
56
+ *.pth
57
+ *.safetensors
58
+ *.ckpt
59
+
60
+ # IDEs and Editors
61
+ .idea/
62
+ .vscode/
63
+ *.swp
64
+ *.swo
65
+ *~
66
+ .spyderproject
67
+ .spyproject
68
+
69
+ # OS generated files
70
+ .DS_Store
71
+ .DS_Store?
72
+ ._*
73
+ .Spotlight-V100
74
+ .Trashes
75
+ ehthumbs.db
76
+ Thumbs.db
77
+
78
+ #docs
79
+ docs
Dockerfile ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=AutoMathReasoner
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ #Enable Web Interface
75
+ ENV ENABLE_WEB_INTERFACE=true
76
+
77
+ # Health check
78
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
79
+ CMD curl -f http://localhost:7860/health || exit 1
80
+
81
+ # Run the FastAPI server
82
+ # The module path is constructed to work with the /app/env structure
83
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AutoMathReasoner Environment
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # ♾️ AutoMathReasoner: Self-Improving Mathematics RL Environment
12
+
13
+ **AutoMathReasoner** is an OpenEnv-compliant reinforcement learning server specifically formulated to bootstrap mathematical intelligence in Large Language Models (LLMs). Rooted in principles from DeepSeekMath and Group-Relative Policy Optimization (GRPO), it facilitates absolute, fully autonomous self-improvement through rigorous dense reward curves, exploration entropy, and curriculum scaling.
14
+
15
+ This repository wraps the environment architecture securely into a lightweight Docker-backed REST API for direct ingestion in Google Colab, SageMaker, or distributed compute arrays.
16
+
17
+ ---
18
+
19
+ ## 🏗️ Architecture Overview
20
+
21
+ The system strictly decouples the interactive RL environment from the learning engine. The `FastAPI` instance serves purely as the mathematical world simulation.
22
+
23
+ ```mermaid
24
+ graph TD
25
+ subgraph EnvAPI [OpenEnv API Space]
26
+ GE["Task Generator Engine"] -->|"Yields Math"| Server["FastAPI Server"]
27
+ Server -->|"Computes"| VR["Verifier System & Reward Logic"]
28
+ VR --> Server
29
+ end
30
+
31
+ subgraph ClientNode [Training Node e.g. Colab]
32
+ MD["Language Model Policy"] -->|"Action: Reason & Answer"| HG["HF GRPOTrainer"]
33
+ HG -->|"REST HTTP POST"| Server
34
+ Server -->|"Observation: Rewards"| HG
35
+ HG -->|"Log diff"| MD
36
+ end
37
+
38
+ classDef space fill:transparent,stroke:#9370DB,stroke-width:2px,stroke-dasharray: 5 5;
39
+ classDef client fill:transparent,stroke:#008B8B,stroke-width:2px,stroke-dasharray: 5 5;
40
+
41
+ class EnvAPI space
42
+ class ClientNode client
43
+ ```
44
+
45
+ ---
46
+
47
+ ## 🎯 Reward Composite Hierarchy (Graders)
48
+
49
+ Instead of binary scalar rewards (0 for incorrect, 1 for correct), the AutoMathReasoner relies on an aggressive mathematical dense reward architecture designed to shape logical structures rather than just end targets.
50
+
51
+ The absolute reward matrix evaluates as:
52
+
53
+ $$R = 0.35C + 0.15\tanh(Q) + 0.1P + 0.1R_{\text{ref}} + 0.15D - 0.05E + 0.1X + \mathcal{N}(0, \sigma^2)$$
54
+
55
+ ### Individual Mathematical Graders
56
+
57
+ - **Correctness ($C$):** $C \in \{0.0, 1.0\}$. Passed through an exact match, numeric bound tolerance limit, and generic python evaluation. E.g. correctly evaluating `3.1415 = 3.14159`.
58
+ - **Reasoning Squashing ($Q_{\text{smooth}}$):** $Q_{\text{smooth}} = \tanh(Q)$. Uses hyperbolic tangent functions bounding heuristic step-formatting markers to ensure extreme verbosity does not dominate correctness.
59
+ - **Process Supervision ($P$):** A step-aware structural logic test that algorithmically assigns $-0.5$ scalar penalties for hallucinatory inferential jumps.
60
+ - **Reflection Parsing ($R_{\text{ref}}$):** Tracks deducing logic boundaries ("Wait", "What could be wrong"). Rewards $+1.0$ for successful self-correction routing, and $-0.5$ if it reflects into a broken contradiction.
61
+ - **Entropic Exploration ($X$):** Rewards unique reasoning path token variance mapped dynamically against historical encounter probability:
62
+ $$X = \frac{\log(1 + \text{unique\_ratio})}{\sqrt{1 + \text{times\_seen\_problem}}}$$
63
+ - **Token Efficiency Penalty ($E$):** Penalizes overly verbose traces dynamically. It anchors outputs safely against a $50$-token optimal length via an inverse negative Gaussian curve:
64
+ $$E = \exp\left(-\left(\frac{\text{approx\_tokens} - 50}{50}\right)^2\right) - 1.0$$
65
+ - **History Diversity ($D$):** Employs strict, absolute mathematical blocks against network hacking and identical solution repetition loops:
66
+ $$D = \begin{cases} -\exp(1.0) & \text{if answer repeats exactly} \\ 1.0 & \text{otherwise} \end{cases}$$
67
+
68
+ ---
69
+
70
+ ## 🔄 Self-Curriculum Training Loop
71
+
72
+ The pipeline intrinsically manages mathematical difficulty scaling while systematically applying ReST-Style trajectory filtration to block network poisoning.
73
+
74
+ ```mermaid
75
+ sequenceDiagram
76
+ participant Model as P-Model
77
+ participant Buffer as Replay/LADDER Buffer
78
+ participant Env as AutoMath Env OpenEnv
79
+
80
+ loop Episodic Batch GRPO
81
+ Env->>Model: Emit Algebra Prompt (Diff=2.0)
82
+ Model->>Env: Rollout K=4 Completion Traces
83
+
84
+ Note over Env: Execute Process Supervision<br>Determine Majority Sample Output
85
+ Env-->>Model: Return Normalized Reward Arrays
86
+
87
+ Model->>Model: Compute Relative Log Likelihood
88
+ Model->>Model: LoRA Gradient Step
89
+
90
+ alt is_correct == 1 AND Q_reasoning > 0.6
91
+ Model->>Buffer: Store Trajectory (ReST/LADDER)
92
+ else
93
+ Model->>Buffer: Store as Hard Negative Mine
94
+ end
95
+ end
96
+
97
+ loop Curriculum Scaling Tick
98
+ Note over Env: If Mean Rolling Accuracy >= 65%
99
+ Env->>Env: Diff = Diff + 0.5 (Generate advanced word problems)
100
+ end
101
+ ```
102
+
103
+ ---
104
+
105
+ ## 💻 Steps to Get the Code Running on Your System
106
+
107
+ ### 1. Initialize the Environment Server Locally
108
+
109
+ You can launch the core OpenEnv FastAPI server effortlessly using `uv` to orchestrate dependencies automatically. This handles environment states entirely.
110
+
111
+ ```bash
112
+ # Clone the repository
113
+ git clone https://github.com/yourusername/AutoMathReasoner.git
114
+ cd AutoMathReasoner
115
+
116
+ # Install native editable package bindings via uv
117
+ uv pip install -e .
118
+
119
+ # Launch the FastAPI Server Engine
120
+ uv run server
121
+ ```
122
+ _The server is now live at `http://localhost:7860`. You can visit `http://localhost:7860/docs` to view the raw interactive environment endpoints._
123
+
124
+ ### 2. Begin Reinforcement Learning (GRPO)
125
+
126
+ Once your server is running (either locally or deployed to Hugging Face Spaces), execute the automated GRPO rollout.
127
+
128
+ To execute the free-tier Colab notebook simulation pointing back at your running server:
129
+ ```bash
130
+ # In an entirely separate terminal
131
+ python train/colab_train.py
132
+ ```
133
+ *(Ensure `HF_SPACE_URL` in `train/colab_train.py` points to your `http://localhost:7860` or deployed Space domain!)*
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Automathreasoner Environment."""
8
+
9
+ from .client import AutomathreasonerEnv
10
+ from .env.models import AutomathreasonerAction, AutomathreasonerObservation
11
+
12
+ __all__ = [
13
+ "AutomathreasonerAction",
14
+ "AutomathreasonerObservation",
15
+ "AutomathreasonerEnv",
16
+ ]
client.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Automathreasoner Environment Client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from .env.models import AutomathreasonerAction, AutomathreasonerObservation
16
+
17
+
18
+ class AutomathreasonerEnv(
19
+ EnvClient[AutomathreasonerAction, AutomathreasonerObservation, State]
20
+ ):
21
+ """
22
+ Client for the Automathreasoner Environment.
23
+
24
+ This client maintains a persistent WebSocket connection to the environment server,
25
+ enabling efficient multi-step interactions with lower latency.
26
+ Each client instance has its own dedicated environment session on the server.
27
+
28
+ Example:
29
+ >>> # Connect to a running server
30
+ >>> with AutomathreasonerEnv(base_url="http://localhost:7860") as client:
31
+ ... result = client.reset()
32
+ ... print(result.observation.echoed_message)
33
+ ...
34
+ ... result = client.step(AutomathreasonerAction(message="Hello!"))
35
+ ... print(result.observation.echoed_message)
36
+
37
+ Example with Docker:
38
+ >>> # Automatically start container and connect
39
+ >>> client = AutomathreasonerEnv.from_docker_image("AutoMathReasoner-env:latest")
40
+ >>> try:
41
+ ... result = client.reset()
42
+ ... result = client.step(AutomathreasonerAction(message="Test"))
43
+ ... finally:
44
+ ... client.close()
45
+ """
46
+
47
+ def _step_payload(self, action: AutomathreasonerAction) -> Dict:
48
+ """
49
+ Convert AutomathreasonerAction to JSON payload for step message.
50
+
51
+ Args:
52
+ action: AutomathreasonerAction instance
53
+
54
+ Returns:
55
+ Dictionary representation suitable for JSON encoding
56
+ """
57
+ return {
58
+ "reasoning": action.reasoning,
59
+ "final_answer": action.final_answer,
60
+ }
61
+
62
+ def _parse_result(self, payload: Dict) -> StepResult[AutomathreasonerObservation]:
63
+ """
64
+ Parse server response into StepResult[AutomathreasonerObservation].
65
+
66
+ Args:
67
+ payload: JSON response data from server
68
+
69
+ Returns:
70
+ StepResult with AutomathreasonerObservation
71
+ """
72
+ obs_data = payload.get("observation", {})
73
+ observation = AutomathreasonerObservation(
74
+ problem_text=obs_data.get("problem_text", ""),
75
+ difficulty_level=obs_data.get("difficulty_level", 1.0),
76
+ history=obs_data.get("history", []),
77
+ done=payload.get("done", False),
78
+ reward=payload.get("reward", 0.0),
79
+ metadata=obs_data.get("metadata", {}),
80
+ )
81
+
82
+ return StepResult(
83
+ observation=observation,
84
+ reward=payload.get("reward"),
85
+ done=payload.get("done", False),
86
+ )
87
+
88
+ def _parse_state(self, payload: Dict) -> State:
89
+ """
90
+ Parse server response into State object.
91
+
92
+ Args:
93
+ payload: JSON response from state request
94
+
95
+ Returns:
96
+ State object with episode_id and step_count
97
+ """
98
+ return State(
99
+ episode_id=payload.get("episode_id"),
100
+ step_count=payload.get("step_count", 0),
101
+ )
config/openenv.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env:
2
+ name: "AutoMathReasoner"
3
+ author: "Meta Hackathon User"
4
+ description: "A self-improving math reasoning environment that dynamically generates tasks, tracking accuracy to provide curriculum learning for RL agents."
5
+ version: "1.0.0"
6
+
7
+ server:
8
+ host: "0.0.0.0"
9
+ port: 7860
10
+ workers: 4
11
+ module: "server.app:app"
12
+
13
+ features:
14
+ multi_reward: true
15
+ prevent_hacking: true
16
+ curriculum_scheduler: true
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: AutoMathReasoner
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 7860
7
+
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-AutoMathReasoner"
13
+ version = "0.1.0"
14
+ description = "Automathreasoner environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m AutoMathReasoner.server.app
40
+ server = "AutoMathReasoner.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["AutoMathReasoner", "AutoMathReasoner.server", "AutoMathReasoner.env"]
45
+ package-dir = { "AutoMathReasoner" = ".", "AutoMathReasoner.server" = "server", "AutoMathReasoner.env" = "env" }
requirements.txt ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --no-hashes -o requirements.txt
3
+ -e .
4
+ aiofile==3.9.0
5
+ # via py-key-value-aio
6
+ annotated-doc==0.0.4
7
+ # via
8
+ # fastapi
9
+ # typer
10
+ annotated-types==0.7.0
11
+ # via pydantic
12
+ anyio==4.13.0
13
+ # via
14
+ # gradio
15
+ # httpx
16
+ # mcp
17
+ # openai
18
+ # py-key-value-aio
19
+ # sse-starlette
20
+ # starlette
21
+ # watchfiles
22
+ attrs==26.1.0
23
+ # via
24
+ # cyclopts
25
+ # jsonschema
26
+ # referencing
27
+ audioop-lts==0.2.2 ; python_full_version >= '3.13'
28
+ # via gradio
29
+ authlib==1.7.0
30
+ # via fastmcp
31
+ backports-tarfile==1.2.0 ; python_full_version < '3.12'
32
+ # via jaraco-context
33
+ beartype==0.22.9
34
+ # via py-key-value-aio
35
+ brotli==1.2.0
36
+ # via gradio
37
+ cachetools==7.0.6
38
+ # via py-key-value-aio
39
+ caio==0.9.25
40
+ # via aiofile
41
+ certifi==2026.4.22
42
+ # via
43
+ # httpcore
44
+ # httpx
45
+ # requests
46
+ cffi==2.0.0 ; platform_python_implementation != 'PyPy'
47
+ # via cryptography
48
+ charset-normalizer==3.4.7
49
+ # via requests
50
+ click==8.3.3
51
+ # via
52
+ # typer
53
+ # uvicorn
54
+ colorama==0.4.6 ; sys_platform == 'win32'
55
+ # via
56
+ # click
57
+ # tqdm
58
+ cryptography==46.0.7
59
+ # via
60
+ # authlib
61
+ # joserfc
62
+ # pyjwt
63
+ # secretstorage
64
+ cyclopts==4.11.0
65
+ # via fastmcp
66
+ distro==1.9.0
67
+ # via openai
68
+ dnspython==2.8.0
69
+ # via email-validator
70
+ docstring-parser==0.18.0
71
+ # via cyclopts
72
+ docutils==0.22.4
73
+ # via rich-rst
74
+ email-validator==2.3.0
75
+ # via pydantic
76
+ exceptiongroup==1.3.1
77
+ # via
78
+ # anyio
79
+ # fastmcp
80
+ fastapi==0.136.0
81
+ # via
82
+ # gradio
83
+ # openenv-core
84
+ fastmcp==3.2.4
85
+ # via openenv-core
86
+ filelock==3.29.0
87
+ # via huggingface-hub
88
+ fsspec==2026.3.0
89
+ # via
90
+ # gradio-client
91
+ # huggingface-hub
92
+ gradio==6.13.0
93
+ # via openenv-core
94
+ gradio-client==2.5.0
95
+ # via
96
+ # gradio
97
+ # hf-gradio
98
+ griffelib==2.0.2
99
+ # via fastmcp
100
+ groovy==0.1.2
101
+ # via gradio
102
+ h11==0.16.0
103
+ # via
104
+ # httpcore
105
+ # uvicorn
106
+ hf-gradio==0.4.1
107
+ # via gradio
108
+ hf-xet==1.4.3 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
109
+ # via huggingface-hub
110
+ httpcore==1.0.9
111
+ # via httpx
112
+ httpx==0.28.1
113
+ # via
114
+ # fastmcp
115
+ # gradio
116
+ # gradio-client
117
+ # huggingface-hub
118
+ # mcp
119
+ # openai
120
+ # openenv-core
121
+ # safehttpx
122
+ httpx-sse==0.4.3
123
+ # via mcp
124
+ huggingface-hub==1.11.0
125
+ # via
126
+ # gradio
127
+ # gradio-client
128
+ # openenv-core
129
+ idna==3.13
130
+ # via
131
+ # anyio
132
+ # email-validator
133
+ # httpx
134
+ # requests
135
+ importlib-metadata==8.7.1
136
+ # via
137
+ # keyring
138
+ # opentelemetry-api
139
+ jaraco-classes==3.4.0
140
+ # via keyring
141
+ jaraco-context==6.1.2
142
+ # via keyring
143
+ jaraco-functools==4.4.0
144
+ # via keyring
145
+ jeepney==0.9.0 ; sys_platform == 'linux'
146
+ # via
147
+ # keyring
148
+ # secretstorage
149
+ jinja2==3.1.6
150
+ # via gradio
151
+ jiter==0.14.0
152
+ # via openai
153
+ joserfc==1.6.4
154
+ # via authlib
155
+ jsonref==1.1.0
156
+ # via fastmcp
157
+ jsonschema==4.26.0
158
+ # via mcp
159
+ jsonschema-path==0.4.5
160
+ # via fastmcp
161
+ jsonschema-specifications==2025.9.1
162
+ # via jsonschema
163
+ keyring==25.7.0
164
+ # via py-key-value-aio
165
+ markdown-it-py==4.0.0
166
+ # via rich
167
+ markupsafe==3.0.3
168
+ # via
169
+ # gradio
170
+ # jinja2
171
+ mcp==1.27.0
172
+ # via fastmcp
173
+ mdurl==0.1.2
174
+ # via markdown-it-py
175
+ more-itertools==11.0.2
176
+ # via
177
+ # jaraco-classes
178
+ # jaraco-functools
179
+ numpy==2.2.6 ; python_full_version < '3.11'
180
+ # via
181
+ # gradio
182
+ # pandas
183
+ numpy==2.4.4 ; python_full_version >= '3.11'
184
+ # via
185
+ # gradio
186
+ # pandas
187
+ openai==2.32.0
188
+ # via openenv-core
189
+ openapi-pydantic==0.5.1
190
+ # via fastmcp
191
+ openenv-core==0.2.3
192
+ # via openenv-automathreasoner
193
+ opentelemetry-api==1.41.0
194
+ # via fastmcp
195
+ orjson==3.11.8
196
+ # via gradio
197
+ packaging==26.1
198
+ # via
199
+ # fastmcp
200
+ # gradio
201
+ # gradio-client
202
+ # huggingface-hub
203
+ pandas==2.3.3 ; python_full_version < '3.11'
204
+ # via gradio
205
+ pandas==3.0.2 ; python_full_version >= '3.11'
206
+ # via gradio
207
+ pathable==0.5.0
208
+ # via jsonschema-path
209
+ pillow==12.2.0
210
+ # via gradio
211
+ platformdirs==4.9.6
212
+ # via fastmcp
213
+ py-key-value-aio==0.4.4
214
+ # via fastmcp
215
+ pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
216
+ # via cffi
217
+ pydantic==2.13.3
218
+ # via
219
+ # fastapi
220
+ # fastmcp
221
+ # gradio
222
+ # mcp
223
+ # openai
224
+ # openapi-pydantic
225
+ # openenv-core
226
+ # pydantic-settings
227
+ pydantic-core==2.46.3
228
+ # via pydantic
229
+ pydantic-settings==2.14.0
230
+ # via mcp
231
+ pydub==0.25.1
232
+ # via gradio
233
+ pygments==2.20.0
234
+ # via rich
235
+ pyjwt==2.12.1
236
+ # via mcp
237
+ pyperclip==1.11.0
238
+ # via fastmcp
239
+ python-dateutil==2.9.0.post0
240
+ # via pandas
241
+ python-dotenv==1.2.2
242
+ # via
243
+ # fastmcp
244
+ # pydantic-settings
245
+ python-multipart==0.0.26
246
+ # via
247
+ # gradio
248
+ # mcp
249
+ pytz==2026.1.post1
250
+ # via
251
+ # gradio
252
+ # pandas
253
+ pywin32==311 ; sys_platform == 'win32'
254
+ # via mcp
255
+ pywin32-ctypes==0.2.3 ; sys_platform == 'win32'
256
+ # via keyring
257
+ pyyaml==6.0.3
258
+ # via
259
+ # fastmcp
260
+ # gradio
261
+ # huggingface-hub
262
+ # jsonschema-path
263
+ # openenv-core
264
+ referencing==0.37.0
265
+ # via
266
+ # jsonschema
267
+ # jsonschema-path
268
+ # jsonschema-specifications
269
+ requests==2.33.1
270
+ # via openenv-core
271
+ rich==15.0.0
272
+ # via
273
+ # cyclopts
274
+ # fastmcp
275
+ # openenv-core
276
+ # rich-rst
277
+ # typer
278
+ rich-rst==1.3.2
279
+ # via cyclopts
280
+ rpds-py==0.30.0
281
+ # via
282
+ # jsonschema
283
+ # referencing
284
+ safehttpx==0.1.7
285
+ # via gradio
286
+ secretstorage==3.5.0 ; sys_platform == 'linux'
287
+ # via keyring
288
+ semantic-version==2.10.0
289
+ # via gradio
290
+ shellingham==1.5.4
291
+ # via typer
292
+ six==1.17.0
293
+ # via python-dateutil
294
+ sniffio==1.3.1
295
+ # via openai
296
+ sse-starlette==3.3.4
297
+ # via mcp
298
+ starlette==1.0.0
299
+ # via
300
+ # fastapi
301
+ # gradio
302
+ # mcp
303
+ # sse-starlette
304
+ tomli==2.4.1
305
+ # via
306
+ # cyclopts
307
+ # openenv-core
308
+ tomli-w==1.2.0
309
+ # via openenv-core
310
+ tomlkit==0.14.0
311
+ # via gradio
312
+ tqdm==4.67.3
313
+ # via
314
+ # huggingface-hub
315
+ # openai
316
+ typer==0.24.2
317
+ # via
318
+ # gradio
319
+ # hf-gradio
320
+ # huggingface-hub
321
+ # openenv-core
322
+ typing-extensions==4.15.0
323
+ # via
324
+ # anyio
325
+ # cryptography
326
+ # cyclopts
327
+ # exceptiongroup
328
+ # fastapi
329
+ # gradio
330
+ # gradio-client
331
+ # huggingface-hub
332
+ # mcp
333
+ # openai
334
+ # opentelemetry-api
335
+ # py-key-value-aio
336
+ # pydantic
337
+ # pydantic-core
338
+ # pyjwt
339
+ # referencing
340
+ # starlette
341
+ # typing-inspection
342
+ # uvicorn
343
+ typing-inspection==0.4.2
344
+ # via
345
+ # fastapi
346
+ # mcp
347
+ # pydantic
348
+ # pydantic-settings
349
+ tzdata==2026.1 ; python_full_version < '3.11' or sys_platform == 'emscripten' or sys_platform == 'win32'
350
+ # via pandas
351
+ uncalled-for==0.3.1
352
+ # via fastmcp
353
+ urllib3==2.6.3
354
+ # via requests
355
+ uvicorn==0.46.0
356
+ # via
357
+ # fastmcp
358
+ # gradio
359
+ # mcp
360
+ # openenv-core
361
+ watchfiles==1.1.1
362
+ # via fastmcp
363
+ websockets==16.0
364
+ # via
365
+ # fastmcp
366
+ # openenv-core
367
+ zipp==3.23.1
368
+ # via importlib-metadata
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Automathreasoner environment server components."""
8
+
9
+ from AutoMathReasoner.env.environment import AutomathreasonerEnvironment
10
+
11
+ __all__ = ["AutomathreasonerEnvironment"]
server/app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Automathreasoner Environment.
9
+
10
+ This module creates an HTTP server that exposes the AutomathreasonerEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 7860
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 7860 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ from AutoMathReasoner.env.models import AutomathreasonerAction, AutomathreasonerObservation
39
+ from AutoMathReasoner.env.environment import AutomathreasonerEnvironment
40
+
41
+
42
+ # Create the app with web interface and README integration
43
+ app = create_app(
44
+ AutomathreasonerEnvironment,
45
+ AutomathreasonerAction,
46
+ AutomathreasonerObservation,
47
+ env_name="AutoMathReasoner",
48
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
49
+ )
50
+
51
+
52
+ def main(host: str = "0.0.0.0", port: int = 7860):
53
+ """
54
+ Entry point for direct execution via uv run or python -m.
55
+
56
+ This function enables running the server without Docker:
57
+ uv run --project . server
58
+ uv run --project . server --port 8001
59
+ python -m AutoMathReasoner.server.app
60
+
61
+ Args:
62
+ host: Host address to bind to (default: "0.0.0.0")
63
+ port: Port number to listen on (default: 7860)
64
+
65
+ For production deployments, consider using uvicorn directly with
66
+ multiple workers:
67
+ uvicorn AutoMathReasoner.server.app:app --workers 4
68
+ """
69
+ import uvicorn
70
+
71
+ uvicorn.run(app, host=host, port=port)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ import argparse
76
+
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument("--port", type=int, default=7860)
79
+ args = parser.parse_args()
80
+ main(port=args.port)
tests/test_env.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
+
5
+ from env.generator import TaskGenerationEngine
6
+ from env.verifier import VerifierSystem
7
+ from env.rewards import RewardSystem
8
+ from env.environment import AutomathreasonerEnvironment
9
+ from env.models import AutomathreasonerAction
10
+
11
+ def test_generator():
12
+ engine = TaskGenerationEngine()
13
+
14
+ # Test arithmetic
15
+ prob, diff, ans = engine.generate_arithmetic(complexity=1)
16
+ assert prob and ans
17
+
18
+ # Test overall generate task
19
+ task = engine.generate_task(target_difficulty_band=2.0)
20
+ assert "problem" in task
21
+ assert "solution" in task
22
+ assert "difficulty" in task
23
+
24
+ def test_verifier():
25
+ verifier = VerifierSystem()
26
+
27
+ # Exact match
28
+ assert verifier.check_exact_match("42", "42")
29
+ assert verifier.check_exact_match(" 42 ", "42")
30
+
31
+ # Numeric tolerance
32
+ assert verifier.check_numeric_tolerance("3.14159", "3.1415")
33
+ assert not verifier.check_numeric_tolerance("4.1415", "3.1415")
34
+
35
+ # Python execution
36
+ assert verifier.check_python_execution("2 + 2", "4")
37
+
38
+ # Full verification
39
+ c, q = verifier.verify("Because 2 + 2 is 4", "4", "4")
40
+ assert c == 1.0
41
+ assert q > 0.0 # Should have some mock reasoning score
42
+
43
+ def test_rewards():
44
+ reward_sys = RewardSystem(max_len=1000)
45
+ history = [{"final_answer": "42"}]
46
+
47
+ # Test diversity drop on repeat
48
+ d = reward_sys.compute_diversity("42", history)
49
+ assert d == -1.0
50
+
51
+ # Normal compute
52
+ r, comps = reward_sys.compute_reward(
53
+ correctness=1.0,
54
+ reasoning_quality=1.0,
55
+ action_str="step 1: do math. = 42",
56
+ final_answer="42",
57
+ history=[],
58
+ times_seen_problem=0
59
+ )
60
+ assert r > 0.0
61
+
62
+ def test_environment_step():
63
+ env = AutomathreasonerEnvironment()
64
+ obs = env.reset()
65
+
66
+ assert obs.problem_text != ""
67
+ assert obs.difficulty_level > 0
68
+ assert len(obs.history) == 0
69
+
70
+ # Create action where they just pass dummy stuff
71
+ action = AutomathreasonerAction(
72
+ reasoning="I am guessing the answer.",
73
+ final_answer="0"
74
+ )
75
+
76
+ obs_after = env.step(action)
77
+ assert obs_after.reward is not None
78
+ assert len(obs_after.history) == 1
79
+ assert "reward_components" in obs_after.metadata
train/colab_train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Colab Training Script for AutoMathReasoner (Hugging Face Space + Free T4 GPU)
3
+
4
+ Instructions for Colab:
5
+ 1. Create a new Google Colab notebook (Free Tier: T4 GPU is supported by Unsloth)
6
+ 2. Run the following installation commands in your first cell:
7
+
8
+ !pip install unsloth "trl<0.9.0"
9
+ !pip install openenv-core pydantic httpx
10
+ !git clone <YOUR-GITHUB-REPO-URL>
11
+ !cd AutoMathReasoner && pip install -e .
12
+
13
+ 3. Run the following Python script in the next cell.
14
+ """
15
+
16
+ import collections
17
+ import random
18
+ from datasets import Dataset
19
+ import torch
20
+
21
+ # Unsloth & TRL
22
+ from unsloth import FastLanguageModel
23
+ from trl import GRPOConfig, GRPOTrainer
24
+
25
+ # AutoMathReasoner OpenEnv Client
26
+ import sys
27
+ sys.path.append("./AutoMathReasoner")
28
+ from AutoMathReasoner.client import AutomathreasonerEnv
29
+ from AutoMathReasoner.env.models import AutomathreasonerAction
30
+
31
+ # 1. Configuration
32
+ # Replace with your actual Hugging Face Space URL!
33
+ HF_SPACE_URL = "https://your-username-automathreasoner.hf.space"
34
+ env = AutomathreasonerEnv(url=HF_SPACE_URL)
35
+
36
+ max_seq_length = 1024 # Fits well within Colab T4 16GB VRAM limit
37
+ lora_rank = 16
38
+
39
+ # 2. Load Model via Unsloth (optimized for Free Colab VRAM)
40
+ print("Loading model via Unsloth...")
41
+ model, tokenizer = FastLanguageModel.from_pretrained(
42
+ model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit", # Pre-quantized 4bit for fast download
43
+ max_seq_length = max_seq_length,
44
+ dtype = None,
45
+ load_in_4bit = True,
46
+ )
47
+
48
+ # Enable LoRA fine-tuning
49
+ model = FastLanguageModel.get_peft_model(
50
+ model,
51
+ r = lora_rank,
52
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
53
+ "gate_proj", "up_proj", "down_proj"],
54
+ lora_alpha = lora_rank,
55
+ use_gradient_checkpointing = "unsloth", # Crucial for fitting into T4
56
+ )
57
+
58
+ # 3. Prepare Dummy Prompts from the Remote Environment
59
+ print("Gathering initial prompts from HF Space environment...")
60
+ initial_prompts = []
61
+ for _ in range(30):
62
+ # This fires an HTTP request to your Hugging Face Space
63
+ obs = env.reset()
64
+ initial_prompts.append({"prompt": obs.problem_text})
65
+
66
+ dataset = Dataset.from_list(initial_prompts)
67
+
68
+ # 4. Define Reward Function for TRL
69
+ def compute_rewards(prompts, completions, **kwargs):
70
+ """
71
+ Interfaces with the OpenEnv running on Hugging Face Spaces.
72
+ Extracts the generation, passes it via HTTP to the env, and yields the dense reward.
73
+ """
74
+ rewards = []
75
+ parsed_actions = []
76
+ prompt_answers = collections.defaultdict(list)
77
+
78
+ # Track completion variants
79
+ for prompt, completion in zip(prompts, completions):
80
+ try:
81
+ parts = completion.split("Answer:")
82
+ reasoning = parts[0].strip()
83
+ answer = parts[1].strip() if len(parts) > 1 else ""
84
+ except Exception:
85
+ reasoning = completion
86
+ answer = ""
87
+
88
+ parsed_actions.append((prompt, completion, reasoning, answer))
89
+ prompt_answers[prompt].append(answer)
90
+
91
+ majority_answers = {}
92
+ for p, ans_list in prompt_answers.items():
93
+ if ans_list:
94
+ majority_answers[p] = collections.Counter(ans_list).most_common(1)[0][0]
95
+
96
+ for p, c, r, a in parsed_actions:
97
+ action = AutomathreasonerAction(reasoning=r, final_answer=a)
98
+
99
+ # In a real environment mapping, we would initialize the episode with the specific prompt.
100
+ # But for REST API environments, we simply reset and forcefully simulate.
101
+ obs = env.reset()
102
+
103
+ # Step through HTTP API
104
+ step_obs = env.step(action)
105
+ r_total = step_obs.reward
106
+
107
+ # Self-consistency matching bonus
108
+ majority = majority_answers.get(p, "")
109
+ if (a == majority) and len(a) > 0:
110
+ r_total += 0.2
111
+
112
+ rewards.append(r_total)
113
+
114
+ return rewards
115
+
116
+ # 5. Execute Training
117
+ training_args = GRPOConfig(
118
+ output_dir="colab_outputs",
119
+ learning_rate=2e-5,
120
+ per_device_train_batch_size=1, # 1 for Colab GPUs to prevent OOM
121
+ gradient_accumulation_steps=4,
122
+ max_prompt_length=128,
123
+ max_completion_length=256,
124
+ num_generations=4, # K=4 (Reduced from 8 for Colab T4 Memory limitations)
125
+ max_steps=150,
126
+ logging_steps=10,
127
+ optim="adamw_8bit", # 8-bit optimizer saves VRAM
128
+ )
129
+
130
+ trainer = GRPOTrainer(
131
+ model=model,
132
+ reward_funcs=[compute_rewards],
133
+ args=training_args,
134
+ train_dataset=dataset,
135
+ )
136
+
137
+ print("Starting GRPO Training in Colab using Remote HF Environment...")
138
+ # Will show wandb/tensorboard logging so you can prove "it is actually learning"
139
+ trainer.train()
140
+
141
+ # 6. Push to Hugging Face
142
+ # Optional: save locally or push to Hub after it learns
143
+ # model.push_to_hub("your-name/AutoMathReasoner-Trained")
train/sft_warm_start.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from trl import SFTTrainer, SFTConfig
3
+ from unsloth import FastLanguageModel
4
+
5
+ def main():
6
+ max_seq_length = 1024
7
+
8
+ # Load model and tokenizer
9
+ model, tokenizer = FastLanguageModel.from_pretrained(
10
+ model_name = "llama-3-8b-instruct",
11
+ max_seq_length = max_seq_length,
12
+ dtype = None,
13
+ load_in_4bit = True,
14
+ )
15
+
16
+ # We use a subset of GSM8K style data to warm start the reasoning format
17
+ # In practice, this would load a custom generated dataset locally
18
+ try:
19
+ dataset = load_dataset("gsm8k", "main", split="train[:5%]")
20
+ except Exception:
21
+ # Fallback dummy dataset
22
+ dataset = load_dataset("json", data_files={"train": ["dummy.json"]}, split="train")
23
+
24
+ def formatting_prompts_func(examples):
25
+ texts = []
26
+ for q, a in zip(examples['question'], examples['answer']):
27
+ # Assuming 'answer' has reasoning and then '#### answer'
28
+ parts = a.split("####")
29
+ reasoning = parts[0].strip()
30
+ final_answer = parts[1].strip() if len(parts) > 1 else ""
31
+
32
+ text = f"Problem: {q}\nReasoning: {reasoning}\nAnswer: {final_answer}"
33
+ texts.append(text)
34
+ return { "text" : texts }
35
+
36
+ dataset = dataset.map(formatting_prompts_func, batched = True)
37
+
38
+ training_args = SFTConfig(
39
+ output_dir="sft_outputs",
40
+ dataset_text_field="text",
41
+ max_seq_length=max_seq_length,
42
+ per_device_train_batch_size=2,
43
+ max_steps=100,
44
+ learning_rate=2e-5,
45
+ )
46
+
47
+ trainer = SFTTrainer(
48
+ model=model,
49
+ train_dataset=dataset,
50
+ args=training_args,
51
+ )
52
+
53
+ print("Starting SFT Warm-Start...")
54
+ trainer.train()
55
+
56
+ if __name__ == "__main__":
57
+ main()
train/train_grpo.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import collections
3
+ import torch
4
+ import numpy as np
5
+ from datasets import Dataset
6
+ from trl import GRPOTrainer, GRPOConfig
7
+ from unsloth import FastLanguageModel
8
+
9
+ import sys
10
+ import os
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ from env.environment import AutomathreasonerEnvironment
14
+ from env.models import AutomathreasonerAction
15
+
16
+ class ReplayBuffer:
17
+ def __init__(self):
18
+ self.ladder_buffer = [] # A. LADDER-STYLE self-bootstrapping buffer
19
+ self.failed = [] # F. HARD NEGATIVE MINING buffer
20
+ self.all_history = []
21
+
22
+ def add_ladder(self, item):
23
+ """
24
+ [PAPER TRACEABILITY: LADDER-Style Self-Bootstrapping]
25
+ Stores only high-quality trajectories.
26
+ """
27
+ self.ladder_buffer.append(item)
28
+ # Keep top 20% effectively by hard capping and sorting if applicable
29
+ # Simplistic version: Just keep recent highest
30
+ if len(self.ladder_buffer) > 200:
31
+ self.ladder_buffer.sort(key=lambda x: x['reward'], reverse=True)
32
+ self.ladder_buffer = self.ladder_buffer[:100]
33
+
34
+ def add(self, problem, best_solution, failed_attempts, reward=0.0):
35
+ item = {
36
+ "prompt": problem,
37
+ "best_solution": best_solution,
38
+ "failed_attempts": failed_attempts,
39
+ "reward": reward
40
+ }
41
+ self.all_history.append(item)
42
+
43
+ # F. HARD NEGATIVE MINING
44
+ # Prioritize tracking failed problems
45
+ if failed_attempts:
46
+ # We explicitly track failures to reintroduce them
47
+ self.failed.append(item)
48
+ if len(self.failed) > 200:
49
+ self.failed.pop(0)
50
+
51
+ def sample(self, batch_size) -> list:
52
+ """
53
+ [PAPER TRACEABILITY: Hard Negative Mining]
54
+ Samples from Ladder/High-quality, Failed, and Random.
55
+ """
56
+ if len(self.all_history) < batch_size:
57
+ return self.all_history
58
+
59
+ n_ladder = int(batch_size * 0.5)
60
+ n_failed = int(batch_size * 0.3)
61
+ n_random = batch_size - n_ladder - n_failed
62
+
63
+ batch = []
64
+ batch.extend(random.choices(self.ladder_buffer if self.ladder_buffer else self.all_history, k=n_ladder))
65
+ batch.extend(random.choices(self.failed if self.failed else self.all_history, k=n_failed))
66
+ batch.extend(random.choices(self.all_history, k=n_random))
67
+
68
+ return batch
69
+
70
+ def main():
71
+ max_seq_length = 1024
72
+ # Load model via Unsloth
73
+ model, tokenizer = FastLanguageModel.from_pretrained(
74
+ model_name = "llama-3-8b-instruct",
75
+ max_seq_length = max_seq_length,
76
+ dtype = None,
77
+ load_in_4bit = True,
78
+ )
79
+
80
+ env = AutomathreasonerEnvironment()
81
+ replay_buffer = ReplayBuffer()
82
+
83
+ # Generate some initial experiences
84
+ initial_prompts = []
85
+ for _ in range(50):
86
+ obs = env.reset()
87
+ initial_prompts.append({"prompt": obs.problem_text})
88
+
89
+ dataset = Dataset.from_list(initial_prompts)
90
+
91
+ def compute_rewards(prompts, completions, **kwargs):
92
+ """
93
+ [PAPER TRACEABILITY: GRPO (Group-Relative Policy Optimization)]
94
+ D. GROUP-RELATIVE TRAINING
95
+ TRL GRPOTrainer automatically handles the relative optimization aspect:
96
+ log π(best) − log π(worst) by using the normalized rewards returned here.
97
+ """
98
+ rewards = []
99
+
100
+ # C. SELF-CONSISTENCY SAMPLING
101
+ # We group generated outputs by prompt to find the majority answer
102
+ # TRL provides completions aligned with prompts. Usually completions are batched by K per prompt.
103
+ prompt_answers = collections.defaultdict(list)
104
+
105
+ parsed_actions = []
106
+ for prompt, completion in zip(prompts, completions):
107
+ try:
108
+ parts = completion.split("Answer:")
109
+ reasoning = parts[0].strip()
110
+ answer = parts[1].strip() if len(parts) > 1 else ""
111
+ except Exception:
112
+ reasoning = completion
113
+ answer = ""
114
+
115
+ parsed_actions.append((prompt, completion, reasoning, answer))
116
+ prompt_answers[prompt].append(answer)
117
+
118
+ majority_answers = {}
119
+ for p, ans_list in prompt_answers.items():
120
+ if ans_list:
121
+ majority_answers[p] = collections.Counter(ans_list).most_common(1)[0][0]
122
+
123
+ for p, c, r, a in parsed_actions:
124
+ action = AutomathreasonerAction(reasoning=r, final_answer=a)
125
+
126
+ # Simulate step
127
+ env.reset()
128
+ env.current_problem = p
129
+ step_obs = env.step(action)
130
+ r_total = step_obs.reward
131
+
132
+ # [PAPER TRACEABILITY: Self-Consistency Sampling]
133
+ # Verify majority match
134
+ majority = majority_answers.get(p, "")
135
+ is_majority = (a == majority) and len(a) > 0
136
+ if is_majority:
137
+ r_total += 0.2 # Bonus reward for mapping to majority
138
+
139
+ rewards.append(r_total)
140
+
141
+ is_correct = step_obs.metadata.get('is_correct', False)
142
+ q_score = step_obs.metadata.get('reward_components', {}).get('Q_reasoning', 0.0)
143
+
144
+ # B. ReST-STYLE FILTERING (SELF-TRAINING)
145
+ # Filter samples where correctness = 1 AND reasoning quality > 0.6
146
+ # [PAPER TRACEABILITY: ReST (Rest-Style Filtering)]
147
+ if is_correct and q_score > 0.6:
148
+ # Store as High Quality trajectory in Ladder buffer
149
+ ladder_item = {
150
+ "prompt": p,
151
+ "best_solution": c,
152
+ "failed_attempts": [],
153
+ "reward": r_total
154
+ }
155
+ replay_buffer.add_ladder(ladder_item)
156
+
157
+ # Standard buffer mapping
158
+ if is_correct:
159
+ replay_buffer.add(p, c, [], reward=r_total)
160
+ else:
161
+ replay_buffer.add(p, "", [c], reward=r_total)
162
+
163
+ return rewards
164
+
165
+ training_args = GRPOConfig(
166
+ output_dir="outputs",
167
+ learning_rate=1e-5,
168
+ per_device_train_batch_size=1,
169
+ gradient_accumulation_steps=4,
170
+ max_prompt_length=128,
171
+ max_completion_length=256,
172
+ num_generations=8, # K=8 outputs per problem (Allows Self-consistency majority to work)
173
+ max_steps=100,
174
+ logging_steps=10,
175
+ )
176
+
177
+ trainer = GRPOTrainer(
178
+ model=model,
179
+ reward_funcs=[compute_rewards],
180
+ args=training_args,
181
+ train_dataset=dataset,
182
+ )
183
+
184
+ print("Starting GRPO Training with Research-Aligned Modules...")
185
+ trainer.train()
186
+
187
+ if __name__ == "__main__":
188
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff