Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- physix/systems/registry.py +1 -2
- train/physix_train_colab.ipynb +24 -4
physix/systems/registry.py
CHANGED
|
@@ -22,8 +22,7 @@ SYSTEM_REGISTRY: dict[str, SystemFactory] = {
|
|
| 22 |
"charged_b_field": ChargedInBField,
|
| 23 |
}
|
| 24 |
|
| 25 |
-
|
| 26 |
-
# TODO: extend training to other systems before exposing them here.
|
| 27 |
SUPPORTED_SYSTEMS: tuple[str, ...] = (
|
| 28 |
"free_fall",
|
| 29 |
"simple_pendulum",
|
|
|
|
| 22 |
"charged_b_field": ChargedInBField,
|
| 23 |
}
|
| 24 |
|
| 25 |
+
|
|
|
|
| 26 |
SUPPORTED_SYSTEMS: tuple[str, ...] = (
|
| 27 |
"free_fall",
|
| 28 |
"simple_pendulum",
|
train/physix_train_colab.ipynb
CHANGED
|
@@ -291,6 +291,9 @@
|
|
| 291 |
" token=os.environ.get(\"HF_TOKEN\"),\n",
|
| 292 |
")\n",
|
| 293 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 294 |
"print(\"physix-live source at\", PHYSIX_LOCAL)\n",
|
| 295 |
"!ls {PHYSIX_LOCAL}"
|
| 296 |
],
|
|
@@ -301,8 +304,21 @@
|
|
| 301 |
"cell_type": "code",
|
| 302 |
"metadata": {},
|
| 303 |
"source": [
|
| 304 |
-
"
|
| 305 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
],
|
| 307 |
"execution_count": null,
|
| 308 |
"outputs": []
|
|
@@ -311,8 +327,12 @@
|
|
| 311 |
"cell_type": "code",
|
| 312 |
"metadata": {},
|
| 313 |
"source": [
|
| 314 |
-
"import
|
| 315 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
"print(f\"unsloth={unsloth.__version__} trl={trl.__version__} transformers={transformers.__version__} datasets={datasets.__version__}\")\n",
|
| 317 |
"print(f\"physix loaded from {physix.__file__}\")\n",
|
| 318 |
"assert trl.__version__ == \"0.24.0\", f\"trl must be pinned to 0.24.0, got {trl.__version__}\""
|
|
|
|
| 291 |
" token=os.environ.get(\"HF_TOKEN\"),\n",
|
| 292 |
")\n",
|
| 293 |
"\n",
|
| 294 |
+
"# Verify the bits cell 6 needs are present before continuing.\n",
|
| 295 |
+
"assert (PHYSIX_LOCAL / \"pyproject.toml\").is_file(), f\"missing pyproject.toml in {PHYSIX_LOCAL}\"\n",
|
| 296 |
+
"assert (PHYSIX_LOCAL / \"physix\" / \"__init__.py\").is_file(), f\"missing physix/ package in {PHYSIX_LOCAL}\"\n",
|
| 297 |
"print(\"physix-live source at\", PHYSIX_LOCAL)\n",
|
| 298 |
"!ls {PHYSIX_LOCAL}"
|
| 299 |
],
|
|
|
|
| 304 |
"cell_type": "code",
|
| 305 |
"metadata": {},
|
| 306 |
"source": [
|
| 307 |
+
"import subprocess, sys\n",
|
| 308 |
+
"result = subprocess.run(\n",
|
| 309 |
+
" [sys.executable, \"-m\", \"pip\", \"install\", \"--no-deps\", \"-e\", str(PHYSIX_LOCAL)],\n",
|
| 310 |
+
" capture_output=True, text=True, check=False,\n",
|
| 311 |
+
")\n",
|
| 312 |
+
"if result.returncode != 0:\n",
|
| 313 |
+
" print(\"pip stdout:\\n\", result.stdout)\n",
|
| 314 |
+
" print(\"pip stderr:\\n\", result.stderr)\n",
|
| 315 |
+
" raise RuntimeError(\n",
|
| 316 |
+
" f\"Editable install of {PHYSIX_LOCAL} failed (rc={result.returncode}). \"\n",
|
| 317 |
+
" \"See the pip output above. Common causes: (1) cell 5 didn't actually \"\n",
|
| 318 |
+
" \"download the bundle; (2) Colab pip is too old (run `!pip install -q \"\n",
|
| 319 |
+
" \"--upgrade pip` and retry); (3) a hatchling build backend hiccup — \"\n",
|
| 320 |
+
" \"rerun this cell.\")\n",
|
| 321 |
+
"print(result.stdout.splitlines()[-1] if result.stdout else \"install ok\")"
|
| 322 |
],
|
| 323 |
"execution_count": null,
|
| 324 |
"outputs": []
|
|
|
|
| 327 |
"cell_type": "code",
|
| 328 |
"metadata": {},
|
| 329 |
"source": [
|
| 330 |
+
"import unsloth # MUST come before trl / transformers so its monkey-patches land first\n",
|
| 331 |
+
"import torch, trl, transformers, datasets, wandb\n",
|
| 332 |
+
"import physix\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"device = torch.cuda.get_device_name(0) if torch.cuda.is_available() else None\n",
|
| 335 |
+
"print(f\"torch={torch.__version__} cuda={torch.cuda.is_available()} device={device}\")\n",
|
| 336 |
"print(f\"unsloth={unsloth.__version__} trl={trl.__version__} transformers={transformers.__version__} datasets={datasets.__version__}\")\n",
|
| 337 |
"print(f\"physix loaded from {physix.__file__}\")\n",
|
| 338 |
"assert trl.__version__ == \"0.24.0\", f\"trl must be pinned to 0.24.0, got {trl.__version__}\""
|