Spaces:
Runtime error
Runtime error
Commit ·
1af4cba
1
Parent(s): 5edec00
HF Spaces GPU training pipeline
Browse files- .dockerignore +28 -0
- Dockerfile +22 -11
- RULES.md +79 -0
- push_to_hub.py +2 -1
- pyproject.toml +5 -3
- run_hf_training.sh +169 -23
- salespath_env/server/rules.py +13 -4
- scripts/run_training.sh +4 -0
- training/eval_baseline_vs_trained.py +251 -0
- training/grpo_train.py +4 -3
- training/hf_keepalive_app.py +131 -0
- training/preflight_check.py +75 -0
.dockerignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git
|
| 2 |
+
.git/
|
| 3 |
+
.gitignore
|
| 4 |
+
|
| 5 |
+
# Python cache
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.pyc
|
| 8 |
+
*.pyo
|
| 9 |
+
*.egg-info/
|
| 10 |
+
|
| 11 |
+
# Training outputs (too large for Docker context)
|
| 12 |
+
salespath_out/
|
| 13 |
+
|
| 14 |
+
# IDE
|
| 15 |
+
.idea/
|
| 16 |
+
.vscode/
|
| 17 |
+
*.swp
|
| 18 |
+
*.swo
|
| 19 |
+
|
| 20 |
+
# OS
|
| 21 |
+
.DS_Store
|
| 22 |
+
Thumbs.db
|
| 23 |
+
|
| 24 |
+
# Notebook checkpoints
|
| 25 |
+
.ipynb_checkpoints/
|
| 26 |
+
|
| 27 |
+
# HF Space specific
|
| 28 |
+
push_to_hub.py
|
Dockerfile
CHANGED
|
@@ -1,32 +1,43 @@
|
|
| 1 |
-
FROM
|
| 2 |
|
| 3 |
# HuggingFace Spaces runs on port 7860 by default
|
| 4 |
ENV PORT=7860
|
| 5 |
ENV PYTHONUNBUFFERED=1
|
| 6 |
ENV PYTHONDONTWRITEBYTECODE=1
|
|
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
# Install system dependencies
|
| 11 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
-
curl \
|
|
|
|
| 13 |
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
COPY requirements.txt .
|
| 17 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
|
| 19 |
# Copy the salespath_env package and training scripts
|
| 20 |
COPY salespath_env/ ./salespath_env/
|
| 21 |
COPY training/ ./training/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
# Copy and set permissions for the training
|
| 24 |
-
COPY run_hf_training.sh ./
|
| 25 |
-
RUN sed -i 's/\r$//' ./
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
CMD curl -f http://localhost:${PORT}/health || exit 1
|
| 30 |
|
| 31 |
-
|
| 32 |
-
CMD ["sh", "-c", "uvicorn salespath_env.server.app:app --host 0.0.0.0 --port ${PORT}"]
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 2 |
|
| 3 |
# HuggingFace Spaces runs on port 7860 by default
|
| 4 |
ENV PORT=7860
|
| 5 |
ENV PYTHONUNBUFFERED=1
|
| 6 |
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 8 |
|
| 9 |
WORKDIR /app
|
| 10 |
|
| 11 |
# Install system dependencies
|
| 12 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
python3 python3-pip python3-dev git curl \
|
| 14 |
+
&& ln -sf /usr/bin/python3 /usr/bin/python \
|
| 15 |
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
|
| 17 |
+
# Pin NumPy to avoid breakage
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 19 |
+
pip install "numpy<2"
|
| 20 |
+
|
| 21 |
+
# Install PyTorch with CUDA 12.1 support
|
| 22 |
+
RUN pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu121
|
| 23 |
+
|
| 24 |
+
# Copy and install Python dependencies
|
| 25 |
COPY requirements.txt .
|
| 26 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 27 |
|
| 28 |
# Copy the salespath_env package and training scripts
|
| 29 |
COPY salespath_env/ ./salespath_env/
|
| 30 |
COPY training/ ./training/
|
| 31 |
+
COPY scripts/ ./scripts/
|
| 32 |
+
|
| 33 |
+
# Install the salespath_env package
|
| 34 |
+
RUN pip install -e . --no-deps || true
|
| 35 |
|
| 36 |
+
# Copy and set permissions for the training entrypoint
|
| 37 |
+
COPY run_hf_training.sh ./run_training.sh
|
| 38 |
+
RUN sed -i 's/\r$//' ./run_training.sh && chmod +x ./run_training.sh
|
| 39 |
|
| 40 |
+
# NO HEALTHCHECK — the entrypoint script starts a background health server
|
| 41 |
+
# to keep HF Spaces alive during long training runs
|
|
|
|
| 42 |
|
| 43 |
+
CMD ["/bin/bash", "./run_training.sh"]
|
|
|
RULES.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SalesPath — Business Rules (R01–R09)
|
| 2 |
+
|
| 3 |
+
The environment enforces these 9 business rules at every step.
|
| 4 |
+
Three violations → episode terminates with a heavy penalty.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## R01 — Qualify Before Present
|
| 9 |
+
|
| 10 |
+
> **Must QUALIFY before PRESENT**
|
| 11 |
+
|
| 12 |
+
The agent cannot pitch the product until it has asked qualifying questions about the prospect's needs, budget, and situation.
|
| 13 |
+
|
| 14 |
+
## R02 — Demo Before Negotiate
|
| 15 |
+
|
| 16 |
+
> **Must OFFER_DEMO before NEGOTIATE**
|
| 17 |
+
|
| 18 |
+
No discount or price negotiation is allowed unless a product demo has been offered and scheduled.
|
| 19 |
+
|
| 20 |
+
## R03 — Budget Known Before Negotiate
|
| 21 |
+
|
| 22 |
+
> **Budget must be known before NEGOTIATE**
|
| 23 |
+
|
| 24 |
+
The prospect's budget must be revealed (via QUALIFY action) before the agent can enter negotiations.
|
| 25 |
+
|
| 26 |
+
## R04 — Discount After Objections
|
| 27 |
+
|
| 28 |
+
> **Discount only after 2 objections handled**
|
| 29 |
+
|
| 30 |
+
If the agent mentions a discount during NEGOTIATE, at least 2 prospect objections must have been successfully handled first.
|
| 31 |
+
|
| 32 |
+
## R05 — No Repeat Action
|
| 33 |
+
|
| 34 |
+
> **Cannot repeat same action consecutively**
|
| 35 |
+
|
| 36 |
+
The agent cannot use the same action type twice in a row. A QUALIFY cannot follow a QUALIFY, a PRESENT cannot follow a PRESENT, etc.
|
| 37 |
+
|
| 38 |
+
## R06 — First Action Must Be PROSPECT
|
| 39 |
+
|
| 40 |
+
> **First action must always be PROSPECT**
|
| 41 |
+
|
| 42 |
+
Every episode must begin with the PROSPECT action. Any other first action is invalid.
|
| 43 |
+
|
| 44 |
+
## R07 — Follow-Up Only After Silence
|
| 45 |
+
|
| 46 |
+
> **FOLLOW_UP only after prospect goes silent**
|
| 47 |
+
|
| 48 |
+
FOLLOW_UP is only valid when the prospect has disengaged (returned a `silence` response). If the prospect just responded with actual content, FOLLOW_UP is a violation.
|
| 49 |
+
|
| 50 |
+
## R08 — Disqualify Logic
|
| 51 |
+
|
| 52 |
+
> **DISQUALIFY only if prospect is genuinely unqualified**
|
| 53 |
+
|
| 54 |
+
DISQUALIFY is a violation if the prospect is actually closable (true budget ≥ close threshold AND decision maker is present). Use it only when the deal is truly unwinnable.
|
| 55 |
+
|
| 56 |
+
## R09 — Close Requires Demo
|
| 57 |
+
|
| 58 |
+
> **Must OFFER_DEMO before CLOSE (difficulty 2+)**
|
| 59 |
+
|
| 60 |
+
On difficulty 2 and above, the agent must have completed OFFER_DEMO before attempting to CLOSE the deal.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## How Rules Are Enforced
|
| 65 |
+
|
| 66 |
+
Rules are checked **before** the prospect responds to an action. Violations are accumulated in `constraints_violated` and returned in the observation:
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
# Observation schema
|
| 70 |
+
{
|
| 71 |
+
"constraints_violated": ["R01", "R05"], # New violations this turn
|
| 72 |
+
"steps_completed": ["PROSPECT", "QUALIFY"],
|
| 73 |
+
...
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
When `len(constraints_violated) >= 3`, the episode terminates with:
|
| 78 |
+
- `r_outcome = -0.5` (terminal penalty)
|
| 79 |
+
- `r_compliance = -0.2 × violations` (per-turn penalty)
|
push_to_hub.py
CHANGED
|
@@ -13,7 +13,8 @@ IGNORE_PATTERNS = [
|
|
| 13 |
"*.egg-info/**",
|
| 14 |
"push_to_hub.py",
|
| 15 |
"salespath_env/server/Dockerfile", # root Dockerfile is used instead
|
| 16 |
-
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
def main():
|
|
|
|
| 13 |
"*.egg-info/**",
|
| 14 |
"push_to_hub.py",
|
| 15 |
"salespath_env/server/Dockerfile", # root Dockerfile is used instead
|
| 16 |
+
# Training scripts ARE included for HF Spaces GPU training
|
| 17 |
+
# "training/**",
|
| 18 |
]
|
| 19 |
|
| 20 |
def main():
|
pyproject.toml
CHANGED
|
@@ -11,10 +11,12 @@ dependencies = [
|
|
| 11 |
"fastapi",
|
| 12 |
"uvicorn",
|
| 13 |
"pydantic>=2.0",
|
| 14 |
-
"trl>=0.
|
| 15 |
-
"
|
| 16 |
"torch",
|
| 17 |
-
"transformers",
|
|
|
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
[tool.setuptools.packages.find]
|
|
|
|
| 11 |
"fastapi",
|
| 12 |
"uvicorn",
|
| 13 |
"pydantic>=2.0",
|
| 14 |
+
"trl>=0.11.0",
|
| 15 |
+
"peft>=0.11.0",
|
| 16 |
"torch",
|
| 17 |
+
"transformers>=4.44.0",
|
| 18 |
+
"accelerate>=0.33.0",
|
| 19 |
+
"bitsandbytes>=0.43.0",
|
| 20 |
]
|
| 21 |
|
| 22 |
[tool.setuptools.packages.find]
|
run_hf_training.sh
CHANGED
|
@@ -1,28 +1,174 @@
|
|
| 1 |
-
#!/bin/bash
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
#
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
echo "
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
--mode grpo \
|
| 14 |
-
--env-url http://127.0.0.1:
|
| 15 |
-
--model-name
|
| 16 |
-
--grpo-steps
|
| 17 |
--grpo-dataset-size 128 \
|
| 18 |
-
--num-generations
|
| 19 |
-
--max-completion-length
|
| 20 |
-
--per-device-train-batch-size
|
| 21 |
-
--gradient-accumulation-steps
|
| 22 |
-
--output-dir
|
| 23 |
-
--logging-steps
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
cd /app
|
| 4 |
|
| 5 |
+
# ====================================================================
|
| 6 |
+
# SalesPath Training Pipeline — Configuration
|
| 7 |
+
# ====================================================================
|
| 8 |
+
# Override any of these via HF Space "Variables and secrets" settings.
|
| 9 |
+
#
|
| 10 |
+
# GPU VRAM Guide (for GRPO with LoRA 4-bit):
|
| 11 |
+
# T4 (16GB) → 0.5B-3B models → num_generations=2-4, batch=1-2
|
| 12 |
+
# L4 (24GB) → 7B models → num_generations=2, batch=1
|
| 13 |
+
# A10G (24GB)→ 7B models → num_generations=4, batch=2
|
| 14 |
+
# A100 (40GB)→ 14B-32B models → num_generations=4, batch=4
|
| 15 |
+
#
|
| 16 |
+
# Example for 7B on L4:
|
| 17 |
+
# MODEL_NAME=Qwen/Qwen2.5-7B-Instruct
|
| 18 |
+
# NUM_GENERATIONS=2
|
| 19 |
+
# PER_DEVICE_BATCH=1
|
| 20 |
+
# MAX_SEQ_LEN=512
|
| 21 |
+
# ====================================================================
|
| 22 |
|
| 23 |
+
export PORT="${PORT:-7860}"
|
| 24 |
+
export HF_MODEL_REPO="${HF_MODEL_REPO:-Imsachin010/salespath-qwen25-0.5b}"
|
| 25 |
+
export MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 26 |
+
export OUTPUT_DIR="${OUTPUT_DIR:-/app/salespath_out}"
|
| 27 |
+
export GRPO_STEPS="${GRPO_STEPS:-100}"
|
| 28 |
+
export NUM_GENERATIONS="${NUM_GENERATIONS:-4}"
|
| 29 |
+
export PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-2}"
|
| 30 |
+
export GRAD_ACCUM="${GRAD_ACCUM:-8}"
|
| 31 |
+
export MAX_SEQ_LEN="${MAX_SEQ_LEN:-1024}"
|
| 32 |
+
export LOGGING_STEPS="${LOGGING_STEPS:-10}"
|
| 33 |
+
export EVAL_EPISODES="${EVAL_EPISODES:-4}"
|
| 34 |
|
| 35 |
+
echo "========================================"
|
| 36 |
+
echo " SalesPath Training Pipeline"
|
| 37 |
+
echo " Model: $MODEL_NAME"
|
| 38 |
+
echo " HF Repo: $HF_MODEL_REPO"
|
| 39 |
+
echo " GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')"
|
| 40 |
+
echo " Port: $PORT"
|
| 41 |
+
echo "========================================"
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------------
|
| 44 |
+
# 1. Background health server (keeps HF Spaces happy during training)
|
| 45 |
+
# ------------------------------------------------------------------
|
| 46 |
+
echo "Starting background health server on port $PORT..."
|
| 47 |
+
python3 -c "
|
| 48 |
+
import http.server, socketserver, os
|
| 49 |
+
PORT = int(os.environ.get('PORT', 7860))
|
| 50 |
+
class H(http.server.SimpleHTTPRequestHandler):
|
| 51 |
+
def do_GET(self):
|
| 52 |
+
if self.path == '/health':
|
| 53 |
+
self.send_response(200); self.end_headers(); self.wfile.write(b'OK')
|
| 54 |
+
else:
|
| 55 |
+
self.send_response(404); self.end_headers()
|
| 56 |
+
def log_message(self, *a): pass
|
| 57 |
+
with socketserver.TCPServer(('', PORT), H) as httpd:
|
| 58 |
+
httpd.serve_forever()
|
| 59 |
+
" &
|
| 60 |
+
HEALTH_PID=$!
|
| 61 |
+
echo "Health server PID: $HEALTH_PID"
|
| 62 |
+
sleep 2
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------------
|
| 65 |
+
# 2. HF login (if token is set as secret)
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
if [[ -n "${HF_TOKEN:-}" ]]; then
|
| 68 |
+
echo "Logging in to Hugging Face Hub..."
|
| 69 |
+
huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential
|
| 70 |
+
fi
|
| 71 |
+
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# 3. Pre-flight check
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
echo "=== Pre-flight check ==="
|
| 76 |
+
python training/preflight_check.py || echo "Pre-flight warning (non-fatal)"
|
| 77 |
+
|
| 78 |
+
# ------------------------------------------------------------------
|
| 79 |
+
# 4. Start environment server (needed for rollout-based training)
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
echo "Starting SalesPath environment server on port 8000..."
|
| 82 |
+
uvicorn salespath_env.server.app:app --host 0.0.0.0 --port 8000 &
|
| 83 |
+
ENV_PID=$!
|
| 84 |
+
sleep 3
|
| 85 |
+
|
| 86 |
+
# Verify environment server is healthy
|
| 87 |
+
python3 -c "
|
| 88 |
+
import httpx, time
|
| 89 |
+
for i in range(10):
|
| 90 |
+
try:
|
| 91 |
+
r = httpx.get('http://127.0.0.1:8000/health', timeout=5)
|
| 92 |
+
if r.status_code == 200: print('Environment server OK'); break
|
| 93 |
+
except: pass
|
| 94 |
+
time.sleep(2)
|
| 95 |
+
"
|
| 96 |
+
|
| 97 |
+
# ------------------------------------------------------------------
|
| 98 |
+
# 5. GRPO Training
|
| 99 |
+
# ------------------------------------------------------------------
|
| 100 |
+
echo ""
|
| 101 |
+
echo "=== GRPO Training with $MODEL_NAME ==="
|
| 102 |
+
echo "Steps: $GRPO_STEPS | Generations: $NUM_GENERATIONS | Batch: $PER_DEVICE_BATCH"
|
| 103 |
+
|
| 104 |
+
PYTORCH_ALLOC_CONF=expandable_segments:True \
|
| 105 |
+
python -u -m training.grpo_train \
|
| 106 |
--mode grpo \
|
| 107 |
+
--env-url http://127.0.0.1:8000 \
|
| 108 |
+
--model-name "$MODEL_NAME" \
|
| 109 |
+
--grpo-steps "$GRPO_STEPS" \
|
| 110 |
--grpo-dataset-size 128 \
|
| 111 |
+
--num-generations "$NUM_GENERATIONS" \
|
| 112 |
+
--max-completion-length "$MAX_SEQ_LEN" \
|
| 113 |
+
--per-device-train-batch-size "$PER_DEVICE_BATCH" \
|
| 114 |
+
--gradient-accumulation-steps "$GRAD_ACCUM" \
|
| 115 |
+
--output-dir "$OUTPUT_DIR" \
|
| 116 |
+
--logging-steps "$LOGGING_STEPS"
|
| 117 |
+
|
| 118 |
+
TRAINING_EXIT=$?
|
| 119 |
+
echo "GRPO training exit code: $TRAINING_EXIT"
|
| 120 |
+
|
| 121 |
+
# ------------------------------------------------------------------
|
| 122 |
+
# 6. Evaluation: baseline vs trained
|
| 123 |
+
# ------------------------------------------------------------------
|
| 124 |
+
if [[ $TRAINING_EXIT -eq 0 ]]; then
|
| 125 |
+
echo ""
|
| 126 |
+
echo "=== Evaluation: Baseline vs Trained ==="
|
| 127 |
+
python training/eval_baseline_vs_trained.py \
|
| 128 |
+
--base "$MODEL_NAME" \
|
| 129 |
+
--trained "$OUTPUT_DIR/grpo_final" \
|
| 130 |
+
--env-url http://127.0.0.1:8000 \
|
| 131 |
+
--episodes-per-level "$EVAL_EPISODES" \
|
| 132 |
+
--output "$OUTPUT_DIR/eval_results.json"
|
| 133 |
+
|
| 134 |
+
echo ""
|
| 135 |
+
echo "=== Generating reward plots ==="
|
| 136 |
+
python training/plot_rewards.py \
|
| 137 |
+
--input "$OUTPUT_DIR/reward_history.txt" \
|
| 138 |
+
--output "$OUTPUT_DIR/reward_graph.png" || echo "Plotting skipped"
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
# ------------------------------------------------------------------
|
| 142 |
+
# 7. Upload artifacts to Hugging Face Hub
|
| 143 |
+
# ------------------------------------------------------------------
|
| 144 |
+
if [[ $TRAINING_EXIT -eq 0 && -n "${HF_TOKEN:-}" ]]; then
|
| 145 |
+
echo ""
|
| 146 |
+
echo "=== Uploading to $HF_MODEL_REPO ==="
|
| 147 |
+
|
| 148 |
+
# Upload GRPO adapters
|
| 149 |
+
huggingface-cli upload "$HF_MODEL_REPO" "$OUTPUT_DIR/grpo_final" . --repo-type model || true
|
| 150 |
+
|
| 151 |
+
# Upload logs and plots
|
| 152 |
+
for f in reward_history.txt eval_results.json reward_graph.png; do
|
| 153 |
+
if [[ -f "$OUTPUT_DIR/$f" ]]; then
|
| 154 |
+
huggingface-cli upload "$HF_MODEL_REPO" "$OUTPUT_DIR/$f" "training_artifacts/$f" --repo-type model || true
|
| 155 |
+
fi
|
| 156 |
+
done
|
| 157 |
+
|
| 158 |
+
echo "Upload complete!"
|
| 159 |
+
fi
|
| 160 |
+
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
# 8. Keep container alive for log inspection
|
| 163 |
+
# ------------------------------------------------------------------
|
| 164 |
+
echo ""
|
| 165 |
+
echo "Training pipeline complete."
|
| 166 |
+
echo "Container will stay alive. Check logs via HF Spaces dashboard."
|
| 167 |
+
echo "Stop the Space manually when done to avoid further billing."
|
| 168 |
+
|
| 169 |
+
# Kill background servers
|
| 170 |
+
kill $HEALTH_PID 2>/dev/null || true
|
| 171 |
+
kill $ENV_PID 2>/dev/null || true
|
| 172 |
+
|
| 173 |
+
# Start keepalive app
|
| 174 |
+
exec uvicorn training.hf_keepalive_app:app --host 0.0.0.0 --port "$PORT"
|
salespath_env/server/rules.py
CHANGED
|
@@ -104,13 +104,22 @@ def _followup_timing(
|
|
| 104 |
) -> bool:
|
| 105 |
"""
|
| 106 |
R07:
|
| 107 |
-
FOLLOW_UP only valid after silence.
|
| 108 |
-
|
|
|
|
| 109 |
"""
|
| 110 |
if action.action_type == "FOLLOW_UP":
|
| 111 |
if state.conversation_history:
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return False
|
| 115 |
|
| 116 |
|
|
|
|
| 104 |
) -> bool:
|
| 105 |
"""
|
| 106 |
R07:
|
| 107 |
+
FOLLOW_UP only valid after prospect silence.
|
| 108 |
+
Violation if the prospect's last response had actual content
|
| 109 |
+
(i.e., the prospect is still engaged and waiting for a reply).
|
| 110 |
"""
|
| 111 |
if action.action_type == "FOLLOW_UP":
|
| 112 |
if state.conversation_history:
|
| 113 |
+
# Walk backwards to find the last prospect message
|
| 114 |
+
for entry in reversed(state.conversation_history):
|
| 115 |
+
if entry.get("speaker") == "prospect":
|
| 116 |
+
response_token = entry.get("response_token", "")
|
| 117 |
+
# FOLLOW_UP is only valid if the prospect went silent
|
| 118 |
+
return response_token != "silence"
|
| 119 |
+
# No prospect message found — first turn, so violation
|
| 120 |
+
return True
|
| 121 |
+
# No history at all — first turn, can't FOLLOW_UP yet
|
| 122 |
+
return True
|
| 123 |
return False
|
| 124 |
|
| 125 |
|
scripts/run_training.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Simple alias for the main entrypoint
|
| 3 |
+
cd /app
|
| 4 |
+
exec ./run_training.sh "$@"
|
training/eval_baseline_vs_trained.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SalesPath — Evaluate Baseline vs Trained Model
|
| 3 |
+
|
| 4 |
+
Runs episodes at each difficulty level with both the base model
|
| 5 |
+
and the trained (GRPO) model, then compares performance.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python training/eval_baseline_vs_trained.py \
|
| 9 |
+
--base Qwen/Qwen2.5-0.5B-Instruct \
|
| 10 |
+
--trained ./salespath_out/grpo_final \
|
| 11 |
+
--env-url http://127.0.0.1:8000 \
|
| 12 |
+
--episodes-per-level 4 \
|
| 13 |
+
--output ./salespath_out/eval_results.json
|
| 14 |
+
"""
|
| 15 |
+
import argparse
|
| 16 |
+
import asyncio
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 25 |
+
|
| 26 |
+
# Ensure project root is on path
|
| 27 |
+
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 28 |
+
if _ROOT not in sys.path:
|
| 29 |
+
sys.path.insert(0, _ROOT)
|
| 30 |
+
|
| 31 |
+
from training.rollout import run_episode
|
| 32 |
+
|
| 33 |
+
DIFFICULTIES = [1, 2, 3, 4]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def eval_model(
|
| 37 |
+
model,
|
| 38 |
+
tokenizer,
|
| 39 |
+
env_url: str,
|
| 40 |
+
episodes_per_level: int = 4,
|
| 41 |
+
label: str = "model",
|
| 42 |
+
) -> dict:
|
| 43 |
+
"""Evaluate a model across all difficulty levels."""
|
| 44 |
+
results = {
|
| 45 |
+
"label": label,
|
| 46 |
+
"per_difficulty": {},
|
| 47 |
+
"overall": {},
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
all_rewards = []
|
| 51 |
+
all_violations = []
|
| 52 |
+
all_closes = []
|
| 53 |
+
all_lengths = []
|
| 54 |
+
|
| 55 |
+
for difficulty in DIFFICULTIES:
|
| 56 |
+
diff_rewards = []
|
| 57 |
+
diff_violations = []
|
| 58 |
+
diff_closes = []
|
| 59 |
+
diff_lengths = []
|
| 60 |
+
|
| 61 |
+
print(f" Difficulty {difficulty}...")
|
| 62 |
+
for ep in range(episodes_per_level):
|
| 63 |
+
result = await run_episode(
|
| 64 |
+
model=model,
|
| 65 |
+
tokenizer=tokenizer,
|
| 66 |
+
env_url=env_url,
|
| 67 |
+
difficulty=difficulty,
|
| 68 |
+
message_timeout_s=120.0,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
trajectory = result["trajectory"]
|
| 72 |
+
reward = result["total_reward"]
|
| 73 |
+
violations = result["violations"]
|
| 74 |
+
steps = result["steps_completed"]
|
| 75 |
+
length = len(trajectory)
|
| 76 |
+
|
| 77 |
+
# Did we close successfully?
|
| 78 |
+
last_action = trajectory[-1]["action_type"] if trajectory else ""
|
| 79 |
+
last_traj = trajectory[-1] if trajectory else {}
|
| 80 |
+
components = last_traj.get("components", {})
|
| 81 |
+
r_outcome = components.get("r_outcome", 0.0)
|
| 82 |
+
closed = last_action == "CLOSE" and r_outcome > 0
|
| 83 |
+
|
| 84 |
+
diff_rewards.append(reward)
|
| 85 |
+
diff_violations.append(len(violations))
|
| 86 |
+
diff_closes.append(1 if closed else 0)
|
| 87 |
+
diff_lengths.append(length)
|
| 88 |
+
|
| 89 |
+
all_rewards.append(reward)
|
| 90 |
+
all_violations.append(len(violations))
|
| 91 |
+
all_closes.append(1 if closed else 0)
|
| 92 |
+
all_lengths.append(length)
|
| 93 |
+
|
| 94 |
+
results["per_difficulty"][difficulty] = {
|
| 95 |
+
"mean_reward": sum(diff_rewards) / len(diff_rewards) if diff_rewards else 0,
|
| 96 |
+
"mean_violations": sum(diff_violations) / len(diff_violations) if diff_violations else 0,
|
| 97 |
+
"close_rate": sum(diff_closes) / len(diff_closes) if diff_closes else 0,
|
| 98 |
+
"mean_episode_length": sum(diff_lengths) / len(diff_lengths) if diff_lengths else 0,
|
| 99 |
+
"num_episodes": len(diff_rewards),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
results["overall"] = {
|
| 103 |
+
"mean_reward": sum(all_rewards) / len(all_rewards) if all_rewards else 0,
|
| 104 |
+
"mean_violations": sum(all_violations) / len(all_violations) if all_violations else 0,
|
| 105 |
+
"close_rate": sum(all_closes) / len(all_closes) if all_closes else 0,
|
| 106 |
+
"mean_episode_length": sum(all_lengths) / len(all_lengths) if all_lengths else 0,
|
| 107 |
+
"num_episodes": len(all_rewards),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return results
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_model(model_name_or_path: str):
|
| 114 |
+
"""Load model, detecting if it's a PEFT adapter."""
|
| 115 |
+
print(f"Loading model: {model_name_or_path}")
|
| 116 |
+
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 118 |
+
if tokenizer.pad_token is None:
|
| 119 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 120 |
+
|
| 121 |
+
# Check if this is a PEFT adapter directory
|
| 122 |
+
adapter_path = Path(model_name_or_path)
|
| 123 |
+
is_adapter = (adapter_path / "adapter_config.json").exists()
|
| 124 |
+
|
| 125 |
+
if is_adapter:
|
| 126 |
+
print(" Detected PEFT adapter — loading base model + adapter...")
|
| 127 |
+
from peft import PeftModel
|
| 128 |
+
|
| 129 |
+
# Find the base model name from adapter config
|
| 130 |
+
import json as _json
|
| 131 |
+
with open(adapter_path / "adapter_config.json") as f:
|
| 132 |
+
adapter_cfg = _json.load(f)
|
| 133 |
+
base_model_name = adapter_cfg.get("base_model_name_or_path", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 134 |
+
print(f" Base model: {base_model_name}")
|
| 135 |
+
|
| 136 |
+
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 137 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 138 |
+
base_model_name,
|
| 139 |
+
torch_dtype=torch.bfloat16 if bf16_supported else torch.float32,
|
| 140 |
+
device_map="auto",
|
| 141 |
+
)
|
| 142 |
+
model = PeftModel.from_pretrained(base_model, model_name_or_path)
|
| 143 |
+
# Merge adapter for faster inference
|
| 144 |
+
model = model.merge_and_unload()
|
| 145 |
+
print(" Adapter loaded and merged ✅")
|
| 146 |
+
else:
|
| 147 |
+
bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 148 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 149 |
+
model_name_or_path,
|
| 150 |
+
torch_dtype=torch.bfloat16 if bf16_supported else torch.float32,
|
| 151 |
+
device_map="auto",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
model.eval()
|
| 155 |
+
print(f" Model on: {next(model.parameters()).device}")
|
| 156 |
+
return model, tokenizer
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
async def main():
|
| 160 |
+
parser = argparse.ArgumentParser(description="Evaluate baseline vs trained model")
|
| 161 |
+
parser.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct")
|
| 162 |
+
parser.add_argument("--trained", default="./salespath_out/grpo_final")
|
| 163 |
+
parser.add_argument("--env-url", default="http://127.0.0.1:8000")
|
| 164 |
+
parser.add_argument("--episodes-per-level", type=int, default=4)
|
| 165 |
+
parser.add_argument("--output", default="./salespath_out/eval_results.json")
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
|
| 168 |
+
print("=" * 60)
|
| 169 |
+
print("SalesPath — Model Evaluation")
|
| 170 |
+
print("=" * 60)
|
| 171 |
+
print(f"Base model: {args.base}")
|
| 172 |
+
print(f"Trained model: {args.trained}")
|
| 173 |
+
print(f"Episodes/level: {args.episodes_per_level}")
|
| 174 |
+
print()
|
| 175 |
+
|
| 176 |
+
# Evaluate base model
|
| 177 |
+
print("Loading base model...")
|
| 178 |
+
base_model, base_tokenizer = load_model(args.base)
|
| 179 |
+
print("\nEvaluating base model...")
|
| 180 |
+
base_results = await eval_model(
|
| 181 |
+
base_model, base_tokenizer, args.env_url,
|
| 182 |
+
episodes_per_level=args.episodes_per_level,
|
| 183 |
+
label="baseline",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Clean up
|
| 187 |
+
del base_model
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
+
torch.cuda.empty_cache()
|
| 190 |
+
|
| 191 |
+
# Evaluate trained model
|
| 192 |
+
print("\nLoading trained model...")
|
| 193 |
+
trained_model, trained_tokenizer = load_model(args.trained)
|
| 194 |
+
print("\nEvaluating trained model...")
|
| 195 |
+
trained_results = await eval_model(
|
| 196 |
+
trained_model, trained_tokenizer, args.env_url,
|
| 197 |
+
episodes_per_level=args.episodes_per_level,
|
| 198 |
+
label="trained",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Print comparison
|
| 202 |
+
print("\n" + "=" * 60)
|
| 203 |
+
print("RESULTS COMPARISON")
|
| 204 |
+
print("=" * 60)
|
| 205 |
+
|
| 206 |
+
for model_results in [base_results, trained_results]:
|
| 207 |
+
label = model_results["label"]
|
| 208 |
+
overall = model_results["overall"]
|
| 209 |
+
print(f"\n--- {label.upper()} ---")
|
| 210 |
+
print(f" Mean reward: {overall['mean_reward']:.4f}")
|
| 211 |
+
print(f" Mean violations: {overall['mean_violations']:.2f}")
|
| 212 |
+
print(f" Close rate: {overall['close_rate']:.2%}")
|
| 213 |
+
print(f" Mean ep. length: {overall['mean_episode_length']:.1f}")
|
| 214 |
+
|
| 215 |
+
for diff, metrics in model_results["per_difficulty"].items():
|
| 216 |
+
print(f" Difficulty {diff}: reward={metrics['mean_reward']:.3f}, "
|
| 217 |
+
f"violations={metrics['mean_violations']:.1f}, "
|
| 218 |
+
f"close={metrics['close_rate']:.0%}")
|
| 219 |
+
|
| 220 |
+
# Save results
|
| 221 |
+
output = {
|
| 222 |
+
"base": base_results,
|
| 223 |
+
"trained": trained_results,
|
| 224 |
+
"comparison": {
|
| 225 |
+
"reward_delta": trained_results["overall"]["mean_reward"] - base_results["overall"]["mean_reward"],
|
| 226 |
+
"violation_reduction": base_results["overall"]["mean_violations"] - trained_results["overall"]["mean_violations"],
|
| 227 |
+
"close_rate_improvement": trained_results["overall"]["close_rate"] - base_results["overall"]["close_rate"],
|
| 228 |
+
},
|
| 229 |
+
"config": {
|
| 230 |
+
"base_model": args.base,
|
| 231 |
+
"trained_model": args.trained,
|
| 232 |
+
"episodes_per_level": args.episodes_per_level,
|
| 233 |
+
"difficulties": DIFFICULTIES,
|
| 234 |
+
},
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| 238 |
+
with open(args.output, "w") as f:
|
| 239 |
+
json.dump(output, f, indent=2)
|
| 240 |
+
print(f"\nResults saved to {args.output}")
|
| 241 |
+
|
| 242 |
+
# Print comparison summary
|
| 243 |
+
print("\n=== KEY METRICS ===")
|
| 244 |
+
c = output["comparison"]
|
| 245 |
+
print(f" Reward delta: {c['reward_delta']:+.4f}")
|
| 246 |
+
print(f" Violation reduction: {c['violation_reduction']:+.2f}")
|
| 247 |
+
print(f" Close rate change: {c['close_rate_improvement']:+.2%}")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
asyncio.run(main())
|
training/grpo_train.py
CHANGED
|
@@ -39,13 +39,12 @@ def _load_model_and_tokenizer(model_name: str, use_unsloth: bool = False):
|
|
| 39 |
try:
|
| 40 |
from unsloth import FastLanguageModel
|
| 41 |
print("Loading with Unsloth in 4-bit + LoRA...")
|
| 42 |
-
|
| 43 |
model_name=model_name,
|
| 44 |
max_seq_length=2048,
|
| 45 |
load_in_4bit=True,
|
| 46 |
fast_inference=True,
|
| 47 |
max_lora_rank=16,
|
| 48 |
-
max_lora_rank_type="lora",
|
| 49 |
)
|
| 50 |
# Inject LoRA adapters to drastically reduce VRAM
|
| 51 |
model = FastLanguageModel.get_peft_model(
|
|
@@ -292,7 +291,9 @@ def run_grpo(args):
|
|
| 292 |
"or fix local pyarrow/datasets installation first."
|
| 293 |
) from exc
|
| 294 |
|
| 295 |
-
|
|
|
|
|
|
|
| 296 |
rows = _build_grpo_dataset_rows(args.grpo_dataset_size)
|
| 297 |
train_dataset = Dataset.from_list(rows)
|
| 298 |
|
|
|
|
| 39 |
try:
|
| 40 |
from unsloth import FastLanguageModel
|
| 41 |
print("Loading with Unsloth in 4-bit + LoRA...")
|
| 42 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 43 |
model_name=model_name,
|
| 44 |
max_seq_length=2048,
|
| 45 |
load_in_4bit=True,
|
| 46 |
fast_inference=True,
|
| 47 |
max_lora_rank=16,
|
|
|
|
| 48 |
)
|
| 49 |
# Inject LoRA adapters to drastically reduce VRAM
|
| 50 |
model = FastLanguageModel.get_peft_model(
|
|
|
|
| 291 |
"or fix local pyarrow/datasets installation first."
|
| 292 |
) from exc
|
| 293 |
|
| 294 |
+
# Try Unsloth first (4-bit saves VRAM), fallback to standard HF
|
| 295 |
+
use_unsloth = args.model_name.startswith("unsloth/")
|
| 296 |
+
model, tokenizer = _load_model_and_tokenizer(args.model_name, use_unsloth=use_unsloth)
|
| 297 |
rows = _build_grpo_dataset_rows(args.grpo_dataset_size)
|
| 298 |
train_dataset = Dataset.from_list(rows)
|
| 299 |
|
training/hf_keepalive_app.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SalesPath — HF Spaces Keepalive App
|
| 3 |
+
|
| 4 |
+
Serves a simple FastAPI app after training completes to keep
|
| 5 |
+
the HF Space alive and display training results.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 12 |
+
|
| 13 |
+
app = FastAPI(title="SalesPath — Training Complete")
|
| 14 |
+
|
| 15 |
+
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "/app/salespath_out"))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@app.get("/health")
|
| 19 |
+
async def health():
|
| 20 |
+
return {"status": "ok", "service": "SalesPath Training"}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@app.get("/")
|
| 24 |
+
async def root():
|
| 25 |
+
"""Display training results page."""
|
| 26 |
+
html = """
|
| 27 |
+
<!DOCTYPE html>
|
| 28 |
+
<html>
|
| 29 |
+
<head>
|
| 30 |
+
<title>SalesPath Training Complete</title>
|
| 31 |
+
<meta charset="utf-8">
|
| 32 |
+
<style>
|
| 33 |
+
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 34 |
+
max-width: 800px; margin: 40px auto; padding: 20px;
|
| 35 |
+
background: #0f172a; color: #e2e8f0; }
|
| 36 |
+
h1 { color: #38bdf8; }
|
| 37 |
+
h2 { color: #94a3b8; margin-top: 30px; }
|
| 38 |
+
.card { background: #1e293b; border-radius: 12px; padding: 20px; margin: 16px 0; }
|
| 39 |
+
.metric { display: inline-block; margin: 8px 16px 8px 0; }
|
| 40 |
+
.metric .value { font-size: 24px; font-weight: bold; color: #4ade80; }
|
| 41 |
+
.metric .label { font-size: 12px; color: #64748b; text-transform: uppercase; }
|
| 42 |
+
img { max-width: 100%; border-radius: 8px; margin: 16px 0; }
|
| 43 |
+
.badge { display: inline-block; padding: 4px 12px; border-radius: 20px;
|
| 44 |
+
font-size: 12px; font-weight: bold; }
|
| 45 |
+
.badge.success { background: #166534; color: #4ade80; }
|
| 46 |
+
.badge.info { background: #1e3a5f; color: #38bdf8; }
|
| 47 |
+
pre { background: #0f172a; padding: 12px; border-radius: 8px; overflow-x: auto; }
|
| 48 |
+
</style>
|
| 49 |
+
</head>
|
| 50 |
+
<body>
|
| 51 |
+
<h1>🏆 SalesPath Training Complete</h1>
|
| 52 |
+
<p>Trained model has been uploaded to Hugging Face Hub.</p>
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# Load eval results
|
| 56 |
+
eval_path = OUTPUT_DIR / "eval_results.json"
|
| 57 |
+
if eval_path.exists():
|
| 58 |
+
try:
|
| 59 |
+
eval_data = json.loads(eval_path.read_text())
|
| 60 |
+
html += '<div class="card"><h2>Evaluation Results</h2>'
|
| 61 |
+
for key, value in eval_data.items():
|
| 62 |
+
if isinstance(value, (int, float)):
|
| 63 |
+
html += f'<div class="metric"><div class="value">{value:.3f}</div><div class="label">{key}</div></div>'
|
| 64 |
+
else:
|
| 65 |
+
html += f'<pre>{json.dumps(value, indent=2)}</pre>'
|
| 66 |
+
html += "</div>"
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
# Show reward graph
|
| 71 |
+
graph_path = OUTPUT_DIR / "reward_graph.png"
|
| 72 |
+
if graph_path.exists():
|
| 73 |
+
import base64
|
| 74 |
+
img_b64 = base64.b64encode(graph_path.read_bytes()).decode()
|
| 75 |
+
html += f'<div class="card"><h2>Reward Curve</h2><img src="data:image/png;base64,{img_b64}" alt="Reward Graph"/></div>'
|
| 76 |
+
|
| 77 |
+
# Show reward history stats
|
| 78 |
+
history_path = OUTPUT_DIR / "reward_history.txt"
|
| 79 |
+
if history_path.exists():
|
| 80 |
+
lines = history_path.read_text().strip().splitlines()
|
| 81 |
+
rewards = [float(line.split("\t")[-1]) for line in lines if line.strip()]
|
| 82 |
+
if rewards:
|
| 83 |
+
html += f"""
|
| 84 |
+
<div class="card">
|
| 85 |
+
<h2>Training Stats</h2>
|
| 86 |
+
<div class="metric"><div class="value">{len(rewards)}</div><div class="label">Episodes</div></div>
|
| 87 |
+
<div class="metric"><div class="value">{sum(rewards)/len(rewards):.4f}</div><div class="label">Mean Reward</div></div>
|
| 88 |
+
<div class="metric"><div class="value">{max(rewards):.4f}</div><div class="label">Max Reward</div></div>
|
| 89 |
+
<div class="metric"><div class="value">{min(rewards):.4f}</div><div class="label">Min Reward</div></div>
|
| 90 |
+
</div>
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
html += """
|
| 94 |
+
<div class="card">
|
| 95 |
+
<h2>Next Steps</h2>
|
| 96 |
+
<p>1. View model on Hugging Face Hub</p>
|
| 97 |
+
<p>2. Run inference with the trained model</p>
|
| 98 |
+
<p>3. Stop this Space to avoid billing</p>
|
| 99 |
+
</div>
|
| 100 |
+
</body>
|
| 101 |
+
</html>
|
| 102 |
+
"""
|
| 103 |
+
return HTMLResponse(html)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@app.get("/api/results")
|
| 107 |
+
async def api_results():
|
| 108 |
+
"""Return training results as JSON."""
|
| 109 |
+
results = {}
|
| 110 |
+
|
| 111 |
+
eval_path = OUTPUT_DIR / "eval_results.json"
|
| 112 |
+
if eval_path.exists():
|
| 113 |
+
try:
|
| 114 |
+
results["eval"] = json.loads(eval_path.read_text())
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
history_path = OUTPUT_DIR / "reward_history.txt"
|
| 119 |
+
if history_path.exists():
|
| 120 |
+
lines = history_path.read_text().strip().splitlines()
|
| 121 |
+
rewards = [float(line.split("\t")[-1]) for line in lines if line.strip()]
|
| 122 |
+
if rewards:
|
| 123 |
+
results["training"] = {
|
| 124 |
+
"episodes": len(rewards),
|
| 125 |
+
"mean_reward": sum(rewards) / len(rewards),
|
| 126 |
+
"max_reward": max(rewards),
|
| 127 |
+
"min_reward": min(rewards),
|
| 128 |
+
"std_reward": __import__("statistics").stdev(rewards) if len(rewards) > 1 else 0,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
return JSONResponse(results)
|
training/preflight_check.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SalesPath — Pre-flight Dependency Check
|
| 3 |
+
Run at the start of training to catch version mismatches early.
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import importlib
|
| 7 |
+
|
| 8 |
+
REQUIRED_PACKAGES = {
|
| 9 |
+
"torch": "2.0.0",
|
| 10 |
+
"transformers": "4.44.2",
|
| 11 |
+
"trl": "0.11.0",
|
| 12 |
+
"peft": "0.11.1",
|
| 13 |
+
"datasets": "2.0.0",
|
| 14 |
+
"fastapi": "0.100.0",
|
| 15 |
+
"httpx": "0.24.0",
|
| 16 |
+
"openenv": None,
|
| 17 |
+
"accelerate": "0.25.0",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
all_ok = True
|
| 21 |
+
|
| 22 |
+
print("=" * 60)
|
| 23 |
+
print("SalesPath Pre-flight Check")
|
| 24 |
+
print("=" * 60)
|
| 25 |
+
|
| 26 |
+
# Python version
|
| 27 |
+
print(f"Python: {sys.version}")
|
| 28 |
+
if sys.version_info < (3, 10):
|
| 29 |
+
print(" WARNING: Python >= 3.10 recommended")
|
| 30 |
+
all_ok = False
|
| 31 |
+
|
| 32 |
+
# CUDA availability
|
| 33 |
+
try:
|
| 34 |
+
import torch
|
| 35 |
+
print(f"PyTorch: {torch.__version__}")
|
| 36 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 37 |
+
if torch.cuda.is_available():
|
| 38 |
+
print(f"CUDA version: {torch.version.cuda}")
|
| 39 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 40 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"PyTorch: ERROR — {e}")
|
| 43 |
+
all_ok = False
|
| 44 |
+
|
| 45 |
+
# Check each package
|
| 46 |
+
for pkg_name, min_version in REQUIRED_PACKAGES.items():
|
| 47 |
+
try:
|
| 48 |
+
mod = importlib.import_module(pkg_name)
|
| 49 |
+
ver = getattr(mod, "__version__", "unknown")
|
| 50 |
+
status = f"{ver}"
|
| 51 |
+
if min_version:
|
| 52 |
+
from packaging import version
|
| 53 |
+
if version.parse(ver) < version.parse(min_version):
|
| 54 |
+
status += f" (needs >= {min_version}) ⚠️"
|
| 55 |
+
all_ok = False
|
| 56 |
+
else:
|
| 57 |
+
status += " ✅"
|
| 58 |
+
else:
|
| 59 |
+
status += " ✅"
|
| 60 |
+
print(f"{pkg_name}: {status}")
|
| 61 |
+
except ImportError:
|
| 62 |
+
print(f"{pkg_name}: NOT FOUND ❌")
|
| 63 |
+
all_ok = False
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"{pkg_name}: ERROR — {e} ❌")
|
| 66 |
+
all_ok = False
|
| 67 |
+
|
| 68 |
+
print("=" * 60)
|
| 69 |
+
if all_ok:
|
| 70 |
+
print("All checks passed ✅")
|
| 71 |
+
else:
|
| 72 |
+
print("Some checks failed ⚠️ — training may still work")
|
| 73 |
+
print("=" * 60)
|
| 74 |
+
|
| 75 |
+
sys.exit(0 if all_ok else 1)
|