Spaces:
Runtime error
Runtime error
| """Core Identity Environment - server-side implementation.""" | |
| import random | |
| from typing import Any, Dict, Optional | |
| from dataclasses import dataclass, field | |
| from uuid import uuid4 | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import Action, Observation, State | |
| except ImportError: | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import Action, Observation, State | |
| from core_identity_env.models import ( | |
| CoreIdentityObservation, | |
| CoreIdentityAction, | |
| VerificationResult, | |
| TaskType, | |
| IdentityDocument, | |
| UserCredentials, | |
| UserProfile, | |
| ) | |
| from core_identity_env.tasks.definitions import ( | |
| get_all_tasks, | |
| get_task_by_id, | |
| CoreIdentityTask, | |
| CoreIdentityTaskEvaluator, | |
| GradingResult, | |
| ) | |
| class _EpisodeState: | |
| task: CoreIdentityTask | |
| episode_id: str | |
| current_step: int = 0 | |
| cumulative_reward: float = 0.0 | |
| submitted_verification: Dict[str, Any] = field(default_factory=dict) | |
| episode_complete: bool = False | |
| DIFFICULTY_WEIGHTS = { | |
| "easy": 0.15, | |
| "medium": 0.12, | |
| "hard": 0.08, | |
| } | |
| class CoreIdentityEnvironment(Environment): | |
| """Server-side Core Identity Environment compliant with OpenEnv spec.""" | |
| def __init__( | |
| self, | |
| task_id: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| max_steps: int = 10, | |
| ): | |
| self._task_id = task_id | |
| self._seed = seed | |
| self._max_steps = max_steps | |
| self._ep: Optional[_EpisodeState] = None | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| if seed is not None: | |
| random.seed(seed) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| if seed is not None: | |
| random.seed(seed) | |
| target_task_id = task_id or self._task_id | |
| if target_task_id: | |
| task = get_task_by_id(target_task_id) | |
| else: | |
| task = random.choice(get_all_tasks()) | |
| eid = episode_id or str(uuid4()) | |
| document = None | |
| credentials = None | |
| profile = None | |
| if task.document: | |
| document = IdentityDocument(**task.document) | |
| if task.credentials: | |
| credentials = UserCredentials(**task.credentials) | |
| if task.profile: | |
| profile = UserProfile(**task.profile) | |
| self._ep = _EpisodeState( | |
| task=task, | |
| episode_id=eid, | |
| current_step=0, | |
| cumulative_reward=0.0, | |
| submitted_verification={}, | |
| episode_complete=False, | |
| ) | |
| self._state = State(episode_id=eid, step_count=0) | |
| obs = self._build_observation(document, credentials, profile) | |
| return Observation( | |
| done=False, | |
| reward=0.0, | |
| metadata=obs.model_dump(), | |
| ) | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| if self._ep is None: | |
| return Observation( | |
| done=True, | |
| reward=0.0, | |
| metadata={"error": "Environment not reset. Call reset() first."}, | |
| ) | |
| if self._ep.episode_complete: | |
| return Observation( | |
| done=True, | |
| reward=0.0, | |
| metadata={"error": "Episode already finished."}, | |
| ) | |
| action_data: Dict[str, Any] = {} | |
| if hasattr(action, "data") and isinstance(action.data, dict): | |
| action_data = action.data | |
| elif isinstance(action, dict): | |
| action_data = action | |
| elif hasattr(action, "__dict__"): | |
| action_data = vars(action) | |
| try: | |
| env_action = CoreIdentityAction.model_validate(action_data) | |
| except Exception: | |
| env_action = CoreIdentityAction(verification=VerificationResult()) | |
| self._ep.current_step += 1 | |
| self._state.step_count = self._ep.current_step | |
| verification_dict = env_action.verification.model_dump() | |
| self._ep.submitted_verification = verification_dict | |
| evaluator = CoreIdentityTaskEvaluator(self._ep.task) | |
| result = evaluator.grade(verification_dict) | |
| reward_value = result.score * DIFFICULTY_WEIGHTS.get(self._ep.task.difficulty, 0.1) | |
| terminal = env_action.submit or self._ep.current_step >= self._ep.task.max_steps | |
| if terminal: | |
| final_score = result.score | |
| if self._ep.current_step <= self._ep.task.max_steps * 0.5: | |
| final_score += 0.1 | |
| reward_value = min(1.0, final_score) | |
| self._ep.episode_complete = True | |
| self._ep.cumulative_reward += reward_value | |
| document = None | |
| credentials = None | |
| profile = None | |
| if self._ep.task.document: | |
| document = IdentityDocument(**self._ep.task.document) | |
| if self._ep.task.credentials: | |
| credentials = UserCredentials(**self._ep.task.credentials) | |
| if self._ep.task.profile: | |
| profile = UserProfile(**self._ep.task.profile) | |
| obs = self._build_observation(document, credentials, profile) | |
| step_result = { | |
| "observation": obs.model_dump(), | |
| "reward": { | |
| "value": reward_value, | |
| "accuracy": result.accuracy, | |
| "completeness": result.completeness, | |
| "total": result.score, | |
| "feedback": result.feedback, | |
| }, | |
| "done": self._ep.episode_complete, | |
| "info": { | |
| "step": self._ep.current_step, | |
| "is_final": self._ep.episode_complete, | |
| }, | |
| } | |
| return Observation( | |
| done=self._ep.episode_complete, | |
| reward=reward_value, | |
| metadata=step_result, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| def _build_observation( | |
| self, | |
| document: Optional[IdentityDocument], | |
| credentials: Optional[UserCredentials], | |
| profile: Optional[UserProfile], | |
| ) -> CoreIdentityObservation: | |
| ep = self._ep | |
| return CoreIdentityObservation( | |
| task_id=ep.task.task_id, | |
| task_type=TaskType(ep.task.task_type), | |
| task_name=ep.task.name, | |
| task_description=ep.task.description, | |
| difficulty=ep.task.difficulty, | |
| document=document, | |
| credentials=credentials, | |
| profile=profile, | |
| expected_verification=ep.task.expected_verification, | |
| challenge_data=ep.task.challenge_data, | |
| max_steps=ep.task.max_steps, | |
| ) | |
| def get_task_list(self) -> Dict[str, Any]: | |
| return [ | |
| { | |
| "task_id": t.task_id, | |
| "name": t.name, | |
| "task_type": t.task_type, | |
| "difficulty": t.difficulty, | |
| "description": t.description, | |
| } | |
| for t in get_all_tasks() | |
| ] |