cadforge-cadquery-openenv / server /mechforge_environment.py
sanjuhs's picture
Upload folder using huggingface_hub
6de1b61 verified
from __future__ import annotations
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata
from mechforge.models import Design, ToolAction
from mechforge.tools import apply_action
from .openenv_models import MechForgeAction, MechForgeObservation, MechForgeState
AVAILABLE_TOOLS = [
"create_design_family",
"set_material",
"set_envelope",
"set_load",
"add_mount_hole",
"add_rib",
"add_lightening_hole",
"run_fea",
"commit_design",
]
DEFAULT_TASK = (
"Design a lightweight 6061 aluminum cantilever motor-mount bracket. "
"It is fixed on the left face by two M5 bolts and must carry 120 N downward "
"at the tip/load boss around 90 mm from the fixed edge. Keep mass low while "
"maintaining safety factor above 2.0 and minimizing deflection."
)
class MechForgeEnvironment(Environment[MechForgeAction, MechForgeObservation, MechForgeState]):
"""OpenEnv-compatible MechForge 3D engineering design environment."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
self._state = MechForgeState(episode_id=str(uuid4()), step_count=0, task_brief=DEFAULT_TASK)
self._design: Design | None = None
self._done = False
def reset(self, seed: int | None = None, episode_id: str | None = None, task_brief: str | None = None, **_: object) -> MechForgeObservation:
self._state = MechForgeState(
episode_id=episode_id or str(uuid4()),
step_count=0,
task_brief=task_brief or DEFAULT_TASK,
)
self._design = None
self._done = False
return self._observation(reward=0.0, last_result={"message": "Ready. Create a design family, set constraints, then run FEA."})
def step(self, action: MechForgeAction, timeout_s: float | None = None, **_: object) -> MechForgeObservation: # type: ignore[override]
if self._done:
return self._observation(reward=0.0, last_result={"valid": False, "error": "Episode is already done."}, done=True)
self._state.step_count += 1
tool_action = ToolAction(tool=action.tool, params=action.params)
self._design, result = apply_action(self._design, tool_action, self._state.task_brief)
trace_item = {
"step": self._state.step_count,
"action": tool_action.model_dump(),
"result": result,
}
self._state.trace.append(trace_item)
self._state.design = self._design.model_dump() if self._design else None
metrics = result.get("simulation", {})
if metrics:
self._state.last_metrics = metrics
reward = self._reward(result)
if result.get("committed") or self._state.step_count >= 300:
self._done = True
return self._observation(reward=reward, last_action=tool_action.model_dump(), last_result=result, done=self._done)
@property
def state(self) -> MechForgeState:
return self._state
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="mechforge_3d",
description=(
"Agentic 3D engineering design environment. Agents build parametric "
"brackets/mounts through tool calls, run Python 3D linear FEA, inspect "
"stress/deflection feedback, and optimize manufacturable designs."
),
version="0.1.0",
author="Sanjayprasad H S",
)
def _reward(self, result: dict) -> float:
if not result.get("valid", False):
return -0.05
sim = result.get("simulation")
if not sim:
return 0.02
score = float(sim.get("score", 0.0))
safety_factor = float(sim.get("safety_factor", 0.0))
mass_g = float(sim.get("mass_g", 999.0))
penalty = 0.0
if safety_factor < 2.0:
penalty += 0.2 * (2.0 - safety_factor)
if mass_g > 90:
penalty += 0.01 * ((mass_g - 90) / 10)
return max(-1.0, min(1.0, score - penalty))
def _observation(
self,
reward: float,
last_action: dict | None = None,
last_result: dict | None = None,
done: bool | None = None,
) -> MechForgeObservation:
warnings: list[str] = []
if self._state.last_metrics.get("safety_factor", 999) < 2.0:
warnings.append("Safety factor below target 2.0.")
if self._state.last_metrics.get("mass_g", 0) > 90:
warnings.append("Design is heavy for the current benchmark.")
return MechForgeObservation(
task_brief=self._state.task_brief,
available_tools=AVAILABLE_TOOLS,
design=self._design.model_dump() if self._design else None,
last_action=last_action,
last_result=last_result or {},
metrics=self._state.last_metrics,
warnings=warnings,
trace=self._state.trace,
done=self._done if done is None else done,
reward=reward,
)