dataops-env / inference.py
Graheet
Align END log format with validator: use score field
4d67e80
"""Inference entrypoints for ``dataops-gym``.
This runner keeps the hackathon-required OpenAI-compatible model interface, but
adds a stronger local planner so baseline behavior is still competitive and
reproducible when the model is weak, unavailable, or partially aligned.
"""
from __future__ import annotations
import ast
from collections import Counter, defaultdict
import hashlib
import json
import os
import re
import textwrap
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from openai import OpenAI
from env import DataOpsEnv
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3-VL-30B-A3B-Instruct:novita")
HF_TOKEN = os.getenv("HF_TOKEN")
BENCHMARK = os.getenv("BROWSERGYM_BENCHMARK", "dataops-env")
TASK_NAME = os.getenv("BROWSERGYM_TASK_NAME", "all")
TASK_ORDER = ["easy", "medium", "hard"]
MAX_STEPS = 10
TEMPERATURE = 0.0
MAX_TOKENS = 160
MODEL_RETRIES = 2
FALLBACK_ACTION = "skip(record_id='0', field='record', confidence=0.0)"
ACTION_PREFIX_RE = re.compile(r"^(action|next action)\s*[:\-]\s*", re.IGNORECASE)
EMAIL_PATTERN = re.compile(r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$")
IDENTIFIER_COLUMNS = ("customer_id", "vendor_id", "partner_id")
POLICY_CACHE_PATH = os.getenv("POLICY_CACHE_PATH", ".dataops_policy_cache.json")
POLICY_CACHE_VERSION = 1
SYSTEM_PROMPT = textwrap.dedent(
"""
You control a data-cleaning environment.
Reply with exactly one action string and nothing else.
Only choose from the candidate actions provided by the user prompt.
Favor actions that remove visible issues quickly and avoid actions that were
already blocked because they caused errors or no progress.
Use single quotes for string arguments.
"""
).strip()
class PolicyMemory:
"""Persistent lightweight experience cache used across episodes and runs."""
def __init__(self, path: str) -> None:
self.path = path
self.data: Dict[str, Any] = {
"version": POLICY_CACHE_VERSION,
"states": {},
"patterns": {},
}
self._load()
def _load(self) -> None:
"""Load cache from disk if it exists and is compatible."""
if not os.path.exists(self.path):
return
try:
with open(self.path, "r", encoding="utf-8") as handle:
payload = json.load(handle)
except (OSError, json.JSONDecodeError):
return
if not isinstance(payload, dict):
return
if int(payload.get("version", 0)) != POLICY_CACHE_VERSION:
return
self.data = payload
def save(self) -> None:
"""Persist the current cache contents to disk."""
temp_path = f"{self.path}.tmp"
with open(temp_path, "w", encoding="utf-8") as handle:
json.dump(self.data, handle, indent=2, sort_keys=True)
os.replace(temp_path, self.path)
def _bucket(self, bucket_name: str, key: str) -> Dict[str, Any]:
"""Return the mutable bucket for an exact state or a problem pattern."""
return self.data.setdefault(bucket_name, {}).setdefault(key, {"actions": {}})
def _action_stats(self, bucket_name: str, key: str, action_text: str) -> Dict[str, Any]:
"""Return mutable stats for an action within a memory bucket."""
actions = self._bucket(bucket_name, key).setdefault("actions", {})
return actions.setdefault(
action_text,
{
"attempts": 0,
"successes": 0,
"progresses": 0,
"failures": 0,
"cumulative_reward": 0.0,
"last_error": None,
},
)
def update(
self,
*,
state_key: str,
pattern_key: str,
action_text: str,
reward: float,
progress_delta: float,
error: Optional[str],
done: bool,
task_score: float,
) -> None:
"""Record one step outcome for exact-state and problem-pattern memory."""
was_success = bool(done and task_score >= 0.95 and error is None)
made_progress = bool(progress_delta > 0.0 or reward > 0.0)
was_failure = bool(error is not None or (progress_delta == 0.0 and reward <= 0.0))
for bucket_name, key in (("states", state_key), ("patterns", pattern_key)):
stats = self._action_stats(bucket_name, key, action_text)
stats["attempts"] += 1
stats["cumulative_reward"] = round(
float(stats["cumulative_reward"]) + float(reward),
4,
)
stats["last_error"] = error
if was_success:
stats["successes"] += 1
elif made_progress:
stats["progresses"] += 1
if was_failure:
stats["failures"] += 1
def _combined_stats(self, state_key: str, pattern_key: str, action_text: str) -> Dict[str, float]:
"""Merge exact-state and pattern-level stats into one weighted view."""
combined = {
"attempts": 0.0,
"successes": 0.0,
"progresses": 0.0,
"failures": 0.0,
"cumulative_reward": 0.0,
}
for bucket_name, key, weight in (
("states", state_key, 1.0),
("patterns", pattern_key, 0.5),
):
stats = self.data.get(bucket_name, {}).get(key, {}).get("actions", {}).get(action_text)
if not isinstance(stats, dict):
continue
for field in combined:
combined[field] += float(stats.get(field, 0.0)) * weight
return combined
def score_action(self, state_key: str, pattern_key: str, action_text: str) -> float:
"""Score a candidate action using remembered prior outcomes."""
stats = self._combined_stats(state_key, pattern_key, action_text)
attempts = max(1.0, stats["attempts"])
average_reward = stats["cumulative_reward"] / attempts
return round(
(stats["successes"] * 3.0)
+ (stats["progresses"] * 1.25)
+ average_reward
- (stats["failures"] * 2.0),
4,
)
def blocked_actions(self, state_key: str, pattern_key: str) -> set[str]:
"""Return actions that repeatedly failed for the same state or pattern."""
blocked: set[str] = set()
for bucket_name, key in (("states", state_key), ("patterns", pattern_key)):
actions = self.data.get(bucket_name, {}).get(key, {}).get("actions", {})
for action_text, stats in actions.items():
attempts = int(stats.get("attempts", 0))
failures = int(stats.get("failures", 0))
successes = int(stats.get("successes", 0))
progresses = int(stats.get("progresses", 0))
if attempts >= 2 and failures >= attempts and successes == 0 and progresses == 0:
blocked.add(action_text)
return blocked
def log_start(task: str, env: str, model: str) -> None:
"""Emit the required episode start line."""
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
"""Emit the required per-step line."""
error_value = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error_value}",
flush=True,
)
def log_end(success: bool, steps: int, rewards: List[float], final_score: float) -> None:
"""Emit the required episode end line."""
rewards_text = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={final_score:.2f} rewards={rewards_text}",
flush=True,
)
def build_history_lines(history: Sequence[str]) -> str:
"""Render the last few steps for the model prompt."""
if not history:
return "None"
return "\n".join(history[-5:])
def _stable_json(value: Any) -> str:
"""Serialize a value deterministically for memory key generation."""
return json.dumps(value, sort_keys=True, separators=(",", ":"))
def _hash_key(payload: Mapping[str, Any]) -> str:
"""Build a compact deterministic memory key."""
return hashlib.sha1(_stable_json(payload).encode("utf-8")).hexdigest()
def _normalize_issue_text(issue: str) -> str:
"""Remove row-specific numbers so pattern memory generalizes better."""
lowered = issue.lower().strip()
return re.sub(r"\d+", "#", lowered)
def _table_summary(table: Sequence[Mapping[str, Any]]) -> Dict[str, Any]:
"""Extract a compact problem summary from the visible table state."""
present_columns = sorted({key for row in table for key in row.keys() if key != "row_id"})
missing_counts: Dict[str, int] = {}
for column in present_columns:
count = sum(1 for row in table if _is_missing(row.get(column)))
if count > 0:
missing_counts[column] = count
duplicate_counts: Dict[str, int] = {}
for column in list(IDENTIFIER_COLUMNS) + ["email"]:
values = [row.get(column) for row in table if row.get(column) not in (None, "")]
if values and len(values) != len(set(values)):
duplicate_counts[column] = len(values) - len(set(values))
return {
"row_count": len(table),
"present_columns": present_columns,
"missing_counts": missing_counts,
"duplicate_counts": duplicate_counts,
"invalid_email_count": sum(
1
for row in table
if row.get("email") not in (None, "") and not _is_valid_email(row.get("email"))
),
"invalid_phone_count": sum(
1
for row in table
if row.get("phone") not in (None, "") and not _is_valid_phone(row.get("phone"))
),
"title_case_columns": sorted(
column
for column in ("name", "city")
if any(_needs_title_case(row.get(column)) for row in table)
),
}
def build_memory_keys(
task_name: str,
task_variant: str,
goal: str,
observation: Mapping[str, Any],
) -> Tuple[str, str]:
"""Build exact-state and generalized problem-pattern keys."""
dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
table = list(dataset.get("modified", []))
normalized_issues = [
f"rows={len(table)}",
f"history={len(observation.get('action_history', []))}",
f"iter={observation.get('current_iteration_score', 0.0)}",
]
state_key = _hash_key(
{
"task_name": task_name,
"task_variant": task_variant,
"goal": goal,
"table": [
{key: row.get(key) for key in sorted(row.keys())}
for row in sorted(table, key=lambda row: int(row.get("row_id", 0)))
],
"issues": sorted(normalized_issues),
}
)
pattern_key = _hash_key(
{
"task_name": task_name,
"goal": goal,
"summary": _table_summary(table),
"issues": sorted(normalized_issues),
}
)
return state_key, pattern_key
def _is_missing(value: Any) -> bool:
"""Return whether a value is missing."""
return value is None or value == ""
def _needs_title_case(value: Any) -> bool:
"""Return whether a string still needs title-case normalization."""
if not isinstance(value, str):
return False
cleaned = value.strip()
return bool(cleaned) and cleaned != cleaned.title()
def _is_valid_email(value: Any) -> bool:
"""Return whether an email-like string is valid."""
return isinstance(value, str) and bool(EMAIL_PATTERN.match(value.strip()))
def _is_valid_phone(value: Any) -> bool:
"""Return whether a phone-like string is valid."""
if not isinstance(value, str):
return False
digits = re.sub(r"\D", "", value)
return len(digits) == 10 or (len(digits) == 11 and digits.startswith("1"))
def _slugify_text(value: str) -> str:
"""Convert free text into a stable email-local-part fragment."""
lowered = re.sub(r"[^a-z0-9]+", ".", value.lower()).strip(".")
return lowered or "record"
def _infer_email(row: Mapping[str, Any]) -> str:
"""Infer a safe placeholder email from row context."""
if isinstance(row.get("name"), str) and row["name"].strip():
return f"{_slugify_text(row['name'])}@example.com"
for key in IDENTIFIER_COLUMNS:
if row.get(key):
return f"{str(row[key]).lower()}@example.com"
return f"row{row.get('row_id', 'unknown')}@example.com"
def _infer_name(row: Mapping[str, Any]) -> str:
"""Infer a readable name when a name field is missing."""
email = row.get("email")
if isinstance(email, str) and "@" in email:
return email.split("@", 1)[0].replace(".", " ").title()
for key in IDENTIFIER_COLUMNS:
if row.get(key):
return str(row[key]).replace("_", " ").title()
return "Unknown Record"
def _infer_city(table: Sequence[Mapping[str, Any]]) -> str:
"""Infer a plausible city using the mode of visible values."""
candidates = [
str(row.get("city")).title()
for row in table
if isinstance(row.get("city"), str) and row.get("city")
]
if not candidates:
return "Seattle"
return Counter(candidates).most_common(1)[0][0]
def _infer_fill_value(
row: Mapping[str, Any],
column: str,
table: Sequence[Mapping[str, Any]],
) -> str:
"""Infer a deterministic fill value from local table context."""
for key in IDENTIFIER_COLUMNS:
identifier = row.get(key)
if not identifier:
continue
for candidate in table:
if candidate.get("row_id") == row.get("row_id"):
continue
if candidate.get(key) == identifier and not _is_missing(candidate.get(column)):
return str(candidate[column])
if column == "email":
return _infer_email(row)
if column == "city":
return _infer_city(table)
if column == "phone":
return "555-555-0100"
if column == "status":
return "active"
if column == "name":
return _infer_name(row)
return "resolved"
def _row_signature(row: Mapping[str, Any]) -> Tuple[Tuple[str, Any], ...]:
"""Create a comparable row signature excluding runtime row identifiers."""
return tuple(sorted((key, value) for key, value in row.items() if key != "row_id"))
def _build_action_string(payload: Mapping[str, Any]) -> str:
"""Reconstruct a normalized action string for logging and filtering."""
action_type = str(payload["action_type"])
args: List[str] = []
for key in ("record_id", "field", "value", "confidence"):
if key not in payload or payload[key] is None:
continue
value = payload[key]
if isinstance(value, str):
args.append(f"{key}='{value}'")
else:
args.append(f"{key}={value}")
return f"{action_type}({', '.join(args)})" if args else f"{action_type}()"
def build_action_string(payload: Dict[str, Any]) -> str:
"""Backward-compatible public wrapper around action string generation."""
return _build_action_string(payload)
def parse_model_action(response_text: str) -> str:
"""Extract a single action string from model output."""
if not response_text:
return FALLBACK_ACTION
for raw_line in response_text.splitlines():
line = ACTION_PREFIX_RE.sub("", raw_line.strip())
if "(" in line and line.endswith(")"):
return re.sub(r"\s+", " ", line)
compact = ACTION_PREFIX_RE.sub("", response_text.strip())
match = re.search(r"[a-zA-Z_]+\s*\(.*\)", compact)
if match:
return re.sub(r"\s+", " ", match.group(0))
return FALLBACK_ACTION
def action_string_to_payload(action_str: str, step_number: int) -> Tuple[str, Dict[str, Any]]:
"""Convert a model action string into an environment action payload."""
try:
expression = ast.parse(action_str, mode="eval").body
except SyntaxError:
return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
if not isinstance(expression, ast.Call) or not isinstance(expression.func, ast.Name):
return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
allowed_actions = {
"detect_issue",
"fix_value",
"cannot_determine",
"skip",
}
action_type = expression.func.id
if action_type not in allowed_actions:
return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
payload: Dict[str, Any] = {
"action_type": action_type,
}
try:
for keyword in expression.keywords:
if keyword.arg is None:
continue
payload[keyword.arg] = ast.literal_eval(keyword.value)
except (SyntaxError, ValueError, TypeError):
return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
payload.setdefault("record_id", "0")
payload.setdefault("field", "record")
payload.setdefault("confidence", 0.6 if action_type != "skip" else 0.0)
return _build_action_string(payload), payload
def create_client() -> Optional[OpenAI]:
"""Create an OpenAI-compatible client when credentials look real."""
if HF_TOKEN in {None, "", "local-test", "test", "dummy"}:
return None
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
def _extract_response_text(content: Any) -> str:
"""Normalize OpenAI response content into plain text."""
if isinstance(content, str):
return content
if isinstance(content, list):
return "".join(
str(part.get("text", ""))
for part in content
if isinstance(part, dict)
)
return str(content or "")
def _table_preview(table: Sequence[Mapping[str, Any]], limit: int = 6) -> str:
"""Render a compact table preview for prompting."""
preview_lines: List[str] = []
for row in table[:limit]:
summary = ", ".join(
f"{key}={value}"
for key, value in row.items()
if key in {"row_id", "name", "city", "email", "phone", "status", "customer_id", "vendor_id", "partner_id", "age", "start_date", "end_date"}
)
preview_lines.append(f"- {summary}")
return "\n".join(preview_lines) if preview_lines else "- None"
def build_user_prompt(
step: int,
goal: str,
observation: Dict[str, Any],
history: Sequence[str],
last_error: Optional[str],
candidate_actions: Sequence[str],
blocked_actions: Sequence[str],
) -> str:
"""Construct a compact prompt that constrains the model to useful actions."""
dataset = observation.get("dataset", {})
modified = dataset.get("modified", [])
candidates_text = "\n".join(f"- {action}" for action in candidate_actions)
blocked_text = "\n".join(f"- {action}" for action in blocked_actions[:5]) if blocked_actions else "- None"
return textwrap.dedent(
f"""
Step: {step}
Goal: {goal}
Steps remaining: {observation.get("steps_remaining")}
Current iteration score: {observation.get("current_iteration_score")}
Previous iteration score: {observation.get("previous_iteration_score")}
Per-record scores: {observation.get("per_record_scores")}
Table preview:
{_table_preview(modified)}
Recent history:
{build_history_lines(history)}
Last action error: {last_error or "null"}
Blocked actions:
{blocked_text}
Choose exactly one action from this candidate list:
{candidates_text}
"""
).strip()
def _prefer_action(
candidates: Sequence[Dict[str, Any]],
blocked_actions: set[str],
) -> Dict[str, Any]:
"""Return the first candidate action that is not blocked."""
for candidate in candidates:
action_text = _build_action_string(candidate)
if action_text not in blocked_actions:
return dict(candidate)
return {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
def _record_id(row: Mapping[str, Any]) -> str:
rid = row.get("row_id")
return str(rid) if rid is not None else "0"
def _issue_like_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
"""Generate issue detection/fix candidates for new semantic action schema."""
actions: List[Dict[str, Any]] = []
for row in table:
rid = _record_id(row)
for field, value in row.items():
if field == "row_id":
continue
if _is_missing(value) or str(value).strip().lower() in {"unknown", "9999"}:
actions.append(
{"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.85}
)
actions.append(
{
"action_type": "fix_value",
"record_id": rid,
"field": field,
"value": _infer_fill_value(row, field, table),
"confidence": 0.75,
}
)
elif field == "email" and not _is_valid_email(value):
fixed = str(value).replace("[at]", "@").replace(" at ", "@").replace(" ", "")
if "@" in fixed and "." not in fixed.split("@")[-1]:
fixed += ".com"
actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.85})
actions.append({"action_type": "fix_value", "record_id": rid, "field": field, "value": fixed, "confidence": 0.8})
elif field == "phone" and not _is_valid_phone(value):
digits = re.sub(r"\D", "", str(value))
if len(digits) == 10:
fixed = f"{digits[0:3]}-{digits[3:6]}-{digits[6:10]}"
actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.8})
actions.append({"action_type": "fix_value", "record_id": rid, "field": field, "value": fixed, "confidence": 0.75})
elif field in {"start_date", "end_date"}:
start = row.get("start_date")
end = row.get("end_date")
if start and end and str(end) < str(start):
actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.8})
actions.append({"action_type": "cannot_determine", "record_id": rid, "field": field, "confidence": 0.7})
elif field == "age":
try:
age = int(value)
except Exception:
age = -1
if age < 0 or age > 120:
actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.9})
actions.append({"action_type": "cannot_determine", "record_id": rid, "field": field, "confidence": 0.8})
return actions
def _detected_keys_from_history(action_history: Sequence[Mapping[str, Any]]) -> set[str]:
"""Extract previously detected issue keys from observation history."""
keys: set[str] = set()
for action in action_history:
if action.get("action_type") != "detect_issue":
continue
keys.add(f"{action.get('record_id')}::{action.get('field')}")
return keys
def propose_candidate_actions(
observation: Mapping[str, Any],
blocked_actions: set[str],
) -> List[Dict[str, Any]]:
"""Generate ranked candidate actions from visible table state."""
dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
table = list(dataset.get("modified", []))
detected_keys = _detected_keys_from_history(observation.get("action_history", []))
raw_candidates = _issue_like_candidates(table)
candidates: List[Dict[str, Any]] = []
for candidate in raw_candidates:
if candidate.get("action_type") == "detect_issue":
key = f"{candidate.get('record_id')}::{candidate.get('field')}"
# Detect once; then prefer follow-up actions.
if key in detected_keys:
continue
candidates.append(candidate)
candidates += [
{"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
]
unique_candidates: List[Dict[str, Any]] = []
seen: set[str] = set()
for candidate in candidates:
action_text = _build_action_string(candidate)
if action_text in seen:
continue
seen.add(action_text)
unique_candidates.append(candidate)
preferred = _prefer_action(unique_candidates, blocked_actions)
preferred_text = _build_action_string(preferred)
ordered = [preferred] + [
candidate
for candidate in unique_candidates
if _build_action_string(candidate) != preferred_text
]
return ordered[:12]
def _order_candidates_with_memory(
candidates: Sequence[Dict[str, Any]],
memory: PolicyMemory,
state_key: str,
pattern_key: str,
recent_history: Sequence[str],
) -> List[Dict[str, Any]]:
"""Re-rank candidates using persistent cross-episode memory."""
scored = []
recent_action_counts = Counter()
for item in recent_history[-5:]:
try:
parsed = item.split(" action=", 1)[1].split(" reward=", 1)[0].strip()
if parsed:
recent_action_counts[parsed] += 1
except Exception:
continue
for index, candidate in enumerate(candidates):
action_text = _build_action_string(candidate)
repeat_penalty = recent_action_counts.get(action_text, 0) * 2.0
scored.append(
(
-memory.score_action(state_key, pattern_key, action_text) + repeat_penalty,
index,
dict(candidate),
)
)
scored.sort(key=lambda item: (item[0], item[1]))
return [item[2] for item in scored]
def model_action(
client: Optional[OpenAI],
model_name: str,
step: int,
goal: str,
observation: Dict[str, Any],
history: Sequence[str],
last_error: Optional[str],
candidate_actions: Sequence[str],
blocked_actions: Sequence[str],
) -> Optional[str]:
"""Ask the model to choose among pre-computed candidate actions."""
if client is None:
return None
prompt = build_user_prompt(
step=step,
goal=goal,
observation=observation,
history=history,
last_error=last_error,
candidate_actions=candidate_actions,
blocked_actions=blocked_actions,
)
current_prompt = prompt
candidate_set = set(candidate_actions)
for _ in range(MODEL_RETRIES):
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "system",
"content": [{"type": "text", "text": SYSTEM_PROMPT}],
},
{
"role": "user",
"content": [{"type": "text", "text": current_prompt}],
},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
response_text = _extract_response_text(completion.choices[0].message.content)
action_text = parse_model_action(response_text)
if action_text in candidate_set and action_text not in set(blocked_actions):
return action_text
current_prompt = (
prompt
+ "\n\nYour previous answer was invalid or blocked. Choose exactly one action from the candidate list."
)
except Exception: # noqa: BLE001
return None
return None
def choose_action(
client: Optional[OpenAI],
memory: PolicyMemory,
task_name: str,
task_variant: str,
observation: Dict[str, Any],
goal: str,
step_number: int,
history: Sequence[str],
last_error: Optional[str],
blocked_actions: set[str],
) -> Tuple[str, Dict[str, Any], str, str, str]:
"""Choose the next action using a heuristic planner with optional model arbitration."""
state_key, pattern_key = build_memory_keys(task_name, task_variant, goal, observation)
memory_blocked = memory.blocked_actions(state_key, pattern_key)
combined_blocked = set(blocked_actions) | set(memory_blocked)
candidates = propose_candidate_actions(observation, combined_blocked)
candidates = _order_candidates_with_memory(
candidates, memory, state_key, pattern_key, history
)
heuristic_candidate = candidates[0]
heuristic_text = _build_action_string(heuristic_candidate)
candidate_texts = [_build_action_string(candidate) for candidate in candidates]
model_text = model_action(
client=client,
model_name=MODEL_NAME,
step=step_number,
goal=goal,
observation=observation,
history=history,
last_error=last_error,
candidate_actions=candidate_texts,
blocked_actions=sorted(combined_blocked),
)
if model_text not in candidate_texts:
model_text = None
chosen_text = model_text or heuristic_text
normalized_text, payload = action_string_to_payload(chosen_text, step_number)
if normalized_text in combined_blocked:
normalized_text, payload = action_string_to_payload(heuristic_text, step_number)
return normalized_text, payload, "heuristic", state_key, pattern_key
return normalized_text, payload, "model" if model_text else "heuristic", state_key, pattern_key
def run_episode(
client: Optional[OpenAI],
memory: PolicyMemory,
task_name: str,
seed: int,
) -> float:
"""Run one deterministic task episode and return its final task score."""
env = DataOpsEnv(seed=seed, task_name=task_name)
rewards: List[float] = []
history: List[str] = []
blocked_actions: set[str] = set()
steps_taken = 0
success = False
last_error: Optional[str] = None
final_score = 0.0
task_variant = "unknown"
action_repeat_counts: Dict[str, int] = defaultdict(int)
no_change_counts: Dict[str, int] = defaultdict(int)
try:
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
observation_model = env.reset()
observation = observation_model.model_dump()
task_variant = str(env.state().get("task_variant", "unknown"))
for step_number in range(1, MAX_STEPS + 1):
action_text, action_payload, action_source, state_key, pattern_key = choose_action(
client=client,
memory=memory,
task_name=task_name,
task_variant=task_variant,
observation=observation,
goal=str(env.state().get("task", {}).get("goal", "")),
step_number=step_number,
history=history,
last_error=last_error,
blocked_actions=blocked_actions,
)
try:
before_dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
before_modified = before_dataset.get("modified", [])
observation_model, reward, done, info = env.step(action_payload)
observation = observation_model.model_dump()
after_dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
after_modified = after_dataset.get("modified", [])
result = info.get("result", {})
curr_iter = float(observation.get("current_iteration_score", 0.0))
prev_iter = float(observation.get("previous_iteration_score", 0.0))
progress_delta = max(0.0, curr_iter - prev_iter)
error_value = "step_error" if (
result.get("wrong_fix")
or result.get("hallucinated_fix")
or result.get("wrong_cannot_determine")
or result.get("classification_incorrect")
) else None
final_score = float(info.get("final_task_score", 0.0))
if error_value == "general":
error_value = None
memory.update(
state_key=state_key,
pattern_key=pattern_key,
action_text=action_text,
reward=reward,
progress_delta=progress_delta,
error=error_value,
done=done,
task_score=final_score,
)
if error_value or progress_delta == 0.0 or reward <= 0.0:
blocked_actions.add(action_text)
action_repeat_counts[action_text] += 1
if action_repeat_counts[action_text] > 2:
blocked_actions.add(action_text)
if _stable_json(before_modified) == _stable_json(after_modified):
no_change_counts[action_text] += 1
if no_change_counts[action_text] >= 2:
blocked_actions.add(action_text)
else:
no_change_counts[action_text] = 0
except Exception as exc: # noqa: BLE001
reward = 0.0
done = True
info = {}
error_value = str(exc)
blocked_actions.add(action_text)
memory.update(
state_key=state_key,
pattern_key=pattern_key,
action_text=action_text,
reward=reward,
progress_delta=0.0,
error=error_value,
done=done,
task_score=final_score,
)
rewards.append(reward)
steps_taken = step_number
last_error = error_value
log_step(
step=step_number,
action=action_text,
reward=reward,
done=done,
error=error_value,
)
history.append(
f"step={step_number} source={action_source} action={action_text} "
f"reward={reward:.2f} done={str(done).lower()} error={error_value or 'null'}"
)
if done:
success = bool(final_score > 0.0)
break
finally:
memory.save()
close_method = getattr(env, "close", None)
if callable(close_method):
close_method()
log_end(success=success, steps=steps_taken, rewards=rewards, final_score=final_score)
return final_score
def main() -> None:
"""Run one configured task or all tasks in deterministic order."""
client = create_client()
memory = PolicyMemory(POLICY_CACHE_PATH)
task_name = str(TASK_NAME).strip().lower()
if task_name in {"all", "*"}:
for task_index, ordered_task in enumerate(TASK_ORDER):
run_episode(client=client, memory=memory, task_name=ordered_task, seed=task_index)
return
run_episode(client=client, memory=memory, task_name=task_name, seed=0)
if __name__ == "__main__":
main()