Fix circular import error in server/app.py
Browse files- Dockerfile +3 -8
- server/app.py +4 -90
Dockerfile
CHANGED
|
@@ -1,21 +1,16 @@
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
-
# Create non-root user
|
| 4 |
-
RUN useradd -m -u 1000 user
|
| 5 |
-
USER user
|
| 6 |
-
ENV PATH="/home/user/.local/bin:$PATH"
|
| 7 |
-
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
# Copy requirements first
|
| 11 |
COPY requirements.txt .
|
| 12 |
-
RUN pip install --no-cache-dir -
|
| 13 |
|
| 14 |
# Copy all files
|
| 15 |
-
COPY
|
| 16 |
|
| 17 |
# Expose port 7860 (Hugging Face default)
|
| 18 |
EXPOSE 7860
|
| 19 |
|
| 20 |
-
# Run the server
|
| 21 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
# Copy requirements first
|
| 6 |
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
|
| 9 |
# Copy all files
|
| 10 |
+
COPY . .
|
| 11 |
|
| 12 |
# Expose port 7860 (Hugging Face default)
|
| 13 |
EXPOSE 7860
|
| 14 |
|
| 15 |
+
# Run the server directly from server.app
|
| 16 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
server/app.py
CHANGED
|
@@ -4,37 +4,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 4 |
|
| 5 |
from fastapi import FastAPI, HTTPException
|
| 6 |
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from enum import Enum
|
| 10 |
-
from server.app import app as quant_gym_app
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
app = quant_gym_app
|
| 14 |
-
|
| 15 |
-
# Simple models for the API
|
| 16 |
-
class ActionType(str, Enum):
|
| 17 |
-
GET_PRICE = "GET_PRICE"
|
| 18 |
-
GET_NEWS = "GET_NEWS"
|
| 19 |
-
BUY = "BUY"
|
| 20 |
-
SELL = "SELL"
|
| 21 |
-
BACKTEST = "BACKTEST"
|
| 22 |
-
|
| 23 |
-
class AgentAction(BaseModel):
|
| 24 |
-
type: ActionType
|
| 25 |
-
symbol: Optional[str] = "AAPL"
|
| 26 |
-
amount: Optional[int] = 0
|
| 27 |
-
explanation: Optional[str] = None
|
| 28 |
-
strategy: Optional[str] = None
|
| 29 |
-
|
| 30 |
-
class MarketObservation(BaseModel):
|
| 31 |
-
timestamp: str = ""
|
| 32 |
-
price: float = 150.0
|
| 33 |
-
balance: float = 10000.0
|
| 34 |
-
holdings: int = 0
|
| 35 |
-
portfolio_value: float = 10000.0
|
| 36 |
-
last_news: Optional[Dict[str, Any]] = None
|
| 37 |
-
backtest_results: Optional[Dict[str, float]] = None
|
| 38 |
|
| 39 |
app = FastAPI(title="Quant-Gym", description="Financial Analysis Environment")
|
| 40 |
|
|
@@ -46,64 +17,7 @@ app.add_middleware(
|
|
| 46 |
allow_headers=["*"],
|
| 47 |
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
class SimpleEnv:
|
| 51 |
-
def __init__(self):
|
| 52 |
-
self.prices = [150, 152, 151, 153, 155, 154, 156, 158, 157, 159]
|
| 53 |
-
self.news = [
|
| 54 |
-
{"headline": "Apple announces new AI chip", "sentiment": "positive"},
|
| 55 |
-
{"headline": "Supply chain delays expected", "sentiment": "negative"},
|
| 56 |
-
{"headline": "Analysts raise price target", "sentiment": "positive"},
|
| 57 |
-
{"headline": "Market shows strong growth", "sentiment": "positive"},
|
| 58 |
-
]
|
| 59 |
-
self.reset()
|
| 60 |
-
|
| 61 |
-
def reset(self):
|
| 62 |
-
self.idx = 0
|
| 63 |
-
self.cash = 10000.0
|
| 64 |
-
self.shares = 0
|
| 65 |
-
return self._get_observation()
|
| 66 |
-
|
| 67 |
-
def step(self, action: AgentAction):
|
| 68 |
-
# Move time forward
|
| 69 |
-
self.idx = min(self.idx + 1, len(self.prices) - 1)
|
| 70 |
-
price = self.prices[self.idx]
|
| 71 |
-
|
| 72 |
-
if action.type == "BUY" and action.amount:
|
| 73 |
-
cost = price * action.amount
|
| 74 |
-
if cost <= self.cash:
|
| 75 |
-
self.cash -= cost
|
| 76 |
-
self.shares += action.amount
|
| 77 |
-
elif action.type == "SELL" and action.amount:
|
| 78 |
-
if action.amount <= self.shares:
|
| 79 |
-
self.cash += price * action.amount
|
| 80 |
-
self.shares -= action.amount
|
| 81 |
-
|
| 82 |
-
return self._get_observation()
|
| 83 |
-
|
| 84 |
-
def _get_observation(self):
|
| 85 |
-
price = self.prices[self.idx]
|
| 86 |
-
news_idx = self.idx % len(self.news)
|
| 87 |
-
|
| 88 |
-
return MarketObservation(
|
| 89 |
-
timestamp=f"step_{self.idx}",
|
| 90 |
-
price=float(price),
|
| 91 |
-
balance=round(self.cash, 2),
|
| 92 |
-
holdings=self.shares,
|
| 93 |
-
portfolio_value=round(self.cash + self.shares * price, 2),
|
| 94 |
-
last_news=self.news[news_idx]
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
def get_state(self):
|
| 98 |
-
obs = self._get_observation()
|
| 99 |
-
return {
|
| 100 |
-
"current_step": self.idx,
|
| 101 |
-
"total_steps": len(self.prices),
|
| 102 |
-
"observation": obs.dict(),
|
| 103 |
-
"tasks_completed": []
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
env = SimpleEnv()
|
| 107 |
|
| 108 |
@app.get("/")
|
| 109 |
def root():
|
|
@@ -128,7 +42,7 @@ def step(action: AgentAction):
|
|
| 128 |
|
| 129 |
@app.get("/state")
|
| 130 |
def get_state():
|
| 131 |
-
return env.
|
| 132 |
|
| 133 |
@app.get("/tasks")
|
| 134 |
def get_tasks():
|
|
|
|
| 4 |
|
| 5 |
from fastapi import FastAPI, HTTPException
|
| 6 |
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from models import AgentAction
|
| 8 |
+
from server.environment import TradingEnvironment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
app = FastAPI(title="Quant-Gym", description="Financial Analysis Environment")
|
| 11 |
|
|
|
|
| 17 |
allow_headers=["*"],
|
| 18 |
)
|
| 19 |
|
| 20 |
+
env = TradingEnvironment()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@app.get("/")
|
| 23 |
def root():
|
|
|
|
| 42 |
|
| 43 |
@app.get("/state")
|
| 44 |
def get_state():
|
| 45 |
+
return env.state()
|
| 46 |
|
| 47 |
@app.get("/tasks")
|
| 48 |
def get_tasks():
|