medusa_env / client.py
rampluto's picture
Upload folder using huggingface_hub
fe7e1aa verified
"""MEDUSA Environment Client.
Connects to a running MEDUSA server via WebSocket for persistent sessions.
Example:
>>> # Connect to a running server
>>> with medusa_env(base_url="http://localhost:8000") as client:
... result = client.reset(seed=0)
... print(result.observation.message)
...
... from envs.medusa_env.models import MedusaActionType
... result = client.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
... print(f"Reward: {result.reward}")
Example with Docker:
>>> client = medusa_env.from_docker_image("medusa_env:latest")
>>> try:
... result = client.reset()
... result = client.step(MedusaAction(action=MedusaActionType.COMMIT))
... finally:
... client.close()
"""
from typing import Any, Dict
# Support both in-repo and standalone imports
try:
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from medusa_env.models import MedusaAction, MedusaObservation, MedusaState
except ImportError:
from models import MedusaAction, MedusaObservation, MedusaState
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
class medusa_env(EnvClient[MedusaAction, MedusaObservation, MedusaState]):
"""Client for the MEDUSA Bronze→Silver integration environment.
Maintains a persistent WebSocket connection to the MEDUSA server.
Each client instance has its own dedicated environment session.
The agent observes a 16-float data quality feature vector and chooses
from 11 discrete ETL actions to build a correct Silver entity from
two Bronze sources (Fact + Dimension).
Example:
>>> with medusa_env(base_url="http://localhost:8000") as env:
... result = env.reset(seed=0) # clean scenario
... result = env.step(MedusaAction(action=MedusaActionType.SYNC_CHECK))
... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_A))
... result = env.step(MedusaAction(action=MedusaActionType.PREP_KEYS_B))
... result = env.step(MedusaAction(action=MedusaActionType.DEDUPLICATE_B))
... result = env.step(MedusaAction(action=MedusaActionType.EXECUTE_JOIN_LEFT))
... result = env.step(MedusaAction(action=MedusaActionType.APPLY_SCD_2))
... result = env.step(MedusaAction(action=MedusaActionType.COMMIT))
... print(result.reward)
"""
def _step_payload(self, action: MedusaAction) -> Dict[str, Any]:
"""Convert MedusaAction to JSON payload for the step request."""
return {
"action": action.action.value,
"params": action.params,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MedusaObservation]:
"""Parse server response into StepResult[MedusaObservation]."""
obs_data = payload.get("observation", {})
observation = MedusaObservation(
message=obs_data.get("message", ""),
features=obs_data.get("features", []),
metrics=obs_data.get("metrics", {}),
metadata=obs_data.get("metadata", {}),
reward=payload.get("reward"),
done=payload.get("done", False),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> MedusaState:
"""Parse server response into MedusaState."""
return MedusaState(
run_id=payload.get("run_id"),
seed=payload.get("seed"),
scenario_id=payload.get("scenario_id"),
step_idx=payload.get("step_idx", 0),
stage=payload.get("stage", "init"),
# Freshness
time_delta_a=payload.get("time_delta_a", 0.0),
time_delta_b=payload.get("time_delta_b", 0.0),
is_stale_a=payload.get("is_stale_a", False),
is_stale_b=payload.get("is_stale_b", False),
did_sync_check=payload.get("did_sync_check", False),
# Key health
null_ratio_key_a=payload.get("null_ratio_key_a", 0.0),
null_ratio_key_b=payload.get("null_ratio_key_b", 0.0),
uniqueness_a=payload.get("uniqueness_a", 1.0),
uniqueness_b=payload.get("uniqueness_b", 1.0),
did_prep_a=payload.get("did_prep_a", False),
did_prep_b=payload.get("did_prep_b", False),
did_dedup_b=payload.get("did_dedup_b", False),
# Join
match_rate=payload.get("match_rate", 0.0),
did_join=payload.get("did_join", False),
join_type=payload.get("join_type"),
join_row_count=payload.get("join_row_count", 0),
explosion_detected=payload.get("explosion_detected", False),
# SCD
did_scd=payload.get("did_scd", False),
scd_type=payload.get("scd_type"),
scd_inserts=payload.get("scd_inserts", 0),
scd_updates=payload.get("scd_updates", 0),
# Silver / Quarantine
silver_row_count=payload.get("silver_row_count", 0),
quarantine_row_count=payload.get("quarantine_row_count", 0),
source_a_row_count=payload.get("source_a_row_count", 0),
# Grader
grader_passed=payload.get("grader_passed", False),
grader_report=payload.get("grader_report", ""),
cumulative_reward=payload.get("cumulative_reward", 0.0),
)