Container_yard / inference.py
Draken1606's picture
Initial Container Yard env submission
cc75d6e
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Inference script for Container Yard environment using OpenAI API.
This script evaluates a language model's ability to solve container yard placement
tasks using the hackathon-specified output format.
"""
import os
import sys
import json
from typing import Optional
# Load environment variables from .env file if it exists
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # python-dotenv not installed, use system env vars
from openai import OpenAI
# Read environment variables with defaults
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
# Initialize OpenAI client
client = OpenAI(
base_url=API_BASE_URL,
api_key=HF_TOKEN
)
# Import environment
from server.Container_Yard_environment import ContainerYardEnvironment
from models import ContainerYardAction
def extract_stack_choice(response: str, num_stacks: int) -> Optional[int]:
"""
Extract stack choice from LLM response.
Looks for patterns like "stack 0", "stack=3", or just a number.
Returns None if extraction fails.
"""
response_lower = response.lower().strip()
# Try pattern: "stack X"
words = response_lower.split()
for i, word in enumerate(words):
if "stack" in word and i + 1 < len(words):
try:
stack_idx = int(words[i + 1])
if 0 <= stack_idx < num_stacks:
return stack_idx
except ValueError:
pass
# Try extracting any number
import re
numbers = re.findall(r'\d+', response_lower)
if numbers:
try:
stack_idx = int(numbers[0])
if 0 <= stack_idx < num_stacks:
return stack_idx
except (ValueError, IndexError):
pass
return None
def run_task(task_name: str = "medium") -> dict:
"""
Run a single task in the Container Yard environment.
Args:
task_name: "easy", "medium", or "hard"
Returns:
dict with episode results
"""
env = ContainerYardEnvironment(task_name=task_name)
obs = env.reset()
print(f"[START] task={task_name} env=container-yard model={MODEL_NAME}")
sys.stdout.flush()
step_count = 0
all_rewards = []
success = False
last_error = None
efficiency_score = 0.0
try:
while not obs.done and step_count < 100:
step_count += 1
prompt = f"""You are managing a container yard.
Current state:
- Container to place: ID={obs.current_container_id}, Priority={obs.current_container_priority}
- Available stacks: {obs.num_stacks} stacks (0-{obs.num_stacks-1})
- Max stack height: {obs.max_stack_height}
- Current stacks: {json.dumps(obs.stacks)}
- Rehandles so far: {obs.rehandles_so_far}
Place the container in the stack that minimizes future rehandles.
Reply with ONLY the stack number (0-{obs.num_stacks-1}). No explanation needed."""
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=10,
)
action_str = (response.choices[0].message.content or "").strip()
except Exception as e:
action_str = "0"
last_error = str(e)
stack_idx = extract_stack_choice(action_str, obs.num_stacks)
if stack_idx is None:
stack_idx = 0
try:
action = ContainerYardAction(stack_index=stack_idx)
obs = env.step(action)
reward_value = float(obs.reward or 0.0)
all_rewards.append(reward_value)
error_msg = obs.action_error if obs.action_error else "null"
print(
f"[STEP] step={step_count} action=place({stack_idx}) "
f"reward={reward_value:.2f} done={str(obs.done).lower()} error={error_msg}"
)
sys.stdout.flush()
if obs.done:
success = True
break
except Exception as e:
last_error = str(e)
print(f"[STEP] step={step_count} action=place({stack_idx}) reward=0.00 done=true error={last_error}")
sys.stdout.flush()
break
if step_count > 0:
efficiency_score = 1.0 - (obs.rehandles_so_far / max(obs.total_containers, 1))
success = success and step_count == obs.total_containers
except Exception as e:
last_error = str(e)
finally:
close_fn = getattr(env, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception:
pass
rewards_str = ",".join([f"{r:.2f}" for r in all_rewards])
print(f"[END] success={str(success).lower()} steps={step_count} rewards={rewards_str}")
sys.stdout.flush()
return {
"task": task_name,
"success": success,
"steps": step_count,
"total_rewards": sum(all_rewards),
"rehandles": obs.rehandles_so_far,
"efficiency": efficiency_score,
}
def main():
"""Run all three tasks."""
tasks = ["easy", "medium", "hard"]
results = []
for task in tasks:
try:
result = run_task(task)
results.append(result)
except Exception as e:
print(f"[ERROR] Task {task} failed: {e}", file=sys.stderr)
results.append({
"task": task,
"success": False,
"steps": 0,
"total_rewards": 0.0,
"error": str(e),
})
# Summary
print("\n=== Summary ===", file=sys.stderr)
for result in results:
print(f"Task {result['task']}: success={result['success']}, efficiency={result.get('efficiency', 0.0):.2f}", file=sys.stderr)
if __name__ == "__main__":
main()