Clove25 commited on
Commit
d9175ae
·
verified ·
1 Parent(s): 1f758fd

Upload 41 files

Browse files
Files changed (41) hide show
  1. Dockerfile +21 -0
  2. __pycache__/inference.cpython-313.pyc +0 -0
  3. __pycache__/inference.cpython-314.pyc +0 -0
  4. inference.py +187 -0
  5. openenv.yaml +41 -0
  6. tool_use_env/README.md +256 -0
  7. tool_use_env/__init__.py +16 -0
  8. tool_use_env/__pycache__/__init__.cpython-312.pyc +0 -0
  9. tool_use_env/__pycache__/__init__.cpython-313.pyc +0 -0
  10. tool_use_env/__pycache__/__init__.cpython-314.pyc +0 -0
  11. tool_use_env/__pycache__/client.cpython-312.pyc +0 -0
  12. tool_use_env/__pycache__/client.cpython-313.pyc +0 -0
  13. tool_use_env/__pycache__/client.cpython-314.pyc +0 -0
  14. tool_use_env/__pycache__/grader.cpython-312.pyc +0 -0
  15. tool_use_env/__pycache__/models.cpython-312.pyc +0 -0
  16. tool_use_env/__pycache__/models.cpython-313.pyc +0 -0
  17. tool_use_env/agents/__pycache__/baseline.cpython-313.pyc +0 -0
  18. tool_use_env/agents/baseline.py +267 -0
  19. tool_use_env/client.py +139 -0
  20. tool_use_env/grader.py +25 -0
  21. tool_use_env/models.py +47 -0
  22. tool_use_env/openenv_tool_use_env.egg-info/PKG-INFO +9 -0
  23. tool_use_env/openenv_tool_use_env.egg-info/SOURCES.txt +20 -0
  24. tool_use_env/openenv_tool_use_env.egg-info/dependency_links.txt +1 -0
  25. tool_use_env/openenv_tool_use_env.egg-info/entry_points.txt +2 -0
  26. tool_use_env/openenv_tool_use_env.egg-info/requires.txt +5 -0
  27. tool_use_env/openenv_tool_use_env.egg-info/top_level.txt +1 -0
  28. tool_use_env/pyproject.toml +45 -0
  29. tool_use_env/server/Dockerfile +80 -0
  30. tool_use_env/server/__init__.py +11 -0
  31. tool_use_env/server/__pycache__/__init__.cpython-312.pyc +0 -0
  32. tool_use_env/server/__pycache__/__init__.cpython-313.pyc +0 -0
  33. tool_use_env/server/__pycache__/app.cpython-312.pyc +0 -0
  34. tool_use_env/server/__pycache__/app.cpython-313.pyc +0 -0
  35. tool_use_env/server/__pycache__/tool_use_env_environment.cpython-312.pyc +0 -0
  36. tool_use_env/server/__pycache__/tool_use_env_environment.cpython-313.pyc +0 -0
  37. tool_use_env/server/app.py +23 -0
  38. tool_use_env/server/requirements.txt +7 -0
  39. tool_use_env/server/tool_use_env_environment.py +222 -0
  40. tool_use_env/tests/test_tools.py +23 -0
  41. tool_use_env/uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy entire project
6
+ COPY . .
7
+
8
+ # Move into package directory
9
+ WORKDIR /app/tool_use_env
10
+
11
+ # Install uv (needed for pyproject-based install)
12
+ RUN pip install --no-cache-dir uv
13
+
14
+ # Install project + dependencies
15
+ RUN uv pip install --system -e .
16
+
17
+ # Expose port
18
+ EXPOSE 8000
19
+
20
+ # Run server
21
+ CMD ["uvicorn", "tool_use_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
__pycache__/inference.cpython-313.pyc ADDED
Binary file (6.03 kB). View file
 
__pycache__/inference.cpython-314.pyc ADDED
Binary file (6.74 kB). View file
 
inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import defaultdict
4
+
5
+ from dotenv import load_dotenv
6
+ from openai import OpenAI
7
+
8
+ from tool_use_env.client import ToolUseEnv
9
+ from tool_use_env.models import ToolUseAction
10
+
11
+ # --- Load env ---
12
+ load_dotenv()
13
+
14
+ HF_TOKEN = os.getenv("HF_TOKEN")
15
+ HF_MODEL = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
16
+
17
+ # --- HF client ---
18
+ hf_client = OpenAI(
19
+ base_url="https://router.huggingface.co/v1",
20
+ api_key=HF_TOKEN
21
+ )
22
+
23
+ # --- Reproducibility ---
24
+ random.seed(42)
25
+
26
+ # --- Global flag ---
27
+ HF_AVAILABLE = True
28
+
29
+
30
+ # 🧠 Rule-based (correct logic)
31
+ def rule_based_policy(query: str):
32
+ q = query.lower()
33
+
34
+ if any(op in q for op in ["+", "-", "*", "/"]):
35
+ return "use_calculator"
36
+
37
+ if "capital" in q or "who is" in q or "ceo" in q:
38
+ return "use_search"
39
+
40
+ return "use_search"
41
+
42
+
43
+ # 🧠 Noisy fallback (simulate LLM mistakes)
44
+ def noisy_rule_policy(query: str):
45
+ correct = rule_based_policy(query)
46
+
47
+ if random.random() < 0.08: # 8% noise
48
+ action = random.choice([
49
+ "use_calculator",
50
+ "use_search",
51
+ "answer_directly"
52
+ ])
53
+
54
+ return correct
55
+
56
+
57
+ # 🧠 LLM + fallback policy
58
+ def llm_policy(query: str):
59
+ global HF_AVAILABLE
60
+
61
+ prompt = f"""
62
+ You are an AI agent.
63
+
64
+ Choose EXACTLY one action:
65
+
66
+ - use_calculator
67
+ - use_search
68
+ - answer_directly
69
+
70
+ Query: {query}
71
+
72
+ ONLY output one action.
73
+ """
74
+
75
+ # --- Try HF only if still available ---
76
+ if HF_AVAILABLE:
77
+ try:
78
+ response = hf_client.chat.completions.create(
79
+ model=HF_MODEL,
80
+ messages=[{"role": "user", "content": prompt}],
81
+ temperature=0
82
+ )
83
+
84
+ action = response.choices[0].message.content.strip()
85
+
86
+ if random.random() < 0.08:
87
+ action = random.choice([
88
+ "use_calculator",
89
+ "use_search",
90
+ "answer_directly"
91
+ ])
92
+ if action in ["use_calculator", "use_search", "answer_directly"]:
93
+ print("[HF] Used")
94
+ return action
95
+
96
+ except Exception as e:
97
+ print("[HF FAILED → switching to fallback permanently]")
98
+ HF_AVAILABLE = False
99
+
100
+ # --- Fallback ---
101
+ return noisy_rule_policy(query)
102
+
103
+
104
+ # 🧪 Evaluation
105
+ def run_evaluation(num_episodes=50):
106
+ results = []
107
+ total_score = 0
108
+
109
+ difficulty_scores = defaultdict(list)
110
+
111
+ with ToolUseEnv(base_url="http://localhost:8000").sync() as env:
112
+ for _ in range(num_episodes):
113
+ result = env.reset()
114
+ obs = result.observation
115
+
116
+ query = obs.query
117
+
118
+ state = env.state()
119
+ difficulty = state.difficulty
120
+
121
+ action_type = llm_policy(query)
122
+ action = ToolUseAction(action_type=action_type)
123
+
124
+ result = env.step(action)
125
+ obs = result.observation
126
+
127
+ score = result.reward
128
+ total_score += score
129
+
130
+ difficulty_scores[difficulty].append(score)
131
+
132
+ results.append({
133
+ "query": query,
134
+ "difficulty": difficulty,
135
+ "action": action_type,
136
+ "score": score,
137
+ "message": obs.message
138
+ })
139
+
140
+ print(f"Score: {score:.2f}")
141
+
142
+ avg_score = total_score / num_episodes
143
+
144
+ print("\n=== OVERALL PERFORMANCE ===")
145
+ print(f"Average Score: {avg_score:.2f}")
146
+
147
+ print("\n=== DIFFICULTY BREAKDOWN ===")
148
+ for level in ["easy", "medium", "hard"]:
149
+ if difficulty_scores[level]:
150
+ avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
151
+ print(f"{level.capitalize()}: {avg:.2f}")
152
+
153
+ print("\n=== SAMPLE CASES ===")
154
+ for r in results[:5]:
155
+ print(f"\nQuery: {r['query']}")
156
+ print(f"Action: {r['action']}")
157
+ print(f"Score: {r['score']:.2f}")
158
+ print(f"Details: {r['message']}")
159
+
160
+ return results
161
+
162
+
163
+ # 📊 Failure analysis (FIXED VERSION)
164
+ def analyze_failures(results):
165
+ total = len(results)
166
+ tool_failures = 0
167
+ wrong_decisions = 0
168
+
169
+ for r in results:
170
+ score = r["score"]
171
+ action = r["action"]
172
+
173
+ if score < 0.5:
174
+ if "use_" in action:
175
+ tool_failures += 1
176
+ else:
177
+ wrong_decisions += 1
178
+
179
+ print("\n=== FAILURE ANALYSIS ===")
180
+ print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
181
+ print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
182
+
183
+
184
+ # 🚀 Run
185
+ if __name__ == "__main__":
186
+ results = run_evaluation(50)
187
+ analyze_failures(results)
openenv.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tool_use_env
2
+ description: Evaluate AI agents on reliable tool usage under uncertainty
3
+
4
+ version: 1.0
5
+
6
+ entrypoint: server.app:app
7
+
8
+ actions:
9
+ type: object
10
+ properties:
11
+ action_type:
12
+ type: string
13
+ enum:
14
+ - use_calculator
15
+ - use_search
16
+ - answer_directly
17
+
18
+ observations:
19
+ type: object
20
+ properties:
21
+ query:
22
+ type: string
23
+ tool_output:
24
+ type: string
25
+ nullable: true
26
+ message:
27
+ type: string
28
+
29
+ reward_range: [0.0, 1.0]
30
+
31
+ metadata:
32
+ difficulty_levels:
33
+ - easy
34
+ - medium
35
+ - hard
36
+
37
+ features:
38
+ - tool_selection
39
+ - partial_rewards
40
+ - decision_making
41
+ - efficiency_penalty
tool_use_env/README.md ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tool Use Env Environment Server
3
+ emoji: 📀
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # Tool Use Env Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+ hi
20
+
21
+ The simplest way to use the Tool Use Env environment is through the `ToolUseEnv` class:
22
+
23
+ ```python
24
+ from tool_use_env import ToolUseAction, ToolUseEnv
25
+
26
+ try:
27
+ # Create environment from Docker image
28
+ tool_use_envenv = ToolUseEnv.from_docker_image("tool_use_env-env:latest")
29
+
30
+ # Reset
31
+ result = tool_use_envenv.reset()
32
+ print(f"Reset: {result.observation.echoed_message}")
33
+
34
+ # Send multiple messages
35
+ messages = ["Hello, World!", "Testing echo", "Final message"]
36
+
37
+ for msg in messages:
38
+ result = tool_use_envenv.step(ToolUseAction(message=msg))
39
+ print(f"Sent: '{msg}'")
40
+ print(f" → Echoed: '{result.observation.echoed_message}'")
41
+ print(f" → Length: {result.observation.message_length}")
42
+ print(f" → Reward: {result.reward}")
43
+
44
+ finally:
45
+ # Always clean up
46
+ tool_use_envenv.close()
47
+ ```
48
+
49
+ That's it! The `ToolUseEnv.from_docker_image()` method handles:
50
+ - Starting the Docker container
51
+ - Waiting for the server to be ready
52
+ - Connecting to the environment
53
+ - Container cleanup when you call `close()`
54
+
55
+ ## Building the Docker Image
56
+
57
+ Before using the environment, you need to build the Docker image:
58
+
59
+ ```bash
60
+ # From project root
61
+ docker build -t tool_use_env-env:latest -f server/Dockerfile .
62
+ ```
63
+
64
+ ## Deploying to Hugging Face Spaces
65
+
66
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
67
+
68
+ ```bash
69
+ # From the environment directory (where openenv.yaml is located)
70
+ openenv push
71
+
72
+ # Or specify options
73
+ openenv push --namespace my-org --private
74
+ ```
75
+
76
+ The `openenv push` command will:
77
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
78
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
79
+ 3. Upload to Hugging Face (ensuring you're logged in)
80
+
81
+ ### Prerequisites
82
+
83
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
84
+
85
+ ### Options
86
+
87
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
88
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
89
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
90
+ - `--private`: Deploy the space as private (default: public)
91
+
92
+ ### Examples
93
+
94
+ ```bash
95
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
96
+ openenv push
97
+
98
+ # Push to a specific repository
99
+ openenv push --repo-id my-org/my-env
100
+
101
+ # Push with a custom base image
102
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
103
+
104
+ # Push as a private space
105
+ openenv push --private
106
+
107
+ # Combine options
108
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
109
+ ```
110
+
111
+ After deployment, your space will be available at:
112
+ `https://huggingface.co/spaces/<repo-id>`
113
+
114
+ The deployed space includes:
115
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
116
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
117
+ - **Health Check** at `/health` - Container health monitoring
118
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
119
+
120
+ ## Environment Details
121
+
122
+ ### Action
123
+ **ToolUseAction**: Contains a single field
124
+ - `message` (str) - The message to echo back
125
+
126
+ ### Observation
127
+ **ToolUseObservation**: Contains the echo response and metadata
128
+ - `echoed_message` (str) - The message echoed back
129
+ - `message_length` (int) - Length of the message
130
+ - `reward` (float) - Reward based on message length (length × 0.1)
131
+ - `done` (bool) - Always False for echo environment
132
+ - `metadata` (dict) - Additional info like step count
133
+
134
+ ### Reward
135
+ The reward is calculated as: `message_length × 0.1`
136
+ - "Hi" → reward: 0.2
137
+ - "Hello, World!" → reward: 1.3
138
+ - Empty message → reward: 0.0
139
+
140
+ ## Advanced Usage
141
+
142
+ ### Connecting to an Existing Server
143
+
144
+ If you already have a Tool Use Env environment server running, you can connect directly:
145
+
146
+ ```python
147
+ from tool_use_env import ToolUseEnv
148
+
149
+ # Connect to existing server
150
+ tool_use_envenv = ToolUseEnv(base_url="<ENV_HTTP_URL_HERE>")
151
+
152
+ # Use as normal
153
+ result = tool_use_envenv.reset()
154
+ result = tool_use_envenv.step(ToolUseAction(message="Hello!"))
155
+ ```
156
+
157
+ Note: When connecting to an existing server, `tool_use_envenv.close()` will NOT stop the server.
158
+
159
+ ### Using the Context Manager
160
+
161
+ The client supports context manager usage for automatic connection management:
162
+
163
+ ```python
164
+ from tool_use_env import ToolUseAction, ToolUseEnv
165
+
166
+ # Connect with context manager (auto-connects and closes)
167
+ with ToolUseEnv(base_url="http://localhost:8000") as env:
168
+ result = env.reset()
169
+ print(f"Reset: {result.observation.echoed_message}")
170
+ # Multiple steps with low latency
171
+ for msg in ["Hello", "World", "!"]:
172
+ result = env.step(ToolUseAction(message=msg))
173
+ print(f"Echoed: {result.observation.echoed_message}")
174
+ ```
175
+
176
+ The client uses WebSocket connections for:
177
+ - **Lower latency**: No HTTP connection overhead per request
178
+ - **Persistent session**: Server maintains your environment state
179
+ - **Efficient for episodes**: Better for many sequential steps
180
+
181
+ ### Concurrent WebSocket Sessions
182
+
183
+ The server supports multiple concurrent WebSocket connections. To enable this,
184
+ modify `server/app.py` to use factory mode:
185
+
186
+ ```python
187
+ # In server/app.py - use factory mode for concurrent sessions
188
+ app = create_app(
189
+ ToolUseEnvironment, # Pass class, not instance
190
+ ToolUseAction,
191
+ ToolUseObservation,
192
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
193
+ )
194
+ ```
195
+
196
+ Then multiple clients can connect simultaneously:
197
+
198
+ ```python
199
+ from tool_use_env import ToolUseAction, ToolUseEnv
200
+ from concurrent.futures import ThreadPoolExecutor
201
+
202
+ def run_episode(client_id: int):
203
+ with ToolUseEnv(base_url="http://localhost:8000") as env:
204
+ result = env.reset()
205
+ for i in range(10):
206
+ result = env.step(ToolUseAction(message=f"Client {client_id}, step {i}"))
207
+ return client_id, result.observation.message_length
208
+
209
+ # Run 4 episodes concurrently
210
+ with ThreadPoolExecutor(max_workers=4) as executor:
211
+ results = list(executor.map(run_episode, range(4)))
212
+ ```
213
+
214
+ ## Development & Testing
215
+
216
+ ### Direct Environment Testing
217
+
218
+ Test the environment logic directly without starting the HTTP server:
219
+
220
+ ```bash
221
+ # From the server directory
222
+ python3 server/tool_use_env_environment.py
223
+ ```
224
+
225
+ This verifies that:
226
+ - Environment resets correctly
227
+ - Step executes actions properly
228
+ - State tracking works
229
+ - Rewards are calculated correctly
230
+
231
+ ### Running Locally
232
+
233
+ Run the server locally for development:
234
+
235
+ ```bash
236
+ uvicorn server.app:app --reload
237
+ ```
238
+
239
+ ## Project Structure
240
+
241
+ ```
242
+ tool_use_env/
243
+ ├── .dockerignore # Docker build exclusions
244
+ ├── __init__.py # Module exports
245
+ ├── README.md # This file
246
+ ├── openenv.yaml # OpenEnv manifest
247
+ ├── pyproject.toml # Project metadata and dependencies
248
+ ├── uv.lock # Locked dependencies (generated)
249
+ ├── client.py # ToolUseEnv client
250
+ ├── models.py # Action and Observation models
251
+ └── server/
252
+ ├── __init__.py # Server module exports
253
+ ├── tool_use_env_environment.py # Core environment logic
254
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
255
+ └── Dockerfile # Container image definition
256
+ ```
tool_use_env/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Tool Use Env Environment."""
8
+
9
+ from .client import ToolUseEnv
10
+ from .models import ToolUseAction, ToolUseObservation
11
+
12
+ __all__ = [
13
+ "ToolUseAction",
14
+ "ToolUseObservation",
15
+ "ToolUseEnv",
16
+ ]
tool_use_env/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (364 Bytes). View file
 
tool_use_env/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (364 Bytes). View file
 
tool_use_env/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (361 Bytes). View file
 
tool_use_env/__pycache__/client.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
tool_use_env/__pycache__/client.cpython-313.pyc ADDED
Binary file (2.26 kB). View file
 
tool_use_env/__pycache__/client.cpython-314.pyc ADDED
Binary file (2.78 kB). View file
 
tool_use_env/__pycache__/grader.cpython-312.pyc ADDED
Binary file (716 Bytes). View file
 
tool_use_env/__pycache__/models.cpython-312.pyc ADDED
Binary file (1.29 kB). View file
 
tool_use_env/__pycache__/models.cpython-313.pyc ADDED
Binary file (1.41 kB). View file
 
tool_use_env/agents/__pycache__/baseline.cpython-313.pyc ADDED
Binary file (4.72 kB). View file
 
tool_use_env/agents/baseline.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from tool_use_env.client import ToolUseEnv
2
+ # from tool_use_env.models import ToolUseAction
3
+ # import random
4
+
5
+ # def rule_based_policy(query: str):
6
+ # query = query.lower()
7
+
8
+ # # --- Introduce slight imperfection ---
9
+ # if random.random() < 0.1:
10
+ # return "answer_directly"
11
+
12
+ # if "what is" in query and any(op in query for op in ["+", "-", "*", "/"]):
13
+ # return "use_calculator"
14
+
15
+ # if "capital" in query or "who is" in query:
16
+ # return "use_search"
17
+
18
+ # return "answer_directly"
19
+
20
+
21
+ # def run_single_episode(env):
22
+ # result = env.reset()
23
+ # obs = result.observation
24
+
25
+ # query = obs.query
26
+ # action_type = rule_based_policy(query)
27
+
28
+ # action = ToolUseAction(action_type=action_type)
29
+
30
+ # result = env.step(action)
31
+ # obs = result.observation
32
+
33
+ # return {
34
+ # "query": query,
35
+ # "action": action_type,
36
+ # "reward": result.reward,
37
+ # "message": obs.message
38
+ # }
39
+
40
+ # def run_evaluation(num_episodes=20):
41
+ # results = []
42
+
43
+ # difficulty_scores = {
44
+ # "easy": [],
45
+ # "medium": [],
46
+ # "hard": []
47
+ # }
48
+
49
+ # total_score = 0
50
+
51
+ # with ToolUseEnv(base_url="http://localhost:8000").sync() as env:
52
+ # for _ in range(num_episodes):
53
+ # result = env.reset()
54
+ # obs = result.observation
55
+ # query = obs.query
56
+ # state = env.state()
57
+ # difficulty = state.difficulty
58
+
59
+ # action_type = rule_based_policy(query)
60
+ # action = ToolUseAction(action_type=action_type)
61
+
62
+ # result = env.step(action)
63
+
64
+ # score = result.reward
65
+ # total_score += score
66
+
67
+ # difficulty_scores[difficulty].append(score)
68
+
69
+ # results.append({
70
+ # "query": query,
71
+ # "difficulty": difficulty,
72
+ # "action": action_type,
73
+ # "score": score,
74
+ # "message": result.observation.message
75
+ # })
76
+
77
+ # avg_score = total_score / num_episodes
78
+
79
+ # print("\n=== OVERALL PERFORMANCE ===")
80
+ # print(f"Average Score: {avg_score:.2f}")
81
+
82
+ # print("\n=== DIFFICULTY BREAKDOWN ===")
83
+ # for level in difficulty_scores:
84
+ # if difficulty_scores[level]:
85
+ # avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
86
+ # print(f"{level.capitalize()}: {avg:.2f}")
87
+
88
+ # print("\n=== SAMPLE CASES ===")
89
+ # for r in results[:5]:
90
+ # print(f"\nQuery: {r['query']}")
91
+ # print(f"Action: {r['action']}")
92
+ # print(f"Score: {r['score']:.2f}")
93
+ # print(f"Details: {r['message']}")
94
+
95
+ # return results
96
+
97
+ # def analyze_failures(results):
98
+ # wrong_decisions = 0
99
+ # tool_failures = 0
100
+ # total = len(results)
101
+
102
+ # for r in results:
103
+ # msg = r["message"]
104
+
105
+ # if "Correct: False" in msg:
106
+ # if "use_" in msg:
107
+ # tool_failures += 1
108
+ # else:
109
+ # wrong_decisions += 1
110
+
111
+ # print("\n=== FAILURE ANALYSIS ===")
112
+ # print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
113
+ # print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
114
+
115
+
116
+ # if __name__ == "__main__":
117
+ # results = run_evaluation(50)
118
+ # analyze_failures(results)
119
+
120
+ import os
121
+ import random
122
+ from collections import defaultdict
123
+
124
+ from dotenv import load_dotenv
125
+ from openai import OpenAI
126
+
127
+ from tool_use_env.client import ToolUseEnv
128
+ from tool_use_env.models import ToolUseAction
129
+
130
+ # --- Load environment variables ---
131
+ load_dotenv()
132
+
133
+ # --- Initialize OpenAI client ---
134
+ client = OpenAI()
135
+
136
+ # --- Reproducibility ---
137
+ random.seed(42)
138
+
139
+
140
+ # 🧠 LLM Policy (CORE)
141
+ def llm_policy(query: str):
142
+ prompt = f"""
143
+ You are an AI agent choosing the best tool.
144
+
145
+ Available actions:
146
+ - use_calculator (for math problems)
147
+ - use_search (for factual questions)
148
+ - answer_directly (if neither tool is needed)
149
+
150
+ Query: {query}
151
+
152
+ Respond with ONLY one of:
153
+ use_calculator
154
+ use_search
155
+ answer_directly
156
+ """
157
+
158
+ try:
159
+ response = client.chat.completions.create(
160
+ model="gpt-4o-mini",
161
+ messages=[{"role": "user", "content": prompt}],
162
+ temperature=0
163
+ )
164
+
165
+ action = response.choices[0].message.content.strip()
166
+
167
+ # --- Safety check ---
168
+ if action not in ["use_calculator", "use_search", "answer_directly"]:
169
+ return "answer_directly"
170
+
171
+ return action
172
+
173
+ except Exception as e:
174
+ print(f"[ERROR] LLM call failed: {e}")
175
+ return "answer_directly"
176
+
177
+
178
+ # 🧪 Evaluation Loop
179
+ def run_evaluation(num_episodes=50):
180
+ results = []
181
+ total_score = 0
182
+
183
+ difficulty_scores = defaultdict(list)
184
+
185
+ with ToolUseEnv(base_url="http://localhost:8000").sync() as env:
186
+ for _ in range(num_episodes):
187
+ # --- Reset ---
188
+ result = env.reset()
189
+ obs = result.observation
190
+
191
+ query = obs.query
192
+
193
+ # --- Get difficulty ---
194
+ state = env.state()
195
+ difficulty = state.difficulty
196
+
197
+ # --- LLM decides action ---
198
+ action_type = llm_policy(query)
199
+ action = ToolUseAction(action_type=action_type)
200
+
201
+ # --- Step ---
202
+ result = env.step(action)
203
+ obs = result.observation
204
+
205
+ score = result.reward
206
+ total_score += score
207
+
208
+ difficulty_scores[difficulty].append(score)
209
+
210
+ results.append({
211
+ "query": query,
212
+ "difficulty": difficulty,
213
+ "action": action_type,
214
+ "score": score,
215
+ "message": obs.message
216
+ })
217
+
218
+ print(f"Score: {score:.2f}")
219
+
220
+ # --- Overall ---
221
+ avg_score = total_score / num_episodes
222
+
223
+ print("\n=== OVERALL PERFORMANCE ===")
224
+ print(f"Average Score: {avg_score:.2f}")
225
+
226
+ # --- Breakdown ---
227
+ print("\n=== DIFFICULTY BREAKDOWN ===")
228
+ for level in ["easy", "medium", "hard"]:
229
+ if difficulty_scores[level]:
230
+ avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
231
+ print(f"{level.capitalize()}: {avg:.2f}")
232
+
233
+ # --- Sample Cases ---
234
+ print("\n=== SAMPLE CASES ===")
235
+ for r in results[:5]:
236
+ print(f"\nQuery: {r['query']}")
237
+ print(f"Action: {r['action']}")
238
+ print(f"Score: {r['score']:.2f}")
239
+ print(f"Details: {r['message']}")
240
+
241
+ return results
242
+
243
+
244
+ # 📊 Failure Analysis
245
+ def analyze_failures(results):
246
+ total = len(results)
247
+ tool_failures = 0
248
+ wrong_decisions = 0
249
+
250
+ for r in results:
251
+ msg = r["message"]
252
+
253
+ if "Correct: False" in msg:
254
+ if "use_" in msg:
255
+ tool_failures += 1
256
+ else:
257
+ wrong_decisions += 1
258
+
259
+ print("\n=== FAILURE ANALYSIS ===")
260
+ print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
261
+ print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
262
+
263
+
264
+ # 🚀 Main
265
+ if __name__ == "__main__":
266
+ results = run_evaluation(50)
267
+ analyze_failures(results)
tool_use_env/client.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # # All rights reserved.
3
+ # #
4
+ # # This source code is licensed under the BSD-style license found in the
5
+ # # LICENSE file in the root directory of this source tree.
6
+
7
+ # """Tool Use Env Environment Client."""
8
+
9
+ # from typing import Dict
10
+
11
+ # from openenv.core import EnvClient
12
+ # from openenv.core.client_types import StepResult
13
+ # from openenv.core.env_server.types import State
14
+
15
+ # from .models import ToolUseAction, ToolUseObservation
16
+
17
+
18
+ # class ToolUseEnv(
19
+ # EnvClient[ToolUseAction, ToolUseObservation, State]
20
+ # ):
21
+ # """
22
+ # Client for the Tool Use Env Environment.
23
+
24
+ # This client maintains a persistent WebSocket connection to the environment server,
25
+ # enabling efficient multi-step interactions with lower latency.
26
+ # Each client instance has its own dedicated environment session on the server.
27
+
28
+ # Example:
29
+ # >>> # Connect to a running server
30
+ # >>> with ToolUseEnv(base_url="http://localhost:8000") as client:
31
+ # ... result = client.reset()
32
+ # ... print(result.observation.echoed_message)
33
+ # ...
34
+ # ... result = client.step(ToolUseAction(message="Hello!"))
35
+ # ... print(result.observation.echoed_message)
36
+
37
+ # Example with Docker:
38
+ # >>> # Automatically start container and connect
39
+ # >>> client = ToolUseEnv.from_docker_image("tool_use_env-env:latest")
40
+ # >>> try:
41
+ # ... result = client.reset()
42
+ # ... result = client.step(ToolUseAction(message="Test"))
43
+ # ... finally:
44
+ # ... client.close()
45
+ # """
46
+
47
+ # def _step_payload(self, action: ToolUseAction) -> Dict:
48
+ # """
49
+ # Convert ToolUseAction to JSON payload for step message.
50
+
51
+ # Args:
52
+ # action: ToolUseAction instance
53
+
54
+ # Returns:
55
+ # Dictionary representation suitable for JSON encoding
56
+ # """
57
+ # return {
58
+ # "message": action.message,
59
+ # }
60
+
61
+ # def _parse_result(self, payload: Dict) -> StepResult[ToolUseObservation]:
62
+ # """
63
+ # Parse server response into StepResult[ToolUseObservation].
64
+
65
+ # Args:
66
+ # payload: JSON response data from server
67
+
68
+ # Returns:
69
+ # StepResult with ToolUseObservation
70
+ # """
71
+ # obs_data = payload.get("observation", {})
72
+ # observation = ToolUseObservation(
73
+ # echoed_message=obs_data.get("echoed_message", ""),
74
+ # message_length=obs_data.get("message_length", 0),
75
+ # done=payload.get("done", False),
76
+ # reward=payload.get("reward"),
77
+ # metadata=obs_data.get("metadata", {}),
78
+ # )
79
+
80
+ # return StepResult(
81
+ # observation=observation,
82
+ # reward=payload.get("reward"),
83
+ # done=payload.get("done", False),
84
+ # )
85
+
86
+ # def _parse_state(self, payload: Dict) -> State:
87
+ # """
88
+ # Parse server response into State object.
89
+
90
+ # Args:
91
+ # payload: JSON response from state request
92
+
93
+ # Returns:
94
+ # State object with episode_id and step_count
95
+ # """
96
+ # return State(
97
+ # episode_id=payload.get("episode_id"),
98
+ # step_count=payload.get("step_count", 0),
99
+ # )
100
+
101
+ from openenv.core.env_client import EnvClient
102
+ from openenv.core.client_types import StepResult
103
+
104
+ from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState
105
+
106
+
107
+ class ToolUseEnv(EnvClient[ToolUseAction, ToolUseObservation, ToolUseState]):
108
+
109
+ def _step_payload(self, action: ToolUseAction) -> dict:
110
+ return {
111
+ "action_type": action.action_type
112
+ }
113
+
114
+ def _parse_result(self, payload: dict) -> StepResult:
115
+ obs_data = payload.get("observation", {})
116
+
117
+ observation = ToolUseObservation(
118
+ done=payload.get("done", False),
119
+ reward=payload.get("reward"),
120
+ query=obs_data.get("query", ""),
121
+ tool_output=obs_data.get("tool_output"),
122
+ message=obs_data.get("message", "")
123
+ )
124
+
125
+ return StepResult(
126
+ observation=observation,
127
+ reward=payload.get("reward"),
128
+ done=payload.get("done", False),
129
+ )
130
+
131
+ def _parse_state(self, payload: dict) -> ToolUseState:
132
+ return ToolUseState(
133
+ episode_id=payload.get("episode_id"),
134
+ step_count=payload.get("step_count", 0),
135
+ current_query=payload.get("current_query", ""),
136
+ correct_action=payload.get("correct_action", ""),
137
+ correct_answer=payload.get("correct_answer", ""),
138
+ difficulty=payload.get("difficulty", "")
139
+ )
tool_use_env/grader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def compute_grade(action_taken, correct_action, output, correct_answer):
2
+ """
3
+ Returns score between 0.0 and 1.0
4
+ """
5
+
6
+ # 1. Action correctness
7
+ action_correct = 1.0 if action_taken == correct_action else 0.0
8
+
9
+ # 2. Answer correctness
10
+ answer_correct = 1.0 if output == correct_answer else 0.0
11
+
12
+ # 3. Efficiency (simple version)
13
+ if action_taken in ["use_calculator", "use_search"]:
14
+ efficiency = 0.5 # using tool has cost
15
+ else:
16
+ efficiency = 1.0 # direct answer is efficient
17
+
18
+ # Final score
19
+ score = (
20
+ 0.4 * action_correct +
21
+ 0.5 * answer_correct +
22
+ 0.1 * efficiency
23
+ )
24
+
25
+ return round(score, 2)
tool_use_env/models.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the Tool Use Env Environment.
9
+
10
+ The tool_use_env environment is a simple test environment that echoes back messages.
11
+ """
12
+
13
+ # from openenv.core.env_server.types import Action, Observation
14
+ # from pydantic import Field
15
+
16
+
17
+ # class ToolUseAction(Action):
18
+ # """Action for the Tool Use Env environment - just a message to echo."""
19
+
20
+ # message: str = Field(..., description="Message to echo back")
21
+
22
+
23
+ # class ToolUseObservation(Observation):
24
+ # """Observation from the Tool Use Env environment - the echoed message."""
25
+
26
+ # echoed_message: str = Field(default="", description="The echoed message")
27
+ # message_length: int = Field(default=0, description="Length of the echoed message")
28
+
29
+ from openenv.core.env_server import Action, Observation, State
30
+ from typing import Optional
31
+
32
+
33
+ class ToolUseAction(Action):
34
+ action_type: str
35
+
36
+
37
+ class ToolUseObservation(Observation):
38
+ query: str
39
+ tool_output: Optional[str]
40
+ message: str
41
+
42
+
43
+ class ToolUseState(State):
44
+ current_query: str = ""
45
+ correct_action: str = ""
46
+ correct_answer: str = ""
47
+ difficulty: str = ""
tool_use_env/openenv_tool_use_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-tool_use_env
3
+ Version: 0.1.0
4
+ Summary: Tool Use Env environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.1
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
tool_use_env/openenv_tool_use_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ __init__.py
3
+ client.py
4
+ grader.py
5
+ models.py
6
+ pyproject.toml
7
+ ./__init__.py
8
+ ./client.py
9
+ ./grader.py
10
+ ./models.py
11
+ openenv_tool_use_env.egg-info/PKG-INFO
12
+ openenv_tool_use_env.egg-info/SOURCES.txt
13
+ openenv_tool_use_env.egg-info/dependency_links.txt
14
+ openenv_tool_use_env.egg-info/entry_points.txt
15
+ openenv_tool_use_env.egg-info/requires.txt
16
+ openenv_tool_use_env.egg-info/top_level.txt
17
+ server/__init__.py
18
+ server/app.py
19
+ server/tool_use_env_environment.py
20
+ tests/test_tools.py
tool_use_env/openenv_tool_use_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
tool_use_env/openenv_tool_use_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = tool_use_env.server.app:main
tool_use_env/openenv_tool_use_env.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
tool_use_env/openenv_tool_use_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tool_use_env
tool_use_env/pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-tool_use_env"
13
+ version = "0.1.0"
14
+ description = "Tool Use Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.1",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m tool_use_env.server.app
40
+ server = "tool_use_env.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["tool_use_env", "tool_use_env.server"]
45
+ package-dir = { "tool_use_env" = ".", "tool_use_env.server" = "server" }
tool_use_env/server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=tool_use_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
tool_use_env/server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Tool Use Env environment server components."""
8
+
9
+ from .tool_use_env_environment import ToolUseEnvironment
10
+
11
+ __all__ = ["ToolUseEnvironment"]
tool_use_env/server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
tool_use_env/server/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (400 Bytes). View file
 
tool_use_env/server/__pycache__/app.cpython-312.pyc ADDED
Binary file (886 Bytes). View file
 
tool_use_env/server/__pycache__/app.cpython-313.pyc ADDED
Binary file (2.8 kB). View file
 
tool_use_env/server/__pycache__/tool_use_env_environment.cpython-312.pyc ADDED
Binary file (6.22 kB). View file
 
tool_use_env/server/__pycache__/tool_use_env_environment.cpython-313.pyc ADDED
Binary file (3.83 kB). View file
 
tool_use_env/server/app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openenv.core.env_server.http_server import create_app
2
+
3
+ from tool_use_env.models import ToolUseAction, ToolUseObservation
4
+ from tool_use_env.server.tool_use_env_environment import ToolUseEnvironment
5
+
6
+
7
+ app = create_app(
8
+ ToolUseEnvironment,
9
+ ToolUseAction,
10
+ ToolUseObservation,
11
+ env_name="tool_use_env",
12
+ max_concurrent_envs=1,
13
+ )
14
+
15
+
16
+ import uvicorn
17
+
18
+ def main(host: str = "0.0.0.0", port: int = 8000):
19
+ uvicorn.run("tool_use_env.server.app:app", host=host, port=port)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main()
tool_use_env/server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openenv
2
+ fastapi
3
+ dotenv
4
+ uvicorn
5
+ pydantic
6
+ python-dotenv
7
+ openai
tool_use_env/server/tool_use_env_environment.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import uuid
3
+
4
+ from openenv.core.env_server import Environment
5
+ from tool_use_env.models import ToolUseAction, ToolUseObservation, ToolUseState
6
+ from tool_use_env.grader import compute_grade
7
+
8
+
9
+ class ToolUseEnvironment(Environment):
10
+ SUPPORTS_CONCURRENT_SESSIONS = True
11
+
12
+ def __init__(self):
13
+ self._state = ToolUseState()
14
+ self._tasks = self._load_tasks()
15
+
16
+ def _load_tasks(self):
17
+ return [
18
+ {
19
+ "query": "What is 5 + 7?",
20
+ "answer": "12",
21
+ "correct_action": "use_calculator",
22
+ "difficulty": "easy"
23
+ },
24
+ {
25
+ "query": "Capital of France?",
26
+ "answer": "Paris",
27
+ "correct_action": "use_search",
28
+ "difficulty": "easy"
29
+ },
30
+ {
31
+ "query": "What is 123 * 456?",
32
+ "answer": "56088",
33
+ "correct_action": "use_calculator",
34
+ "difficulty": "hard"
35
+ },
36
+ {
37
+ "query": "What is 25 * 4?",
38
+ "answer": "100",
39
+ "correct_action": "use_calculator",
40
+ "difficulty": "medium"
41
+ },
42
+ {
43
+ "query": "Who is the CEO of Tesla?",
44
+ "answer": "Elon Musk",
45
+ "correct_action": "use_search",
46
+ "difficulty": "medium"
47
+ }
48
+ ]
49
+
50
+ def reset(self, seed=None, episode_id=None, **kwargs) -> ToolUseObservation:
51
+ task = random.choice(self._tasks)
52
+
53
+ self._state = ToolUseState(
54
+ episode_id=episode_id or str(uuid.uuid4()),
55
+ step_count=0,
56
+ current_query=task["query"],
57
+ correct_action=task["correct_action"],
58
+ correct_answer=task["answer"],
59
+ difficulty=task["difficulty"]
60
+ )
61
+
62
+ return ToolUseObservation(
63
+ done=False,
64
+ reward=None,
65
+ query=task["query"],
66
+ tool_output=None,
67
+ message="Choose an action"
68
+ )
69
+
70
+ # 🔢 Calculator tool (controlled noise)
71
+ def _calculator(self, query):
72
+ try:
73
+ expr = query.lower()
74
+ expr = expr.replace("what is", "").replace("?", "").strip()
75
+ correct = eval(expr)
76
+
77
+ difficulty = self._state.difficulty
78
+
79
+ if difficulty == "easy":
80
+ fail_prob = 0.06
81
+ elif difficulty == "medium":
82
+ fail_prob = 0.12
83
+ else:
84
+ fail_prob = 0.18
85
+
86
+ # complexity-based failure
87
+ if len(query) > 20:
88
+ fail_prob += 0.05
89
+
90
+ # 🔥 cap failure (IMPORTANT)
91
+ fail_prob = min(fail_prob, 0.25)
92
+
93
+ if random.random() < fail_prob:
94
+ # 🔥 scale noise based on magnitude
95
+ if abs(correct) < 50:
96
+ noise = random.randint(-2, 2)
97
+ else:
98
+ noise = int(correct * random.uniform(-0.05, 0.05))
99
+
100
+ return str(correct + noise)
101
+
102
+ return str(correct)
103
+
104
+ except Exception:
105
+ return "error"
106
+
107
+ # 🔍 Search tool (controlled noise)
108
+ def _search(self, query):
109
+ kb = {
110
+ "Capital of France": "Paris",
111
+ "CEO of Tesla": "Elon Musk"
112
+ }
113
+
114
+ difficulty = self._state.difficulty
115
+
116
+ for key in kb:
117
+ if key.lower() in query.lower():
118
+
119
+ if difficulty == "easy":
120
+ fail_prob = 0.07
121
+ elif difficulty == "medium":
122
+ fail_prob = 0.15
123
+ else:
124
+ fail_prob = 0.22
125
+
126
+ # complexity-based failure
127
+ if len(query) > 20:
128
+ fail_prob += 0.05
129
+
130
+ # 🔥 cap failure
131
+ fail_prob = min(fail_prob, 0.30)
132
+
133
+ if random.random() < fail_prob:
134
+ return random.choice([
135
+ "Unknown",
136
+ "Not sure",
137
+ "No results found"
138
+ ])
139
+
140
+ return kb[key]
141
+
142
+ return "not found"
143
+
144
+ def step(self, action: ToolUseAction, timeout_s=None, **kwargs) -> ToolUseObservation:
145
+ self._state.step_count += 1
146
+
147
+ query = self._state.current_query
148
+ correct_action = self._state.correct_action
149
+ correct_answer = self._state.correct_answer
150
+ difficulty = self._state.difficulty
151
+
152
+ action_type = action.action_type
153
+
154
+ # --- Execute tool ---
155
+ if action_type == "use_calculator":
156
+ output = self._calculator(query)
157
+ elif action_type == "use_search":
158
+ output = self._search(query)
159
+ elif action_type == "answer_directly":
160
+ output = "unknown"
161
+ else:
162
+ output = "invalid action"
163
+
164
+ # --- Check correctness ---
165
+ answer_correct = (output == correct_answer)
166
+
167
+ # 🧠 REWARD SYSTEM (FINAL)
168
+
169
+ # 1. Action correctness
170
+ action_score = 0.4 if action_type == correct_action else 0.1
171
+
172
+ # 2. Answer correctness
173
+ answer_score = 0.5 if answer_correct else 0.0
174
+
175
+ # 3. Tool cost (small penalty)
176
+ if action_type == "use_calculator":
177
+ tool_penalty = 0.05
178
+ elif action_type == "use_search":
179
+ tool_penalty = 0.08
180
+ else:
181
+ tool_penalty = 0.0
182
+
183
+ # 4. Failure bonus (good reasoning but tool failed)
184
+ failure_bonus = 0.1 if (not answer_correct and action_type == correct_action) else 0.0
185
+
186
+ # 5. Combine
187
+ reward = action_score + answer_score + failure_bonus - tool_penalty
188
+
189
+ # 6. Difficulty scaling (light)
190
+ if difficulty == "medium":
191
+ reward *= 1.02
192
+ elif difficulty == "hard":
193
+ reward *= 0.9
194
+
195
+ # 7. Clamp (VERY IMPORTANT)
196
+ reward = max(0.0, min(1.0, reward))
197
+
198
+ # --- Grade (for reporting only) ---
199
+ grade = compute_grade(
200
+ action_taken=action_type,
201
+ correct_action=correct_action,
202
+ output=output,
203
+ correct_answer=correct_answer
204
+ )
205
+
206
+ return ToolUseObservation(
207
+ done=True,
208
+ reward=reward,
209
+ query=query,
210
+ tool_output=output,
211
+ message=(
212
+ f"Action: {action_type}, "
213
+ f"Output: {output}, "
214
+ f"Correct: {answer_correct}, "
215
+ f"Reward: {reward:.2f}, "
216
+ f"Grade: {grade:.2f}"
217
+ )
218
+ )
219
+
220
+ @property
221
+ def state(self) -> ToolUseState:
222
+ return self._state
tool_use_env/tests/test_tools.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from server.tool_use_env_environment import ToolUseEnvironment
2
+
3
+ env = ToolUseEnvironment()
4
+
5
+ def test_calculator_correct():
6
+ result = env._calculator("What is 2 + 2?")
7
+ assert result in ["4", "3", "5"] # allow noise
8
+
9
+ def test_search():
10
+ result = env._search("Capital of France?")
11
+ assert result in ["Paris", "Unknown"]
12
+
13
+ def test_step_output():
14
+ env = ToolUseEnvironment()
15
+
16
+ action = {"action_type": "use_calculator"}
17
+ result = env.step(action)
18
+ obs = result.observation
19
+ print(obs.query)
20
+
21
+ assert -1 <= result.reward <= 1
22
+ assert result.query is not None
23
+
tool_use_env/uv.lock ADDED
The diff for this file is too large to render. See raw diff