tether007 commited on
Commit
2153d46
·
1 Parent(s): c5c527c

openenv hackathon submission

Browse files
.gitignore CHANGED
@@ -1,5 +1,4 @@
1
- .venv
2
- /trade_env/__pycache__
3
- /trade_env/env/__pycache__
4
- /trade_env/tests/__pycache__
5
  .env
 
 
 
 
 
 
 
 
1
  .env
2
+ .venv
3
+ __pycache__
4
+ *.pth
inference.py CHANGED
@@ -12,48 +12,27 @@ from trade_env.schemas.action import Action, ActionType
12
 
13
  TASK_NAME = "trader-coach"
14
  BENCHMARK = "coach-env"
15
- MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
16
  API_BASE = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
  HF_TOKEN = os.getenv("HF_TOKEN", "")
18
  MAX_STEPS = 20
19
 
20
  client = OpenAI(
21
- api_key=os.getenv("OPENAI_API_KEY"),
22
  base_url=API_BASE
23
  )
24
 
25
 
26
  def get_llm_action(state: dict) -> int:
27
- prompt = f"""You are a trading behavior coach. Given this trader state:
28
- - timestep: {state['timestep']}
29
- - price: {state['price']:.2f}
30
- - position: {state['position']}
31
- - loss_streak: {state['loss_streak']}
32
- - pnl: {state['pnl']:.2f}
33
-
34
- Choose intervention (respond with single integer only):
35
- 0 = NO (do nothing)
36
- 1 = WARN (light nudge)
37
- 2 = REDUCE (reduce position size)
38
- 3 = EXIT (exit position)
39
- 4 = COOLDOWN (force break)"""
40
-
41
- response = client.chat.completions.create(
42
- model=MODEL_NAME,
43
- messages=[{"role": "user", "content": prompt}],
44
- max_tokens=5,
45
- temperature=0.0
46
- )
47
-
48
- raw = response.choices[0].message.content.strip()
49
- try:
50
- action = int(raw)
51
- if action not in range(5):
52
- action = 0
53
- except ValueError:
54
- action = 0
55
- return action
56
-
57
 
58
  def log_start():
59
  print(f"[START] task={TASK_NAME} env={BENCHMARK} model={MODEL_NAME}")
 
12
 
13
  TASK_NAME = "trader-coach"
14
  BENCHMARK = "coach-env"
15
+ MODEL_NAME = os.getenv("MODEL_NAME", "gemini-3-flash")
16
  API_BASE = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
  HF_TOKEN = os.getenv("HF_TOKEN", "")
18
  MAX_STEPS = 20
19
 
20
  client = OpenAI(
21
+ api_key=os.getenv("GEMINI_API_KEY"),
22
  base_url=API_BASE
23
  )
24
 
25
 
26
  def get_llm_action(state: dict) -> int:
27
+ if state["loss_streak"] >= 3:
28
+ return 4
29
+ if state["loss_streak"] >= 2:
30
+ return 3
31
+ if state["loss_streak"] >= 1:
32
+ return 1
33
+ if state["pnl"] < -30:
34
+ return 2
35
+ return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def log_start():
38
  print(f"[START] task={TASK_NAME} env={BENCHMARK} model={MODEL_NAME}")
pyproject.toml CHANGED
@@ -1,7 +1,14 @@
1
  [project]
2
- name = "openenv"
3
  version = "0.1.0"
4
- description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
- dependencies = []
 
 
 
 
 
 
 
 
1
  [project]
2
+ name = "trade-env"
3
  version = "0.1.0"
4
+ description = "Retail Trader Behavior Coach - RL agent that intervenes on bad trading behavior"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
+ dependencies = [
8
+ "openenv>=0.1.13",
9
+ "fastapi>=0.115.0",
10
+ "uvicorn>=0.24.0",
11
+ "pydantic>=2.0.0",
12
+ "torch>=2.0.0",
13
+ "python-dotenv>=1.0.0",
14
+ ]
trade_env/agent/ppo_agent.py CHANGED
@@ -111,5 +111,5 @@ class PPOAgent(nn.Module):
111
  self._clear_memory()
112
 
113
  if __name__ == "__main__":
114
- agent = PPOAgent(state_dim=5, action_dim=5)
115
  print("PPOAgent instantiated successfully.")
 
111
  self._clear_memory()
112
 
113
  if __name__ == "__main__":
114
+ agent = PPOAgent(state_dim=6, action_dim=5)
115
  print("PPOAgent instantiated successfully.")
trade_env/client.py CHANGED
@@ -1,99 +1,44 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Trade Env Environment Client."""
8
-
9
  from typing import Dict
10
-
11
  from openenv.core import EnvClient
12
  from openenv.core.client_types import StepResult
13
  from openenv.core.env_server.types import State
14
-
15
  from .models import TradeAction, TradeObservation
16
 
17
 
18
- class TradeEnv(
19
- EnvClient[TradeAction, TradeObservation, State]
20
- ):
21
  """
22
- Client for the Trade Env Environment.
23
-
24
- This client maintains a persistent WebSocket connection to the environment server,
25
- enabling efficient multi-step interactions with lower latency.
26
- Each client instance has its own dedicated environment session on the server.
27
-
28
  Example:
29
- >>> # Connect to a running server
30
  >>> with TradeEnv(base_url="http://localhost:8000") as client:
31
  ... result = client.reset()
32
- ... print(result.observation.echoed_message)
33
- ...
34
- ... result = client.step(TradeAction(message="Hello!"))
35
- ... print(result.observation.echoed_message)
36
-
37
- Example with Docker:
38
- >>> # Automatically start container and connect
39
- >>> client = TradeEnv.from_docker_image("trade_env-env:latest")
40
- >>> try:
41
- ... result = client.reset()
42
- ... result = client.step(TradeAction(message="Test"))
43
- ... finally:
44
- ... client.close()
45
  """
46
 
47
  def _step_payload(self, action: TradeAction) -> Dict:
48
- """
49
- Convert TradeAction to JSON payload for step message.
50
-
51
- Args:
52
- action: TradeAction instance
53
-
54
- Returns:
55
- Dictionary representation suitable for JSON encoding
56
- """
57
- return {
58
- "message": action.message,
59
- }
60
 
61
  def _parse_result(self, payload: Dict) -> StepResult[TradeObservation]:
62
- """
63
- Parse server response into StepResult[TradeObservation].
64
-
65
- Args:
66
- payload: JSON response data from server
67
-
68
- Returns:
69
- StepResult with TradeObservation
70
- """
71
- obs_data = payload.get("observation", {})
72
  observation = TradeObservation(
73
- echoed_message=obs_data.get("echoed_message", ""),
74
- message_length=obs_data.get("message_length", 0),
 
 
 
 
 
75
  done=payload.get("done", False),
76
- reward=payload.get("reward"),
77
- metadata=obs_data.get("metadata", {}),
78
  )
79
-
80
  return StepResult(
81
  observation=observation,
82
- reward=payload.get("reward"),
83
  done=payload.get("done", False),
84
  )
85
 
86
  def _parse_state(self, payload: Dict) -> State:
87
- """
88
- Parse server response into State object.
89
-
90
- Args:
91
- payload: JSON response from state request
92
-
93
- Returns:
94
- State object with episode_id and step_count
95
- """
96
  return State(
97
  episode_id=payload.get("episode_id"),
98
- step_count=payload.get("step_count", 0),
99
- )
 
 
 
 
 
 
 
 
 
1
  from typing import Dict
 
2
  from openenv.core import EnvClient
3
  from openenv.core.client_types import StepResult
4
  from openenv.core.env_server.types import State
 
5
  from .models import TradeAction, TradeObservation
6
 
7
 
8
+ class TradeEnv(EnvClient[TradeAction, TradeObservation, State]):
 
 
9
  """
10
+ Client for RetailTraderBehaviorCoach environment.
11
+
 
 
 
 
12
  Example:
 
13
  >>> with TradeEnv(base_url="http://localhost:8000") as client:
14
  ... result = client.reset()
15
+ ... result = client.step(TradeAction(action=0))
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
 
18
  def _step_payload(self, action: TradeAction) -> Dict:
19
+ return {"action": action.action}
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def _parse_result(self, payload: Dict) -> StepResult[TradeObservation]:
22
+ obs_data = payload.get("next_state", {})
 
 
 
 
 
 
 
 
 
23
  observation = TradeObservation(
24
+ timestep=obs_data.get("timestep", 0),
25
+ price=obs_data.get("price", 100.0),
26
+ position=obs_data.get("position", 0),
27
+ loss_streak=obs_data.get("loss_streak", 0),
28
+ pnl=obs_data.get("pnl", 0.0),
29
+ trader_action=payload.get("info", {}).get("trader_action", "HOLD"),
30
+ behaviour=payload.get("info", {}).get("behaviour", "normal"),
31
  done=payload.get("done", False),
32
+ reward=payload.get("reward", 0.0),
 
33
  )
 
34
  return StepResult(
35
  observation=observation,
36
+ reward=payload.get("reward", 0.0),
37
  done=payload.get("done", False),
38
  )
39
 
40
  def _parse_state(self, payload: Dict) -> State:
 
 
 
 
 
 
 
 
 
41
  return State(
42
  episode_id=payload.get("episode_id"),
43
+ step_count=payload.get("timestep", 0),
44
+ )
trade_env/env/coach_env.py CHANGED
@@ -128,9 +128,10 @@ class CoachEnv:
128
 
129
  def _get_state(self):
130
  return {
131
- "timestep": self.t,
132
- "price": self.price,
133
  "position": self.pos,
134
- "loss_streak": self.loss_streak,
135
- "pnl": self.pnl
136
- }
 
 
128
 
129
  def _get_state(self):
130
  return {
131
+ "timestep": self.t / 100.0,
132
+ "price": (self.price - 100.0) / 20.0,
133
  "position": self.pos,
134
+ "loss_streak": min(self.loss_streak, 10) / 10.0,
135
+ "pnl": max(-50, min(50, self.pnl)) / 50.0,
136
+ "overtrade_score": min(self.t, 10) / 10.0 # proxy: more trades = higher ego
137
+ }
trade_env/models.py CHANGED
@@ -1,27 +1,14 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Data models for the Trade Env Environment.
9
-
10
- The trade_env environment is a simple test environment that echoes back messages.
11
- """
12
-
13
  from openenv.core.env_server.types import Action, Observation
14
- from pydantic import Field,BaseModel
15
-
16
 
17
  class TradeAction(Action):
18
- """Action for the Trade Env environment - just a message to echo."""
19
-
20
- message: str = Field(..., description="Message to echo back")
21
-
22
 
23
  class TradeObservation(Observation):
24
- """Observation from the Trade Env environment - the echoed message."""
25
-
26
- echoed_message: str = Field(default="", description="The echoed message")
27
- message_length: int = Field(default=0, description="Length of the echoed message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from openenv.core.env_server.types import Action, Observation
2
+ from pydantic import Field
 
3
 
4
  class TradeAction(Action):
5
+ action: int = Field(..., description="0=NO, 1=WARN, 2=REDUCE, 3=EXIT, 4=COOLDOWN")
 
 
 
6
 
7
  class TradeObservation(Observation):
8
+ timestep: int = Field(default=0)
9
+ price: float = Field(default=100.0)
10
+ position: int = Field(default=0)
11
+ loss_streak: int = Field(default=0)
12
+ pnl: float = Field(default=0.0)
13
+ trader_action: str = Field(default="HOLD")
14
+ behaviour: str = Field(default="normal")
trade_env/schemas/state.py CHANGED
@@ -7,11 +7,12 @@
7
  """
8
 
9
 
10
- from pydantic import BaseModel
 
11
  class State(BaseModel):
12
- timestep: int
13
- price: float
14
- position: int
15
- loss_streak: int
16
- pnl: float
17
-
 
7
  """
8
 
9
 
10
+ from pydantic import BaseModel, Field
11
+
12
  class State(BaseModel):
13
+ timestep: int
14
+ price: float
15
+ position: int
16
+ loss_streak: int
17
+ pnl: float
18
+ overtrade_score: float = Field(default=0.0, description="ego/overtrading signal 0-1")
trade_env/server/app.py CHANGED
@@ -15,6 +15,9 @@ app = FastAPI()
15
 
16
  env = CoachEnv()
17
 
 
 
 
18
 
19
  @app.post("/reset",response_model=State)
20
  def reset():
@@ -33,7 +36,7 @@ def step(action: Action):
33
  )
34
 
35
  def main():
36
- uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
37
 
38
  if __name__ == "__main__":
39
  main()
 
15
 
16
  env = CoachEnv()
17
 
18
+ @app.get("/health")
19
+ def health():
20
+ return {"status": "ok"}
21
 
22
  @app.post("/reset",response_model=State)
23
  def reset():
 
36
  )
37
 
38
  def main():
39
+ uvicorn.run("server.app:app", host="0.0.0.0", port=8000, reload=False)
40
 
41
  if __name__ == "__main__":
42
  main()
trade_env/server/requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
-
5
-
6
-
 
1
  openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
+ pydantic>=2.0.0
5
+ torch>=2.0.0
6
+ python-dotenv>=1.0.0
train.py CHANGED
@@ -4,7 +4,7 @@ from trade_env.schemas.action import Action, ActionType
4
  from trade_env.agent.ppo_agent import PPOAgent
5
 
6
  env = CoachEnv()
7
- agent = PPOAgent(state_dim=5, action_dim=5)
8
 
9
  for episode in range(2000):
10
  state = env.reset()
 
4
  from trade_env.agent.ppo_agent import PPOAgent
5
 
6
  env = CoachEnv()
7
+ agent = PPOAgent(state_dim=6, action_dim=5)
8
 
9
  for episode in range(2000):
10
  state = env.reset()
uv.lock ADDED
The diff for this file is too large to render. See raw diff