tether007 commited on
Commit ·
c5c527c
1
Parent(s): a2cdc3e
simulator remaining and environment
Browse files- .gitignore +1 -1
- README.md +3 -1
- inference.py +109 -3
- openenv.yaml +17 -7
- trade_env/{train.py → agent/__init__.py} +0 -0
- trade_env/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- trade_env/agent/__pycache__/ppo_agent.cpython-312.pyc +0 -0
- trade_env/agent/ppo_agent.py +115 -0
- trade_env/env/coach_env.py +9 -17
- trade_env/schemas/__pycache__/__init__.cpython-312.pyc +0 -0
- trade_env/schemas/__pycache__/action.cpython-312.pyc +0 -0
- trade_env/schemas/__pycache__/state.cpython-312.pyc +0 -0
- trade_env/schemas/__pycache__/step_response.cpython-312.pyc +0 -0
- trade_env/schemas/state.py +2 -1
- trade_env/server/__pycache__/__init__.cpython-312.pyc +0 -0
- trade_env/server/__pycache__/app.cpython-312.pyc +0 -0
- trade_env/server/__pycache__/environment.cpython-312.pyc +0 -0
- trade_env/server/app.py +25 -44
- trade_env/tests/InferenceTest.py +0 -0
- train.py +21 -0
.gitignore
CHANGED
|
@@ -2,4 +2,4 @@
|
|
| 2 |
/trade_env/__pycache__
|
| 3 |
/trade_env/env/__pycache__
|
| 4 |
/trade_env/tests/__pycache__
|
| 5 |
-
|
|
|
|
| 2 |
/trade_env/__pycache__
|
| 3 |
/trade_env/env/__pycache__
|
| 4 |
/trade_env/tests/__pycache__
|
| 5 |
+
.env
|
README.md
CHANGED
|
@@ -4,4 +4,6 @@
|
|
| 4 |
4. inference.py
|
| 5 |
5. run end-to-end
|
| 6 |
6. add PPO agent
|
| 7 |
-
7. improve logic
|
|
|
|
|
|
|
|
|
| 4 |
4. inference.py
|
| 5 |
5. run end-to-end
|
| 6 |
6. add PPO agent
|
| 7 |
+
7. improve logic
|
| 8 |
+
|
| 9 |
+
`uvicorn trade_env.server.app:app --reload`
|
inference.py
CHANGED
|
@@ -1,5 +1,111 @@
|
|
| 1 |
-
"""
|
| 2 |
-
The main mandatory file which calls the env , logs
|
| 3 |
-
|
| 4 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
inference.py - must be in root directory
|
| 3 |
+
Uses OpenAI client for LLM calls as per hackathon requirements
|
| 4 |
+
Emits [START], [STEP], [END] structured logs
|
| 5 |
+
"""
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
load_dotenv()
|
| 8 |
+
import os
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
from trade_env.env.coach_env import CoachEnv
|
| 11 |
+
from trade_env.schemas.action import Action, ActionType
|
| 12 |
+
|
| 13 |
+
TASK_NAME = "trader-coach"
|
| 14 |
+
BENCHMARK = "coach-env"
|
| 15 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 16 |
+
API_BASE = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 17 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 18 |
+
MAX_STEPS = 20
|
| 19 |
+
|
| 20 |
+
client = OpenAI(
|
| 21 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
| 22 |
+
base_url=API_BASE
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_llm_action(state: dict) -> int:
|
| 27 |
+
prompt = f"""You are a trading behavior coach. Given this trader state:
|
| 28 |
+
- timestep: {state['timestep']}
|
| 29 |
+
- price: {state['price']:.2f}
|
| 30 |
+
- position: {state['position']}
|
| 31 |
+
- loss_streak: {state['loss_streak']}
|
| 32 |
+
- pnl: {state['pnl']:.2f}
|
| 33 |
+
|
| 34 |
+
Choose intervention (respond with single integer only):
|
| 35 |
+
0 = NO (do nothing)
|
| 36 |
+
1 = WARN (light nudge)
|
| 37 |
+
2 = REDUCE (reduce position size)
|
| 38 |
+
3 = EXIT (exit position)
|
| 39 |
+
4 = COOLDOWN (force break)"""
|
| 40 |
+
|
| 41 |
+
response = client.chat.completions.create(
|
| 42 |
+
model=MODEL_NAME,
|
| 43 |
+
messages=[{"role": "user", "content": prompt}],
|
| 44 |
+
max_tokens=5,
|
| 45 |
+
temperature=0.0
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
raw = response.choices[0].message.content.strip()
|
| 49 |
+
try:
|
| 50 |
+
action = int(raw)
|
| 51 |
+
if action not in range(5):
|
| 52 |
+
action = 0
|
| 53 |
+
except ValueError:
|
| 54 |
+
action = 0
|
| 55 |
+
return action
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def log_start():
|
| 59 |
+
print(f"[START] task={TASK_NAME} env={BENCHMARK} model={MODEL_NAME}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def log_step(step, action, reward, done, error=None):
|
| 63 |
+
error_val = error if error else "null"
|
| 64 |
+
print(f"[STEP] step={step} action={action} reward={reward:.4f} done={str(done).lower()} error={error_val}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def log_end(success, steps, score, rewards):
|
| 68 |
+
rewards_str = ",".join(f"{r:.4f}" for r in rewards)
|
| 69 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={rewards_str}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
env = CoachEnv()
|
| 74 |
+
rewards = []
|
| 75 |
+
steps_taken = 0
|
| 76 |
+
|
| 77 |
+
log_start()
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
state = env.reset()
|
| 81 |
+
|
| 82 |
+
for step in range(1, MAX_STEPS + 1):
|
| 83 |
+
action_idx = get_llm_action(state)
|
| 84 |
+
action = Action(action=ActionType(action_idx))
|
| 85 |
+
|
| 86 |
+
next_state, reward, done, info = env.step(action)
|
| 87 |
+
|
| 88 |
+
log_step(step, ActionType(action_idx).name, reward, done)
|
| 89 |
+
|
| 90 |
+
rewards.append(reward)
|
| 91 |
+
steps_taken = step
|
| 92 |
+
state = next_state
|
| 93 |
+
|
| 94 |
+
if done:
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
total_reward = sum(rewards)
|
| 98 |
+
score = max(0.0, min(1.0, (total_reward + 1.0) / 2.0))
|
| 99 |
+
success = score > 0.1
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
log_step(steps_taken + 1, "NO", 0.0, True, error=str(e))
|
| 103 |
+
success = False
|
| 104 |
+
score = 0.0
|
| 105 |
+
rewards = rewards or [0.0]
|
| 106 |
+
|
| 107 |
+
log_end(success, steps_taken, score, rewards)
|
| 108 |
+
|
| 109 |
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
openenv.yaml
CHANGED
|
@@ -1,7 +1,17 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: RetailTraderBehaviorCoach
|
| 2 |
+
version: "1.0"
|
| 3 |
+
state:
|
| 4 |
+
timestep: int
|
| 5 |
+
price: float
|
| 6 |
+
position: int
|
| 7 |
+
loss_streak: int
|
| 8 |
+
pnl: float
|
| 9 |
+
actions:
|
| 10 |
+
- NO
|
| 11 |
+
- WARN
|
| 12 |
+
- REDUCE
|
| 13 |
+
- EXIT
|
| 14 |
+
- COOLDOWN
|
| 15 |
+
endpoints:
|
| 16 |
+
reset: /reset
|
| 17 |
+
step: /step
|
trade_env/{train.py → agent/__init__.py}
RENAMED
|
File without changes
|
trade_env/agent/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (135 Bytes). View file
|
|
|
trade_env/agent/__pycache__/ppo_agent.cpython-312.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
trade_env/agent/ppo_agent.py
CHANGED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Proximal policy Optimization(PPO)"""
|
| 2 |
+
from trade_env.schemas.action import Action
|
| 3 |
+
from trade_env.schemas.state import State
|
| 4 |
+
from trade_env.schemas.step_response import StepResponse
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.distributions import Categorical
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ActorCritic(torch.nn.Module):
|
| 12 |
+
def __init__(self, state_dim, action_dim):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.shared = nn.Sequential(
|
| 15 |
+
nn.Linear(state_dim, 64),
|
| 16 |
+
nn.Tanh(),
|
| 17 |
+
nn.Linear(64, 64),
|
| 18 |
+
nn.Tanh(),
|
| 19 |
+
)
|
| 20 |
+
self.actor = nn.Linear(64, action_dim)
|
| 21 |
+
self.critic = nn.Linear(64, 1)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
x = self.shared(x)
|
| 25 |
+
return self.actor(x), self.critic(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PPOAgent(nn.Module):
|
| 29 |
+
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, eps_clip=0.2):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.gamma = gamma
|
| 32 |
+
self.eps_clip = eps_clip
|
| 33 |
+
self.model = ActorCritic(state_dim, action_dim)
|
| 34 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
| 35 |
+
self._clear_memory()
|
| 36 |
+
|
| 37 |
+
def _clear_memory(self):
|
| 38 |
+
self.states = []
|
| 39 |
+
self.actions = []
|
| 40 |
+
self.log_probs = []
|
| 41 |
+
self.rewards = []
|
| 42 |
+
self.dones = []
|
| 43 |
+
self.values = []
|
| 44 |
+
|
| 45 |
+
def _state_to_tensor(self, state):
|
| 46 |
+
return torch.tensor(list(state.values()), dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
def select_action(self, state):
|
| 49 |
+
state_t = self._state_to_tensor(state)
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
logits, value = self.model(state_t)
|
| 52 |
+
|
| 53 |
+
dist = Categorical(logits=logits)
|
| 54 |
+
action = dist.sample()
|
| 55 |
+
|
| 56 |
+
self.states.append(state_t)
|
| 57 |
+
self.actions.append(action)
|
| 58 |
+
self.log_probs.append(dist.log_prob(action))
|
| 59 |
+
self.values.append(value.squeeze())
|
| 60 |
+
|
| 61 |
+
return action.item()
|
| 62 |
+
|
| 63 |
+
def store_outcome(self, reward, done):
|
| 64 |
+
self.rewards.append(reward)
|
| 65 |
+
self.dones.append(done)
|
| 66 |
+
|
| 67 |
+
def _compute_returns(self):
|
| 68 |
+
returns = []
|
| 69 |
+
G = 0
|
| 70 |
+
for reward, done in zip(reversed(self.rewards), reversed(self.dones)):
|
| 71 |
+
if done:
|
| 72 |
+
G = 0
|
| 73 |
+
G = reward + self.gamma * G
|
| 74 |
+
returns.insert(0, G)
|
| 75 |
+
return torch.tensor(returns, dtype=torch.float32)
|
| 76 |
+
|
| 77 |
+
def update(self, epochs=4):
|
| 78 |
+
returns = self._compute_returns()
|
| 79 |
+
|
| 80 |
+
# detach everything collected during rollout
|
| 81 |
+
states = torch.stack(self.states).detach()
|
| 82 |
+
actions = torch.stack(self.actions).detach()
|
| 83 |
+
log_probs_old = torch.stack(self.log_probs).detach()
|
| 84 |
+
values_old = torch.stack(self.values).detach()
|
| 85 |
+
|
| 86 |
+
advantages = returns - values_old
|
| 87 |
+
# normalize advantages
|
| 88 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 89 |
+
|
| 90 |
+
for _ in range(epochs):
|
| 91 |
+
logits, new_values = self.model(states)
|
| 92 |
+
dist = Categorical(logits=logits)
|
| 93 |
+
new_log_probs = dist.log_prob(actions)
|
| 94 |
+
|
| 95 |
+
ratio = torch.exp(new_log_probs - log_probs_old)
|
| 96 |
+
|
| 97 |
+
surr1 = ratio * advantages
|
| 98 |
+
surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
|
| 99 |
+
|
| 100 |
+
actor_loss = -torch.min(surr1, surr2).mean()
|
| 101 |
+
critic_loss = nn.MSELoss()(new_values.squeeze(), returns)
|
| 102 |
+
entropy_bonus = dist.entropy().mean()
|
| 103 |
+
|
| 104 |
+
loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy_bonus
|
| 105 |
+
|
| 106 |
+
self.optimizer.zero_grad()
|
| 107 |
+
loss.backward()
|
| 108 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
|
| 109 |
+
self.optimizer.step()
|
| 110 |
+
|
| 111 |
+
self._clear_memory()
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
agent = PPOAgent(state_dim=5, action_dim=5)
|
| 115 |
+
print("PPOAgent instantiated successfully.")
|
trade_env/env/coach_env.py
CHANGED
|
@@ -27,16 +27,6 @@ import random
|
|
| 27 |
from enum import Enum
|
| 28 |
from trade_env.schemas.action import ActionType, Action
|
| 29 |
|
| 30 |
-
|
| 31 |
-
class Action(Enum):
|
| 32 |
-
|
| 33 |
-
NO = 0
|
| 34 |
-
WARN = 1
|
| 35 |
-
REDUCE = 2
|
| 36 |
-
EXIT = 3
|
| 37 |
-
COOLDOWN = 4 #force stop for a tmframe
|
| 38 |
-
|
| 39 |
-
|
| 40 |
class CoachEnv:
|
| 41 |
|
| 42 |
def __init__(self):
|
|
@@ -46,6 +36,8 @@ class CoachEnv:
|
|
| 46 |
self.pnl = 0
|
| 47 |
self.loss_streak = 0
|
| 48 |
self.pos = 0
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def reset(self):
|
| 51 |
""" resets the env
|
|
@@ -56,7 +48,7 @@ class CoachEnv:
|
|
| 56 |
self.pnl = 0
|
| 57 |
self.loss_streak = 0
|
| 58 |
self.pos = 0
|
| 59 |
-
|
| 60 |
|
| 61 |
return self._get_state()
|
| 62 |
|
|
@@ -68,14 +60,14 @@ class CoachEnv:
|
|
| 68 |
Args:
|
| 69 |
action (): task for the agent to take given the sensor inputs in the env present
|
| 70 |
"""
|
| 71 |
-
action_type = action.
|
| 72 |
|
| 73 |
intr = 0
|
| 74 |
if(action_type == ActionType.WARN):
|
| 75 |
intr = .2
|
| 76 |
-
elif action_type == ActionType.
|
| 77 |
intr = 0.4
|
| 78 |
-
elif action_type == ActionType.
|
| 79 |
self.pos = 0
|
| 80 |
elif action_type == ActionType.COOLDOWN:
|
| 81 |
intr = 1.0
|
|
@@ -113,8 +105,8 @@ class CoachEnv:
|
|
| 113 |
else:
|
| 114 |
self.loss_streak = 0
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
self.t += 1
|
| 119 |
done = False
|
| 120 |
|
|
@@ -141,4 +133,4 @@ class CoachEnv:
|
|
| 141 |
"position": self.pos,
|
| 142 |
"loss_streak": self.loss_streak,
|
| 143 |
"pnl": self.pnl
|
| 144 |
-
}
|
|
|
|
| 27 |
from enum import Enum
|
| 28 |
from trade_env.schemas.action import ActionType, Action
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class CoachEnv:
|
| 31 |
|
| 32 |
def __init__(self):
|
|
|
|
| 36 |
self.pnl = 0
|
| 37 |
self.loss_streak = 0
|
| 38 |
self.pos = 0
|
| 39 |
+
self.entry_price = 100
|
| 40 |
+
|
| 41 |
|
| 42 |
def reset(self):
|
| 43 |
""" resets the env
|
|
|
|
| 48 |
self.pnl = 0
|
| 49 |
self.loss_streak = 0
|
| 50 |
self.pos = 0
|
| 51 |
+
self.entry_price = 100
|
| 52 |
|
| 53 |
return self._get_state()
|
| 54 |
|
|
|
|
| 60 |
Args:
|
| 61 |
action (): task for the agent to take given the sensor inputs in the env present
|
| 62 |
"""
|
| 63 |
+
action_type = action.action
|
| 64 |
|
| 65 |
intr = 0
|
| 66 |
if(action_type == ActionType.WARN):
|
| 67 |
intr = .2
|
| 68 |
+
elif action_type == ActionType.REDUCE:
|
| 69 |
intr = 0.4
|
| 70 |
+
elif action_type == ActionType.EXIT:
|
| 71 |
self.pos = 0
|
| 72 |
elif action_type == ActionType.COOLDOWN:
|
| 73 |
intr = 1.0
|
|
|
|
| 105 |
else:
|
| 106 |
self.loss_streak = 0
|
| 107 |
|
| 108 |
+
raw_reward = step_pnl - (0.1 * intr) - (0.5 * self.loss_streak if step_pnl < 0 else 0)
|
| 109 |
+
reward = max(-1.0, min(1.0, raw_reward / 50.0))
|
| 110 |
self.t += 1
|
| 111 |
done = False
|
| 112 |
|
|
|
|
| 133 |
"position": self.pos,
|
| 134 |
"loss_streak": self.loss_streak,
|
| 135 |
"pnl": self.pnl
|
| 136 |
+
}
|
trade_env/schemas/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (137 Bytes). View file
|
|
|
trade_env/schemas/__pycache__/action.cpython-312.pyc
ADDED
|
Binary file (764 Bytes). View file
|
|
|
trade_env/schemas/__pycache__/state.cpython-312.pyc
ADDED
|
Binary file (675 Bytes). View file
|
|
|
trade_env/schemas/__pycache__/step_response.cpython-312.pyc
ADDED
|
Binary file (758 Bytes). View file
|
|
|
trade_env/schemas/state.py
CHANGED
|
@@ -13,4 +13,5 @@ class State(BaseModel):
|
|
| 13 |
price: float
|
| 14 |
position: int
|
| 15 |
loss_streak: int
|
| 16 |
-
pnl: float
|
|
|
|
|
|
| 13 |
price: float
|
| 14 |
position: int
|
| 15 |
loss_streak: int
|
| 16 |
+
pnl: float
|
| 17 |
+
|
trade_env/server/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (279 Bytes). View file
|
|
|
trade_env/server/__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
trade_env/server/__pycache__/environment.cpython-312.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
trade_env/server/app.py
CHANGED
|
@@ -4,57 +4,38 @@ fast api endpoints which will be an HTTP server
|
|
| 4 |
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
-
from ..models import TradeAction, TradeObservation
|
| 16 |
-
from .environment import TradeEnvironment
|
| 17 |
-
except ModuleNotFoundError:
|
| 18 |
-
from models import TradeAction, TradeObservation
|
| 19 |
-
from trade_env.server.environment import TradeEnvironment
|
| 20 |
|
|
|
|
| 21 |
|
| 22 |
-
# Create the app with web interface and README integration
|
| 23 |
-
app = create_app(
|
| 24 |
-
TradeEnvironment,
|
| 25 |
-
TradeAction,
|
| 26 |
-
TradeObservation,
|
| 27 |
-
env_name="trade_env",
|
| 28 |
-
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 29 |
-
)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
host: Host address to bind to (default: "0.0.0.0")
|
| 43 |
-
port: Port number to listen on (default: 8000)
|
| 44 |
-
|
| 45 |
-
For production deployments, consider using uvicorn directly with
|
| 46 |
-
multiple workers:
|
| 47 |
-
uvicorn trade_env.server.app:app --workers 4
|
| 48 |
-
"""
|
| 49 |
-
import uvicorn
|
| 50 |
-
|
| 51 |
-
uvicorn.run(app, host=host, port=port)
|
| 52 |
|
|
|
|
|
|
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
parser.add_argument("--port", type=int, default=8000)
|
| 59 |
-
args = parser.parse_args()
|
| 60 |
-
main(port=args.port)
|
|
|
|
| 4 |
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
import uvicorn
|
| 9 |
+
from trade_env.env.coach_env import CoachEnv
|
| 10 |
+
from trade_env.schemas.action import Action
|
| 11 |
+
from trade_env.schemas.state import State
|
| 12 |
+
from trade_env.schemas.step_response import StepResponse
|
| 13 |
|
| 14 |
+
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
env = CoachEnv()
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
@app.post("/reset",response_model=State)
|
| 20 |
+
def reset():
|
| 21 |
+
state = env.reset()
|
| 22 |
+
return State(**state)
|
| 23 |
|
| 24 |
+
@app.post("/step", response_model=StepResponse)
|
| 25 |
+
def step(action: Action):
|
| 26 |
+
next_state, reward, done, info = env.step(action)
|
| 27 |
|
| 28 |
+
return StepResponse(
|
| 29 |
+
next_state=State(**next_state),
|
| 30 |
+
reward=reward,
|
| 31 |
+
done=done,
|
| 32 |
+
info=info
|
| 33 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def main():
|
| 36 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
|
| 37 |
|
| 38 |
if __name__ == "__main__":
|
| 39 |
+
main()
|
| 40 |
|
| 41 |
+
|
|
|
|
|
|
|
|
|
trade_env/tests/InferenceTest.py
ADDED
|
File without changes
|
train.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train.py
|
| 2 |
+
from trade_env.env.coach_env import CoachEnv
|
| 3 |
+
from trade_env.schemas.action import Action, ActionType
|
| 4 |
+
from trade_env.agent.ppo_agent import PPOAgent
|
| 5 |
+
|
| 6 |
+
env = CoachEnv()
|
| 7 |
+
agent = PPOAgent(state_dim=5, action_dim=5)
|
| 8 |
+
|
| 9 |
+
for episode in range(2000):
|
| 10 |
+
state = env.reset()
|
| 11 |
+
done = False
|
| 12 |
+
|
| 13 |
+
while not done:
|
| 14 |
+
action_idx = agent.select_action(state)
|
| 15 |
+
action = Action(action=ActionType(action_idx))
|
| 16 |
+
next_state, reward, done, info = env.step(action)
|
| 17 |
+
agent.store_outcome(reward, done)
|
| 18 |
+
state = next_state
|
| 19 |
+
|
| 20 |
+
agent.update()
|
| 21 |
+
print(f"Ep {episode} | PnL: {info['pnl']:.2f} | Action: {action_idx} | Trader: {info['trader_action']}")
|