shank commited on
Commit Β·
663b8db
1
Parent(s): 8f291e0
Stabilize Space runtime: pin ML deps and disable runtime package drift
Browse files- README.md +1 -0
- training/train_grpo.py +27 -3
README.md
CHANGED
|
@@ -5,6 +5,7 @@ colorFrom: blue
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
app_file: app.py
|
|
|
|
| 8 |
pinned: true
|
| 9 |
license: mit
|
| 10 |
---
|
|
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
app_file: app.py
|
| 8 |
+
python_version: 3.10.13
|
| 9 |
pinned: true
|
| 10 |
license: mit
|
| 11 |
---
|
training/train_grpo.py
CHANGED
|
@@ -26,6 +26,7 @@ import random
|
|
| 26 |
import subprocess
|
| 27 |
import tempfile
|
| 28 |
import shutil
|
|
|
|
| 29 |
|
| 30 |
# ββ Parse args ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
parser = argparse.ArgumentParser()
|
|
@@ -36,9 +37,15 @@ parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint
|
|
| 36 |
parser.add_argument("--max_steps", type=int, default=500)
|
| 37 |
args = parser.parse_args()
|
| 38 |
|
| 39 |
-
# ββ
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# ββ GPU/training imports (skipped in --test-local mode) βββββββββββββββββββββββ
|
| 44 |
if not args.test_local:
|
|
@@ -51,6 +58,23 @@ if not args.test_local:
|
|
| 51 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 52 |
from trl import GRPOTrainer, GRPOConfig
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 55 |
from server.reward_calculator import DebugRewardCalculator
|
| 56 |
from server.models import parse_agent_output
|
|
|
|
| 26 |
import subprocess
|
| 27 |
import tempfile
|
| 28 |
import shutil
|
| 29 |
+
from importlib import metadata
|
| 30 |
|
| 31 |
# ββ Parse args ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
parser = argparse.ArgumentParser()
|
|
|
|
| 37 |
parser.add_argument("--max_steps", type=int, default=500)
|
| 38 |
args = parser.parse_args()
|
| 39 |
|
| 40 |
+
# ββ Optional dependency bootstrap (disabled by default in Spaces) βββββββββββββ
|
| 41 |
+
# Runtime installs with loose versions caused repeated breakages from version drift.
|
| 42 |
+
# Keep this opt-in for fresh Colab notebooks only.
|
| 43 |
+
if os.environ.get("FORCE_BOOTSTRAP_DEPS") == "1":
|
| 44 |
+
os.system(
|
| 45 |
+
f"{sys.executable} -m pip install -q "
|
| 46 |
+
"wandb==0.18.7 datasets==3.0.2 transformers==4.46.3 "
|
| 47 |
+
"accelerate==1.0.1 trl==0.12.2 bitsandbytes==0.43.3 peft==0.13.2"
|
| 48 |
+
)
|
| 49 |
|
| 50 |
# ββ GPU/training imports (skipped in --test-local mode) βββββββββββββββββββββββ
|
| 51 |
if not args.test_local:
|
|
|
|
| 58 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 59 |
from trl import GRPOTrainer, GRPOConfig
|
| 60 |
|
| 61 |
+
def _pkg_ver(name: str) -> str:
|
| 62 |
+
try:
|
| 63 |
+
return metadata.version(name)
|
| 64 |
+
except metadata.PackageNotFoundError:
|
| 65 |
+
return "not-installed"
|
| 66 |
+
|
| 67 |
+
print(
|
| 68 |
+
"Runtime package versions | "
|
| 69 |
+
f"python={sys.version.split()[0]} "
|
| 70 |
+
f"torch={_pkg_ver('torch')} "
|
| 71 |
+
f"transformers={_pkg_ver('transformers')} "
|
| 72 |
+
f"trl={_pkg_ver('trl')} "
|
| 73 |
+
f"accelerate={_pkg_ver('accelerate')} "
|
| 74 |
+
f"peft={_pkg_ver('peft')} "
|
| 75 |
+
f"bitsandbytes={_pkg_ver('bitsandbytes')}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 79 |
from server.reward_calculator import DebugRewardCalculator
|
| 80 |
from server.models import parse_agent_output
|