Upload folder using huggingface_hub
Browse files- Dockerfile +69 -0
- README.md +6 -9
- __init__.py +16 -0
- client.py +51 -0
- inference.py +41 -0
- models.py +94 -0
- openenv.yaml +6 -0
- pyproject.toml +45 -0
- server/__init__.py +11 -0
- server/app.py +53 -0
- server/app_environment.py +96 -0
- server/requirements.txt +6 -0
- utils.py +294 -0
Dockerfile
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends git && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 12 |
+
ARG BUILD_MODE=in-repo
|
| 13 |
+
ARG ENV_NAME=app
|
| 14 |
+
|
| 15 |
+
# Copy environment code (always at root of build context)
|
| 16 |
+
COPY . /app/env
|
| 17 |
+
|
| 18 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 19 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 20 |
+
WORKDIR /app/env
|
| 21 |
+
|
| 22 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 23 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 24 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 25 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 26 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
# Install dependencies using uv sync
|
| 30 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 31 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 32 |
+
if [ -f uv.lock ]; then \
|
| 33 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 34 |
+
else \
|
| 35 |
+
uv sync --no-install-project --no-editable; \
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 39 |
+
if [ -f uv.lock ]; then \
|
| 40 |
+
uv sync --frozen --no-editable; \
|
| 41 |
+
else \
|
| 42 |
+
uv sync --no-editable; \
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
# Final runtime stage
|
| 46 |
+
FROM ${BASE_IMAGE}
|
| 47 |
+
|
| 48 |
+
WORKDIR /app
|
| 49 |
+
|
| 50 |
+
# Copy the virtual environment from builder
|
| 51 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 52 |
+
|
| 53 |
+
# Copy the environment code
|
| 54 |
+
COPY --from=builder /app/env /app/env
|
| 55 |
+
|
| 56 |
+
# Set PATH to use the virtual environment
|
| 57 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 58 |
+
|
| 59 |
+
# Set PYTHONPATH so imports work correctly
|
| 60 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 61 |
+
|
| 62 |
+
# Health check
|
| 63 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 64 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 65 |
+
|
| 66 |
+
# Run the FastAPI server
|
| 67 |
+
# The module path is constructed to work with the /app/env structure
|
| 68 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 69 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Object Placement
|
| 3 |
+
emoji: 🔊
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
base_path: /web
|
| 7 |
+
---
|
|
|
|
|
|
|
|
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""App Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import AppEnv
|
| 10 |
+
from .models import AppAction, AppObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AppAction",
|
| 14 |
+
"AppObservation",
|
| 15 |
+
"AppEnv",
|
| 16 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""App Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
|
| 8 |
+
from .models import AppAction, AppObservation, AppState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AppEnv(EnvClient[AppAction, AppObservation, AppState]):
|
| 12 |
+
|
| 13 |
+
def _step_payload(self, action: AppAction) -> Dict:
|
| 14 |
+
|
| 15 |
+
return {
|
| 16 |
+
"placement": action.placement,
|
| 17 |
+
"isSegmentation": action.isSegmentation,
|
| 18 |
+
"findObjects": action.findObjects,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def _parse_result(self, payload: Dict) -> StepResult[AppObservation]:
|
| 22 |
+
|
| 23 |
+
obs_data = payload.get("observation", {})
|
| 24 |
+
observation = AppObservation(
|
| 25 |
+
currentGrid=obs_data.get("currentGrid", []),
|
| 26 |
+
positions=obs_data.get("positions", {}),
|
| 27 |
+
objectsLeft=obs_data.get("objectsLeft", []),
|
| 28 |
+
objectsFound=obs_data.get("objectsFound", []),
|
| 29 |
+
reward=obs_data.get("reward", 0.0),
|
| 30 |
+
isDone=obs_data.get("isDone", False),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return StepResult(
|
| 34 |
+
observation=observation,
|
| 35 |
+
reward=payload.get("reward"),
|
| 36 |
+
done=payload.get("done", obs_data.get("isDone", False)),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def _parse_state(self, payload: Dict) -> AppState:
|
| 40 |
+
|
| 41 |
+
return AppState(
|
| 42 |
+
episode_id=payload.get("episode_id"),
|
| 43 |
+
step_count=payload.get("step_count", 0),
|
| 44 |
+
currentGrid=payload.get("currentGrid", []),
|
| 45 |
+
weightedGrid=payload.get("weightedGrid", []),
|
| 46 |
+
reward=payload.get("reward", 0.0),
|
| 47 |
+
isDone=payload.get("isDone", False),
|
| 48 |
+
objectsLeft=payload.get("objectsLeft", []),
|
| 49 |
+
objectsFound=payload.get("objectsFound", []),
|
| 50 |
+
ObjectsPresent=payload.get("ObjectsPresent", {}),
|
| 51 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import base64
|
| 4 |
+
import textwrap
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from typing import List, Optional, Dict
|
| 7 |
+
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
API_BASE_URL = os.getenv("API_BASE_URL")
|
| 14 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 15 |
+
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 16 |
+
|
| 17 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 18 |
+
"""
|
| 19 |
+
You control a web browser through BrowserGym.
|
| 20 |
+
Reply with exactly one action string.
|
| 21 |
+
The action must be a valid BrowserGym command such as:
|
| 22 |
+
- noop()
|
| 23 |
+
- click('<BID>')
|
| 24 |
+
- type('selector', 'text to enter')
|
| 25 |
+
- fill('selector', 'text to enter')
|
| 26 |
+
- send_keys('Enter')
|
| 27 |
+
- scroll('down')
|
| 28 |
+
Use single quotes around string arguments.
|
| 29 |
+
When clicking, use the BrowserGym element IDs (BIDs) listed in the user message.
|
| 30 |
+
If you are unsure, respond with noop().
|
| 31 |
+
Do not include explanations or additional text.
|
| 32 |
+
"""
|
| 33 |
+
).strip()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AppAction(Action):
|
| 7 |
+
"""Action for the App environment"""
|
| 8 |
+
|
| 9 |
+
placement: Dict[str, Tuple[int, int, int, bool]] = Field(
|
| 10 |
+
default_factory=dict, description="Placement of the object in a 3D grid"
|
| 11 |
+
)
|
| 12 |
+
isSegmentation: bool = Field(
|
| 13 |
+
default=True, description="Whether the model is segmenting the objects"
|
| 14 |
+
)
|
| 15 |
+
findObjects: Dict[str, Tuple[int, int, int, bool]] = Field(
|
| 16 |
+
default_factory=dict, description="Dictionary of objects"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AppObservation(Observation):
|
| 21 |
+
"""Observation from the App environment"""
|
| 22 |
+
|
| 23 |
+
currentGrid: List[List[List[int]]] = Field(
|
| 24 |
+
default_factory=list,
|
| 25 |
+
description="Current placement of the objects in a 3D grid",
|
| 26 |
+
)
|
| 27 |
+
positions: Dict[str, Tuple[int, int, int, bool]] = Field(
|
| 28 |
+
default_factory=dict,
|
| 29 |
+
description="Dictionary of objects with their positions in the environment",
|
| 30 |
+
)
|
| 31 |
+
objectsLeft: List[str] = Field(
|
| 32 |
+
default_factory=list,
|
| 33 |
+
description="List of unorganised objects left in the environment",
|
| 34 |
+
)
|
| 35 |
+
objectsFound: List[str] = Field(
|
| 36 |
+
default_factory=list,
|
| 37 |
+
description="List of objects found in the environment",
|
| 38 |
+
)
|
| 39 |
+
reward: float = Field(
|
| 40 |
+
default=0.0, description="Reward received after taking the action"
|
| 41 |
+
)
|
| 42 |
+
isDone: bool = Field(default=False, description="Whether the episode has ended")
|
| 43 |
+
|
| 44 |
+
rewardFeedback: list[str] = Field(
|
| 45 |
+
default_factory=list,
|
| 46 |
+
description="List of feedback strings describing the reward received after taking the action",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
rewardList: list[float] = Field(
|
| 50 |
+
default_factory=list,
|
| 51 |
+
description="List of reward values received after taking the action",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class AppState(State):
|
| 56 |
+
"""State for the App environment"""
|
| 57 |
+
|
| 58 |
+
currentGrid: List[List[List[int]]] = Field(
|
| 59 |
+
default_factory=list,
|
| 60 |
+
description="Initial state of the environment with unorganised objects",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
weightedGrid: List[List[List[float]]] = Field(
|
| 64 |
+
default_factory=list,
|
| 65 |
+
description="Weighted grid used when scoring placements",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
objectsLeft: List[str] = Field(
|
| 69 |
+
default_factory=list,
|
| 70 |
+
description="List of unorganised objects left in the environment",
|
| 71 |
+
)
|
| 72 |
+
objectsFound: List[str] = Field(
|
| 73 |
+
default_factory=list,
|
| 74 |
+
description="List of objects found in the environment",
|
| 75 |
+
)
|
| 76 |
+
reward: float = Field(
|
| 77 |
+
default=0.0, description="Reward received after taking the action"
|
| 78 |
+
)
|
| 79 |
+
isDone: bool = Field(default=False, description="Whether the episode has ended")
|
| 80 |
+
|
| 81 |
+
ObjectsPresent: Dict[str, Tuple[int, int, int, bool]] = Field(
|
| 82 |
+
default_factory=dict,
|
| 83 |
+
description="Placed objects and their current positions in the environment",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
rewardFeedback: list[str] = Field(
|
| 87 |
+
default_factory=list,
|
| 88 |
+
description="List of feedback strings describing the reward received after taking the action",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
rewardList: list[float] = Field(
|
| 92 |
+
default_factory=list,
|
| 93 |
+
description="List of reward values received after taking the action",
|
| 94 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: app
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-app"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "App environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.1",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m app.server.app
|
| 40 |
+
server = "app.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["app", "app.server"]
|
| 45 |
+
package-dir = { "app" = ".", "app.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""App environment server components."""
|
| 8 |
+
|
| 9 |
+
from .app_environment import AppEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["AppEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from openenv.core.env_server.http_server import create_app
|
| 3 |
+
except Exception as e: # pragma: no cover
|
| 4 |
+
raise ImportError(
|
| 5 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 6 |
+
) from e
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from app.models import AppAction, AppObservation
|
| 10 |
+
from app.server.app_environment import AppEnvironment
|
| 11 |
+
except ModuleNotFoundError:
|
| 12 |
+
from models import AppAction, AppObservation
|
| 13 |
+
from server.app_environment import AppEnvironment
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
app = create_app(
|
| 17 |
+
AppEnvironment,
|
| 18 |
+
AppAction,
|
| 19 |
+
AppObservation,
|
| 20 |
+
env_name="app",
|
| 21 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 26 |
+
"""
|
| 27 |
+
Entry point for direct execution via uv run or python -m.
|
| 28 |
+
|
| 29 |
+
This function enables running the server without Docker:
|
| 30 |
+
uv run --project . server
|
| 31 |
+
uv run --project . server --port 8001
|
| 32 |
+
python -m app.server.app
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 36 |
+
port: Port number to listen on (default: 8000)
|
| 37 |
+
|
| 38 |
+
For production deployments, consider using uvicorn directly with
|
| 39 |
+
multiple workers:
|
| 40 |
+
uvicorn app.server.app:app --workers 4
|
| 41 |
+
"""
|
| 42 |
+
import uvicorn
|
| 43 |
+
|
| 44 |
+
uvicorn.run(app, host=host, port=port)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
import argparse
|
| 49 |
+
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
main(port=args.port)
|
server/app_environment.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from uuid import uuid4
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server.interfaces import Environment
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from ..models import AppAction, AppObservation, AppState
|
| 7 |
+
except ImportError:
|
| 8 |
+
from models import AppAction, AppObservation, AppState
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from ..utils import *
|
| 12 |
+
except ImportError:
|
| 13 |
+
from utils import *
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AppEnvironment(Environment):
|
| 17 |
+
|
| 18 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._state = self._new_state()
|
| 22 |
+
self._reset_count = 0
|
| 23 |
+
|
| 24 |
+
def _new_state(self) -> AppState:
|
| 25 |
+
grid, placed = initGrid()
|
| 26 |
+
|
| 27 |
+
return AppState(
|
| 28 |
+
episode_id=str(uuid4()),
|
| 29 |
+
step_count=0,
|
| 30 |
+
currentGrid=grid,
|
| 31 |
+
weightedGrid=initWeightedGrid(),
|
| 32 |
+
objectsLeft=list(OBJECTS.keys()),
|
| 33 |
+
objectsFound=[],
|
| 34 |
+
reward=0.0,
|
| 35 |
+
isDone=False,
|
| 36 |
+
ObjectsPresent=placed,
|
| 37 |
+
rewardFeedback=[],
|
| 38 |
+
rewardList=[],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def reset(self) -> AppObservation:
|
| 42 |
+
self._state = self._new_state()
|
| 43 |
+
|
| 44 |
+
return AppObservation(
|
| 45 |
+
currentGrid=self._state.currentGrid,
|
| 46 |
+
positions=self._state.ObjectsPresent,
|
| 47 |
+
objectsLeft=self._state.objectsLeft,
|
| 48 |
+
objectsFound=self._state.objectsFound,
|
| 49 |
+
reward=self._state.reward,
|
| 50 |
+
isDone=self._state.isDone,
|
| 51 |
+
rewardFeedback=self._state.rewardFeedback,
|
| 52 |
+
rewardList=self._state.rewardList,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def step(self, action: AppAction) -> AppObservation:
|
| 56 |
+
if not isinstance(self._state, AppState):
|
| 57 |
+
self._state = self._new_state()
|
| 58 |
+
|
| 59 |
+
self._state.step_count += 1
|
| 60 |
+
|
| 61 |
+
reward = 0.0
|
| 62 |
+
if action.isSegmentation:
|
| 63 |
+
reward += 10.0
|
| 64 |
+
appendRewardFeedback(self._state, "Segmentation successful.", reward)
|
| 65 |
+
|
| 66 |
+
if action.placement:
|
| 67 |
+
reward += place(action.isSegmentation, action.placement, self._state)
|
| 68 |
+
appendRewardFeedback(self._state, "Object placed successfully.", reward)
|
| 69 |
+
|
| 70 |
+
if action.findObjects:
|
| 71 |
+
reward += findobject(action.isSegmentation, action.findObjects, self._state)
|
| 72 |
+
appendRewardFeedback(self._state, "Object found successfully.", reward)
|
| 73 |
+
|
| 74 |
+
if len(self._state.objectsLeft) == 0:
|
| 75 |
+
self._state.isDone = True
|
| 76 |
+
reward += 10.0
|
| 77 |
+
appendRewardFeedback(
|
| 78 |
+
self._state, "All objects found. Episode completed!", reward
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self._state.reward += reward / (10**self._state.step_count)
|
| 82 |
+
|
| 83 |
+
return AppObservation(
|
| 84 |
+
currentGrid=self._state.currentGrid,
|
| 85 |
+
positions=self._state.ObjectsPresent,
|
| 86 |
+
objectsLeft=self._state.objectsLeft,
|
| 87 |
+
objectsFound=self._state.objectsFound,
|
| 88 |
+
reward=self._state.reward,
|
| 89 |
+
isDone=self._state.isDone,
|
| 90 |
+
rewardFeedback=self._state.rewardFeedback,
|
| 91 |
+
rewardList=self._state.rewardList,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def state(self) -> dict:
|
| 96 |
+
return self._state.model_dump()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core
|
| 2 |
+
fastapi
|
| 3 |
+
uvicorn
|
| 4 |
+
numpy
|
| 5 |
+
scikit-learn
|
| 6 |
+
matplotlib
|
utils.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matplotlib.pylab import randint
|
| 2 |
+
from numpy import ones, zeros, random
|
| 3 |
+
from sklearn.metrics import mean_squared_error as MSE
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
random.seed(123)
|
| 7 |
+
|
| 8 |
+
OBJECTS = {
|
| 9 |
+
"book": {"dims": [4, 4, 2], "stack": True},
|
| 10 |
+
"penstand": {"dims": [2, 2, 4], "stack": True},
|
| 11 |
+
"bottle": {"dims": [2, 2, 6], "stack": False},
|
| 12 |
+
"pen": {"dims": [1, 1, 4], "stack": False},
|
| 13 |
+
"pencil": {"dims": [1, 1, 6], "stack": False},
|
| 14 |
+
"eraser": {"dims": [2, 1, 1], "stack": False},
|
| 15 |
+
"powerbank": {"dims": [4, 2, 1], "stack": False},
|
| 16 |
+
"mobile": {"dims": [4, 2, 1], "stack": False},
|
| 17 |
+
"laptop": {"dims": [6, 4, 1], "stack": True},
|
| 18 |
+
"monitor": {"dims": [6, 4, 2], "stack": False},
|
| 19 |
+
"keyboard": {"dims": [6, 2, 1], "stack": False},
|
| 20 |
+
"mouse": {"dims": [4, 2, 1], "stack": False},
|
| 21 |
+
"headphones": {"dims": [4, 4, 2], "stack": False},
|
| 22 |
+
"charger": {"dims": [2, 2, 1], "stack": False},
|
| 23 |
+
"notebook": {"dims": [4, 4, 1], "stack": True},
|
| 24 |
+
"folder": {"dims": [4, 4, 1], "stack": True},
|
| 25 |
+
"backpack": {"dims": [6, 4, 2], "stack": False},
|
| 26 |
+
"pouch": {"dims": [4, 4, 2], "stack": False},
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
OBJECT_NAMES = [
|
| 30 |
+
"book",
|
| 31 |
+
"penstand",
|
| 32 |
+
"bottle",
|
| 33 |
+
"pen",
|
| 34 |
+
"pencil",
|
| 35 |
+
"eraser",
|
| 36 |
+
"powerbank",
|
| 37 |
+
"mobile",
|
| 38 |
+
"laptop",
|
| 39 |
+
"monitor",
|
| 40 |
+
"keyboard",
|
| 41 |
+
"mouse",
|
| 42 |
+
"headphones",
|
| 43 |
+
"charger",
|
| 44 |
+
"notebook",
|
| 45 |
+
"folder",
|
| 46 |
+
"backpack",
|
| 47 |
+
"pouch",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def appendRewardFeedback(state, feedback, reward):
|
| 52 |
+
state.rewardFeedback.append(feedback)
|
| 53 |
+
state.rewardList.append(reward)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def initDimentions(obj):
|
| 57 |
+
dims = obj.get("dims")
|
| 58 |
+
if dims is None:
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
return ones(dims, dtype=int).tolist()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def initGrid():
|
| 65 |
+
sizeX, sizeY, sizeZ = randint(8, 12), randint(8, 12), randint(8, 12)
|
| 66 |
+
grid = zeros((sizeX, sizeY, sizeZ), dtype=int).tolist()
|
| 67 |
+
|
| 68 |
+
numObjs = randint(3, len(OBJECT_NAMES) + 1)
|
| 69 |
+
chosenNames = random.choice(OBJECT_NAMES, size=numObjs, replace=False)
|
| 70 |
+
|
| 71 |
+
placed = {}
|
| 72 |
+
|
| 73 |
+
for name in chosenNames:
|
| 74 |
+
obj = OBJECTS.get(name)
|
| 75 |
+
|
| 76 |
+
dimX, dimY, dimZ = obj["dims"]
|
| 77 |
+
|
| 78 |
+
if dimX > sizeX or dimY > sizeY or dimZ > sizeZ:
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
isPlaced = False
|
| 82 |
+
tryPlaced = 0
|
| 83 |
+
|
| 84 |
+
while not isPlaced and tryPlaced < 100:
|
| 85 |
+
posX = randint(0, sizeX - dimX + 1)
|
| 86 |
+
posY = randint(0, sizeY - dimY + 1)
|
| 87 |
+
posZ = 0
|
| 88 |
+
|
| 89 |
+
canPlace = True
|
| 90 |
+
for i in range(dimX):
|
| 91 |
+
for j in range(dimY):
|
| 92 |
+
for k in range(dimZ):
|
| 93 |
+
if (
|
| 94 |
+
grid[posX + i][posY + j][posZ + k] != 0
|
| 95 |
+
and obj["stack"] == False
|
| 96 |
+
):
|
| 97 |
+
canPlace = False
|
| 98 |
+
break
|
| 99 |
+
else:
|
| 100 |
+
canPlace = True
|
| 101 |
+
if not canPlace:
|
| 102 |
+
break
|
| 103 |
+
if not canPlace:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
if canPlace:
|
| 107 |
+
for i in range(dimX):
|
| 108 |
+
for j in range(dimY):
|
| 109 |
+
for k in range(dimZ):
|
| 110 |
+
if (
|
| 111 |
+
obj["stack"]
|
| 112 |
+
and grid[posX + i][posY + j][posZ + k] > 0
|
| 113 |
+
and posZ + k + 1 < sizeZ
|
| 114 |
+
):
|
| 115 |
+
grid[posX + i][posY + j][posZ + k + 1] += 1
|
| 116 |
+
else:
|
| 117 |
+
grid[posX + i][posY + j][posZ + k] += 1
|
| 118 |
+
|
| 119 |
+
placed[name] = (posX, posY, posZ, obj["stack"])
|
| 120 |
+
isPlaced = True
|
| 121 |
+
|
| 122 |
+
return (grid, placed)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def initWeightedGrid():
|
| 126 |
+
grid = random.uniform(0, 1, (randint(5, 11), randint(5, 11), randint(5, 11)))
|
| 127 |
+
|
| 128 |
+
x_mid = grid.shape[0] // 2
|
| 129 |
+
x_span = grid.shape[0] // 4
|
| 130 |
+
y_front = grid.shape[1] // 3
|
| 131 |
+
|
| 132 |
+
grid[x_mid - x_span : x_mid + x_span, :y_front, :] *= 0.2
|
| 133 |
+
|
| 134 |
+
return grid
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def place(segment, objects, state):
|
| 138 |
+
dims = state.currentGrid
|
| 139 |
+
weight = state.weightedGrid
|
| 140 |
+
reward = 0.0
|
| 141 |
+
totalObjs = len(objects)
|
| 142 |
+
reward_per_obj_placed = 45.0 / totalObjs
|
| 143 |
+
|
| 144 |
+
if segment:
|
| 145 |
+
appendRewardFeedback(
|
| 146 |
+
state, "Placing objects without segmentation is not allowed.", -60.0
|
| 147 |
+
)
|
| 148 |
+
return -60.0
|
| 149 |
+
|
| 150 |
+
for obj_name, pos in objects.items():
|
| 151 |
+
obj = OBJECTS.get(obj_name)
|
| 152 |
+
if obj is None:
|
| 153 |
+
appendRewardFeedback(
|
| 154 |
+
state, f"Object '{obj_name}' is not recognized.", -reward_per_obj_placed
|
| 155 |
+
)
|
| 156 |
+
reward -= reward_per_obj_placed
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
objGrid = initDimentions(obj)
|
| 160 |
+
placement_failed = False
|
| 161 |
+
|
| 162 |
+
for i in range(len(objGrid)):
|
| 163 |
+
for j in range(len(objGrid[0])):
|
| 164 |
+
for k in range(len(objGrid[0][0])):
|
| 165 |
+
if (
|
| 166 |
+
pos[0] + i >= len(dims)
|
| 167 |
+
or pos[1] + j >= len(dims[0])
|
| 168 |
+
or pos[2] + k >= len(dims[0][0])
|
| 169 |
+
):
|
| 170 |
+
reward -= reward_per_obj_placed
|
| 171 |
+
appendRewardFeedback(
|
| 172 |
+
state,
|
| 173 |
+
f"Object '{obj_name}' placement is out of bounds.",
|
| 174 |
+
-reward_per_obj_placed,
|
| 175 |
+
)
|
| 176 |
+
placement_failed = True
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
if dims[pos[0] + i][pos[1] + j][pos[2] + k] > 0 and pos[3] == False:
|
| 180 |
+
reward -= reward_per_obj_placed
|
| 181 |
+
appendRewardFeedback(
|
| 182 |
+
state,
|
| 183 |
+
f"Object '{obj_name}' placement overlaps with another object and stacking is not allowed.",
|
| 184 |
+
-reward_per_obj_placed,
|
| 185 |
+
)
|
| 186 |
+
placement_failed = True
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
elif (
|
| 190 |
+
dims[pos[0] + i][pos[1] + j][pos[2] + k] > 0 and pos[3] == True
|
| 191 |
+
):
|
| 192 |
+
if pos[2] + k + 1 <= len(objGrid[0][0]):
|
| 193 |
+
dims[pos[0] + i][pos[1] + j][pos[2] + k + 1] += 1
|
| 194 |
+
reward += (
|
| 195 |
+
weight[pos[0] + i][pos[1] + j][pos[2] + k + 1]
|
| 196 |
+
* reward_per_obj_placed
|
| 197 |
+
)
|
| 198 |
+
appendRewardFeedback(
|
| 199 |
+
state,
|
| 200 |
+
f"Object '{obj_name}' placed with stacking. Bonus: {weight[pos[0] + i][pos[1] + j][pos[2] + k + 1] * reward_per_obj_placed:.2f}",
|
| 201 |
+
weight[pos[0] + i][pos[1] + j][pos[2] + k + 1]
|
| 202 |
+
* reward_per_obj_placed,
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
reward -= reward_per_obj_placed
|
| 206 |
+
appendRewardFeedback(
|
| 207 |
+
state,
|
| 208 |
+
f"Object '{obj_name}' placement failed. No space for stacking.",
|
| 209 |
+
-reward_per_obj_placed,
|
| 210 |
+
)
|
| 211 |
+
placement_failed = True
|
| 212 |
+
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
dims[pos[0] + i][pos[1] + j][pos[2] + k] = 1
|
| 217 |
+
reward += (
|
| 218 |
+
reward_per_obj_placed
|
| 219 |
+
* weight[pos[0] + i][pos[1] + j][pos[2] + k]
|
| 220 |
+
)
|
| 221 |
+
appendRewardFeedback(
|
| 222 |
+
state,
|
| 223 |
+
f"Object '{obj_name}' placed successfully. Bonus: {weight[pos[0] + i][pos[1] + j][pos[2] + k] * reward_per_obj_placed:.2f}",
|
| 224 |
+
weight[pos[0] + i][pos[1] + j][pos[2] + k]
|
| 225 |
+
* reward_per_obj_placed,
|
| 226 |
+
)
|
| 227 |
+
if placement_failed:
|
| 228 |
+
break
|
| 229 |
+
if placement_failed:
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
if not placement_failed:
|
| 233 |
+
state.ObjectsPresent[obj_name] = pos
|
| 234 |
+
|
| 235 |
+
return reward
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def findobject(segment, objects, state):
|
| 239 |
+
|
| 240 |
+
if not segment:
|
| 241 |
+
appendRewardFeedback(
|
| 242 |
+
state, "Finding objects without segmentation is not allowed.", -60.0
|
| 243 |
+
)
|
| 244 |
+
return -60.0
|
| 245 |
+
|
| 246 |
+
reward = 0.0
|
| 247 |
+
glMetric = 45.0 / len(state.ObjectsPresent)
|
| 248 |
+
objs = []
|
| 249 |
+
for obj_found, pos_found in objects.items():
|
| 250 |
+
pos_real = state.ObjectsPresent.get(obj_found)
|
| 251 |
+
if pos_real is None:
|
| 252 |
+
reward -= glMetric
|
| 253 |
+
appendRewardFeedback(
|
| 254 |
+
state, f"Object '{obj_found}' not found in the environment.", -glMetric
|
| 255 |
+
)
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
if pos_found == pos_real:
|
| 259 |
+
reward += glMetric
|
| 260 |
+
appendRewardFeedback(
|
| 261 |
+
state,
|
| 262 |
+
f"Object '{obj_found}' found with correct position and stacking.",
|
| 263 |
+
glMetric,
|
| 264 |
+
)
|
| 265 |
+
objs.append(obj_found)
|
| 266 |
+
else:
|
| 267 |
+
mse = MSE(pos_real[:3], pos_found[:3])
|
| 268 |
+
reward -= mse
|
| 269 |
+
appendRewardFeedback(
|
| 270 |
+
state,
|
| 271 |
+
f"Object '{obj_found}' found with incorrect position. MSE: {mse:.2f}",
|
| 272 |
+
-mse,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if pos_found[3] != pos_real[3]:
|
| 276 |
+
reward -= glMetric / 4.0
|
| 277 |
+
appendRewardFeedback(
|
| 278 |
+
state,
|
| 279 |
+
f"Object '{obj_found}' found with incorrect stacking. Penalty: {glMetric / 4.0}",
|
| 280 |
+
-glMetric / 4.0,
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
reward += glMetric / 4.0
|
| 284 |
+
appendRewardFeedback(
|
| 285 |
+
state,
|
| 286 |
+
f"Object '{obj_found}' found with correct stacking. Bonus: {glMetric / 4.0}",
|
| 287 |
+
glMetric / 4.0,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
for obj in objs:
|
| 291 |
+
state.objectsLeft.remove(obj)
|
| 292 |
+
state.objectsFound.append(obj)
|
| 293 |
+
|
| 294 |
+
return reward
|