shank commited on
Commit Β·
2bfaf77
1
Parent(s): 5d0b2d4
fix: torch at build time, remove mergekit (conflicts accelerate/peft/trl)
Browse files- requirements.txt +3 -3
- training/train_grpo.py +6 -5
requirements.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
|
|
|
|
| 1 |
+
# torch must be installed at build time (CUDA wheel is ~2GB, too slow at runtime)
|
| 2 |
+
# Everything else is installed at runtime in training/train_grpo.py
|
| 3 |
+
torch
|
training/train_grpo.py
CHANGED
|
@@ -38,8 +38,10 @@ parser.add_argument("--max_steps", type=int, default=500)
|
|
| 38 |
args = parser.parse_args()
|
| 39 |
|
| 40 |
# ββ Runtime dependency install βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
-
# requirements.txt
|
| 42 |
-
#
|
|
|
|
|
|
|
| 43 |
if not args.test_local:
|
| 44 |
_TRAIN_DEPS = [
|
| 45 |
"wandb==0.18.7",
|
|
@@ -49,15 +51,14 @@ if not args.test_local:
|
|
| 49 |
"trl==0.14.0",
|
| 50 |
"peft==0.13.2",
|
| 51 |
"bitsandbytes==0.43.3",
|
| 52 |
-
"mergekit==0.1.4",
|
| 53 |
-
"pydantic==2.10.6",
|
| 54 |
]
|
| 55 |
print("Installing training dependencies...", flush=True)
|
| 56 |
ret = os.system(
|
| 57 |
f"{sys.executable} -m pip install -q --no-cache-dir " + " ".join(f'"{d}"' for d in _TRAIN_DEPS)
|
| 58 |
)
|
| 59 |
if ret != 0:
|
| 60 |
-
print("
|
|
|
|
| 61 |
print("Dependencies installed.", flush=True)
|
| 62 |
|
| 63 |
# ββ GPU/training imports (skipped in --test-local mode) βββββββββββββββββββββββ
|
|
|
|
| 38 |
args = parser.parse_args()
|
| 39 |
|
| 40 |
# ββ Runtime dependency install βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# requirements.txt only has torch (too large to install at runtime).
|
| 42 |
+
# Everything else is installed here, after gradio is already up.
|
| 43 |
+
# NOTE: mergekit intentionally excluded β conflicts with accelerate/peft/trl.
|
| 44 |
+
# NOTE: torch excluded β installed at Docker build time via requirements.txt.
|
| 45 |
if not args.test_local:
|
| 46 |
_TRAIN_DEPS = [
|
| 47 |
"wandb==0.18.7",
|
|
|
|
| 51 |
"trl==0.14.0",
|
| 52 |
"peft==0.13.2",
|
| 53 |
"bitsandbytes==0.43.3",
|
|
|
|
|
|
|
| 54 |
]
|
| 55 |
print("Installing training dependencies...", flush=True)
|
| 56 |
ret = os.system(
|
| 57 |
f"{sys.executable} -m pip install -q --no-cache-dir " + " ".join(f'"{d}"' for d in _TRAIN_DEPS)
|
| 58 |
)
|
| 59 |
if ret != 0:
|
| 60 |
+
print("ERROR: pip install failed. Training cannot continue.", flush=True)
|
| 61 |
+
sys.exit(1)
|
| 62 |
print("Dependencies installed.", flush=True)
|
| 63 |
|
| 64 |
# ββ GPU/training imports (skipped in --test-local mode) βββββββββββββββββββββββ
|