Spaces:
Sleeping
Sleeping
Akhil Soni commited on
Commit ·
e74ff96
1
Parent(s): 025774a
Add custom task input support and update URLs to HF Space
Browse files- reset() now accepts task="custom" with user-defined tasks, meetings, and energy
- Validates custom input (1-10 tasks, bounds clamping on all params)
- Existing easy/medium/hard scenarios unchanged
- Default URL updated from localhost to HF Space
- client.py +1 -1
- inference.py +1 -1
- server/rhythm_environment.py +44 -3
client.py
CHANGED
|
@@ -23,7 +23,7 @@ class RhythmEnv(EnvClient[RhythmAction, RhythmObservation, RhythmState]):
|
|
| 23 |
Client for the RhythmEnv Environment.
|
| 24 |
|
| 25 |
Example:
|
| 26 |
-
>>> async with RhythmEnv(base_url="
|
| 27 |
... result = await client.reset(task="easy")
|
| 28 |
... result = await client.step(RhythmAction(action_type=ActionType.START_TASK, task_id=0))
|
| 29 |
"""
|
|
|
|
| 23 |
Client for the RhythmEnv Environment.
|
| 24 |
|
| 25 |
Example:
|
| 26 |
+
>>> async with RhythmEnv(base_url="https://InosLihka-rhythm-env.hf.space") as client:
|
| 27 |
... result = await client.reset(task="easy")
|
| 28 |
... result = await client.step(RhythmAction(action_type=ActionType.START_TASK, task_id=0))
|
| 29 |
"""
|
inference.py
CHANGED
|
@@ -56,7 +56,7 @@ IMAGE_NAME = os.getenv("IMAGE_NAME")
|
|
| 56 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 57 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 58 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 59 |
-
BASE_URL = os.getenv("RHYTHM_ENV_URL", "
|
| 60 |
BENCHMARK = "rhythm_env"
|
| 61 |
TASKS = ["easy", "medium", "hard"]
|
| 62 |
MAX_STEPS = 20
|
|
|
|
| 56 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 57 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 58 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 59 |
+
BASE_URL = os.getenv("RHYTHM_ENV_URL", "https://InosLihka-rhythm-env.hf.space")
|
| 60 |
BENCHMARK = "rhythm_env"
|
| 61 |
TASKS = ["easy", "medium", "hard"]
|
| 62 |
MAX_STEPS = 20
|
server/rhythm_environment.py
CHANGED
|
@@ -253,10 +253,14 @@ class RhythmEnvironment(Environment):
|
|
| 253 |
**kwargs: Any,
|
| 254 |
) -> RhythmObservation:
|
| 255 |
task_name = kwargs.get("task", "easy")
|
| 256 |
-
if task_name not in TASK_CONFIGS:
|
| 257 |
-
task_name = "easy"
|
| 258 |
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# Deep-copy tasks so mutations don't affect the template
|
| 262 |
self._tasks = [dict(t) for t in config["tasks"]]
|
|
@@ -499,6 +503,43 @@ class RhythmEnvironment(Environment):
|
|
| 499 |
return False
|
| 500 |
return True
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
def _compute_mode(self) -> str:
|
| 503 |
"""Compute hidden internal mode (not exposed to agent)."""
|
| 504 |
if (
|
|
|
|
| 253 |
**kwargs: Any,
|
| 254 |
) -> RhythmObservation:
|
| 255 |
task_name = kwargs.get("task", "easy")
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
if task_name == "custom":
|
| 258 |
+
config = self._parse_custom_config(kwargs)
|
| 259 |
+
elif task_name in TASK_CONFIGS:
|
| 260 |
+
config = TASK_CONFIGS[task_name]
|
| 261 |
+
else:
|
| 262 |
+
task_name = "easy"
|
| 263 |
+
config = TASK_CONFIGS[task_name]
|
| 264 |
|
| 265 |
# Deep-copy tasks so mutations don't affect the template
|
| 266 |
self._tasks = [dict(t) for t in config["tasks"]]
|
|
|
|
| 503 |
return False
|
| 504 |
return True
|
| 505 |
|
| 506 |
+
@staticmethod
|
| 507 |
+
def _parse_custom_config(kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
| 508 |
+
"""Parse and validate a custom task configuration from reset kwargs."""
|
| 509 |
+
raw_tasks = kwargs.get("tasks")
|
| 510 |
+
if not raw_tasks or not isinstance(raw_tasks, list):
|
| 511 |
+
raise ValueError("Custom mode requires a 'tasks' list with at least 1 task.")
|
| 512 |
+
if len(raw_tasks) > 10:
|
| 513 |
+
raise ValueError("Maximum 10 tasks allowed.")
|
| 514 |
+
|
| 515 |
+
tasks = []
|
| 516 |
+
for i, t in enumerate(raw_tasks):
|
| 517 |
+
if not isinstance(t, dict):
|
| 518 |
+
raise ValueError(f"Task {i} must be a dict.")
|
| 519 |
+
tasks.append({
|
| 520 |
+
"id": i,
|
| 521 |
+
"name": str(t.get("name", f"Task {i}")),
|
| 522 |
+
"description": str(t.get("description", "")),
|
| 523 |
+
"effort": max(0.05, min(1.0, float(t.get("effort", 0.3)))),
|
| 524 |
+
"progress": 0.0,
|
| 525 |
+
"deadline": max(1, min(MAX_STEPS, int(t.get("deadline", MAX_STEPS - 2)))),
|
| 526 |
+
"importance": max(0.1, min(1.0, float(t.get("importance", 0.5)))),
|
| 527 |
+
})
|
| 528 |
+
|
| 529 |
+
meetings = kwargs.get("meetings", [])
|
| 530 |
+
if not isinstance(meetings, list):
|
| 531 |
+
meetings = []
|
| 532 |
+
meetings = [int(m) for m in meetings if 0 <= int(m) < MAX_STEPS]
|
| 533 |
+
|
| 534 |
+
initial_energy = max(0.1, min(1.0, float(kwargs.get("initial_energy", 0.8))))
|
| 535 |
+
|
| 536 |
+
return {
|
| 537 |
+
"scenario": "Custom task configuration.",
|
| 538 |
+
"tasks": tasks,
|
| 539 |
+
"meetings": meetings,
|
| 540 |
+
"initial_energy": initial_energy,
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
def _compute_mode(self) -> str:
|
| 544 |
"""Compute hidden internal mode (not exposed to agent)."""
|
| 545 |
if (
|