core-identity-env / server /core_identity_environment.py
hirann
Add Core Identity environment for OpenEnv
1f9fc8c
"""Core Identity Environment - server-side implementation."""
import random
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
from uuid import uuid4
try:
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, Observation, State
from core_identity_env.models import (
CoreIdentityObservation,
CoreIdentityAction,
VerificationResult,
TaskType,
IdentityDocument,
UserCredentials,
UserProfile,
)
from core_identity_env.tasks.definitions import (
get_all_tasks,
get_task_by_id,
CoreIdentityTask,
CoreIdentityTaskEvaluator,
GradingResult,
)
@dataclass
class _EpisodeState:
task: CoreIdentityTask
episode_id: str
current_step: int = 0
cumulative_reward: float = 0.0
submitted_verification: Dict[str, Any] = field(default_factory=dict)
episode_complete: bool = False
DIFFICULTY_WEIGHTS = {
"easy": 0.15,
"medium": 0.12,
"hard": 0.08,
}
class CoreIdentityEnvironment(Environment):
"""Server-side Core Identity Environment compliant with OpenEnv spec."""
def __init__(
self,
task_id: Optional[str] = None,
seed: Optional[int] = None,
max_steps: int = 10,
):
self._task_id = task_id
self._seed = seed
self._max_steps = max_steps
self._ep: Optional[_EpisodeState] = None
self._state = State(episode_id=str(uuid4()), step_count=0)
if seed is not None:
random.seed(seed)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
**kwargs: Any,
) -> Observation:
if seed is not None:
random.seed(seed)
target_task_id = task_id or self._task_id
if target_task_id:
task = get_task_by_id(target_task_id)
else:
task = random.choice(get_all_tasks())
eid = episode_id or str(uuid4())
document = None
credentials = None
profile = None
if task.document:
document = IdentityDocument(**task.document)
if task.credentials:
credentials = UserCredentials(**task.credentials)
if task.profile:
profile = UserProfile(**task.profile)
self._ep = _EpisodeState(
task=task,
episode_id=eid,
current_step=0,
cumulative_reward=0.0,
submitted_verification={},
episode_complete=False,
)
self._state = State(episode_id=eid, step_count=0)
obs = self._build_observation(document, credentials, profile)
return Observation(
done=False,
reward=0.0,
metadata=obs.model_dump(),
)
def step(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
if self._ep is None:
return Observation(
done=True,
reward=0.0,
metadata={"error": "Environment not reset. Call reset() first."},
)
if self._ep.episode_complete:
return Observation(
done=True,
reward=0.0,
metadata={"error": "Episode already finished."},
)
action_data: Dict[str, Any] = {}
if hasattr(action, "data") and isinstance(action.data, dict):
action_data = action.data
elif isinstance(action, dict):
action_data = action
elif hasattr(action, "__dict__"):
action_data = vars(action)
try:
env_action = CoreIdentityAction.model_validate(action_data)
except Exception:
env_action = CoreIdentityAction(verification=VerificationResult())
self._ep.current_step += 1
self._state.step_count = self._ep.current_step
verification_dict = env_action.verification.model_dump()
self._ep.submitted_verification = verification_dict
evaluator = CoreIdentityTaskEvaluator(self._ep.task)
result = evaluator.grade(verification_dict)
reward_value = result.score * DIFFICULTY_WEIGHTS.get(self._ep.task.difficulty, 0.1)
terminal = env_action.submit or self._ep.current_step >= self._ep.task.max_steps
if terminal:
final_score = result.score
if self._ep.current_step <= self._ep.task.max_steps * 0.5:
final_score += 0.1
reward_value = min(1.0, final_score)
self._ep.episode_complete = True
self._ep.cumulative_reward += reward_value
document = None
credentials = None
profile = None
if self._ep.task.document:
document = IdentityDocument(**self._ep.task.document)
if self._ep.task.credentials:
credentials = UserCredentials(**self._ep.task.credentials)
if self._ep.task.profile:
profile = UserProfile(**self._ep.task.profile)
obs = self._build_observation(document, credentials, profile)
step_result = {
"observation": obs.model_dump(),
"reward": {
"value": reward_value,
"accuracy": result.accuracy,
"completeness": result.completeness,
"total": result.score,
"feedback": result.feedback,
},
"done": self._ep.episode_complete,
"info": {
"step": self._ep.current_step,
"is_final": self._ep.episode_complete,
},
}
return Observation(
done=self._ep.episode_complete,
reward=reward_value,
metadata=step_result,
)
@property
def state(self) -> State:
return self._state
def _build_observation(
self,
document: Optional[IdentityDocument],
credentials: Optional[UserCredentials],
profile: Optional[UserProfile],
) -> CoreIdentityObservation:
ep = self._ep
return CoreIdentityObservation(
task_id=ep.task.task_id,
task_type=TaskType(ep.task.task_type),
task_name=ep.task.name,
task_description=ep.task.description,
difficulty=ep.task.difficulty,
document=document,
credentials=credentials,
profile=profile,
expected_verification=ep.task.expected_verification,
challenge_data=ep.task.challenge_data,
max_steps=ep.task.max_steps,
)
def get_task_list(self) -> Dict[str, Any]:
return [
{
"task_id": t.task_id,
"name": t.name,
"task_type": t.task_type,
"difficulty": t.difficulty,
"description": t.description,
}
for t in get_all_tasks()
]