Akhil Soni commited on
Commit
e74ff96
·
1 Parent(s): 025774a

Add custom task input support and update URLs to HF Space

Browse files

- reset() now accepts task="custom" with user-defined tasks, meetings, and energy
- Validates custom input (1-10 tasks, bounds clamping on all params)
- Existing easy/medium/hard scenarios unchanged
- Default URL updated from localhost to HF Space

Files changed (3) hide show
  1. client.py +1 -1
  2. inference.py +1 -1
  3. server/rhythm_environment.py +44 -3
client.py CHANGED
@@ -23,7 +23,7 @@ class RhythmEnv(EnvClient[RhythmAction, RhythmObservation, RhythmState]):
23
  Client for the RhythmEnv Environment.
24
 
25
  Example:
26
- >>> async with RhythmEnv(base_url="http://localhost:8000") as client:
27
  ... result = await client.reset(task="easy")
28
  ... result = await client.step(RhythmAction(action_type=ActionType.START_TASK, task_id=0))
29
  """
 
23
  Client for the RhythmEnv Environment.
24
 
25
  Example:
26
+ >>> async with RhythmEnv(base_url="https://InosLihka-rhythm-env.hf.space") as client:
27
  ... result = await client.reset(task="easy")
28
  ... result = await client.step(RhythmAction(action_type=ActionType.START_TASK, task_id=0))
29
  """
inference.py CHANGED
@@ -56,7 +56,7 @@ IMAGE_NAME = os.getenv("IMAGE_NAME")
56
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
57
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
58
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
59
- BASE_URL = os.getenv("RHYTHM_ENV_URL", "http://localhost:8000")
60
  BENCHMARK = "rhythm_env"
61
  TASKS = ["easy", "medium", "hard"]
62
  MAX_STEPS = 20
 
56
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
57
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
58
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
59
+ BASE_URL = os.getenv("RHYTHM_ENV_URL", "https://InosLihka-rhythm-env.hf.space")
60
  BENCHMARK = "rhythm_env"
61
  TASKS = ["easy", "medium", "hard"]
62
  MAX_STEPS = 20
server/rhythm_environment.py CHANGED
@@ -253,10 +253,14 @@ class RhythmEnvironment(Environment):
253
  **kwargs: Any,
254
  ) -> RhythmObservation:
255
  task_name = kwargs.get("task", "easy")
256
- if task_name not in TASK_CONFIGS:
257
- task_name = "easy"
258
 
259
- config = TASK_CONFIGS[task_name]
 
 
 
 
 
 
260
 
261
  # Deep-copy tasks so mutations don't affect the template
262
  self._tasks = [dict(t) for t in config["tasks"]]
@@ -499,6 +503,43 @@ class RhythmEnvironment(Environment):
499
  return False
500
  return True
501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  def _compute_mode(self) -> str:
503
  """Compute hidden internal mode (not exposed to agent)."""
504
  if (
 
253
  **kwargs: Any,
254
  ) -> RhythmObservation:
255
  task_name = kwargs.get("task", "easy")
 
 
256
 
257
+ if task_name == "custom":
258
+ config = self._parse_custom_config(kwargs)
259
+ elif task_name in TASK_CONFIGS:
260
+ config = TASK_CONFIGS[task_name]
261
+ else:
262
+ task_name = "easy"
263
+ config = TASK_CONFIGS[task_name]
264
 
265
  # Deep-copy tasks so mutations don't affect the template
266
  self._tasks = [dict(t) for t in config["tasks"]]
 
503
  return False
504
  return True
505
 
506
+ @staticmethod
507
+ def _parse_custom_config(kwargs: Dict[str, Any]) -> Dict[str, Any]:
508
+ """Parse and validate a custom task configuration from reset kwargs."""
509
+ raw_tasks = kwargs.get("tasks")
510
+ if not raw_tasks or not isinstance(raw_tasks, list):
511
+ raise ValueError("Custom mode requires a 'tasks' list with at least 1 task.")
512
+ if len(raw_tasks) > 10:
513
+ raise ValueError("Maximum 10 tasks allowed.")
514
+
515
+ tasks = []
516
+ for i, t in enumerate(raw_tasks):
517
+ if not isinstance(t, dict):
518
+ raise ValueError(f"Task {i} must be a dict.")
519
+ tasks.append({
520
+ "id": i,
521
+ "name": str(t.get("name", f"Task {i}")),
522
+ "description": str(t.get("description", "")),
523
+ "effort": max(0.05, min(1.0, float(t.get("effort", 0.3)))),
524
+ "progress": 0.0,
525
+ "deadline": max(1, min(MAX_STEPS, int(t.get("deadline", MAX_STEPS - 2)))),
526
+ "importance": max(0.1, min(1.0, float(t.get("importance", 0.5)))),
527
+ })
528
+
529
+ meetings = kwargs.get("meetings", [])
530
+ if not isinstance(meetings, list):
531
+ meetings = []
532
+ meetings = [int(m) for m in meetings if 0 <= int(m) < MAX_STEPS]
533
+
534
+ initial_energy = max(0.1, min(1.0, float(kwargs.get("initial_energy", 0.8))))
535
+
536
+ return {
537
+ "scenario": "Custom task configuration.",
538
+ "tasks": tasks,
539
+ "meetings": meetings,
540
+ "initial_energy": initial_energy,
541
+ }
542
+
543
  def _compute_mode(self) -> str:
544
  """Compute hidden internal mode (not exposed to agent)."""
545
  if (