| """ |
| inference.py - root directory |
| 3 tasks: |
| 1. revenge-trade-detection — catch loss_streak >= 2 |
| 2. panic-sell-prevention — catch deep pnl < -0.3 |
| 3. overconfidence-correction — catch win streak + overtrading |
| """ |
| import os |
| from typing import List |
| from openai import OpenAI |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| from trade_env.env.coach_env import CoachEnv |
| from trade_env.schemas.action import Action, ActionType |
|
|
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") |
|
|
| BENCHMARK = "coach-env" |
| MAX_STEPS = 20 |
| SUCCESS_SCORE_THRESHOLD = 0.1 |
|
|
| TASKS = { |
| "revenge-trade-detection": { |
| "desc": "Detect and intervene on revenge trading after loss streaks", |
| "trigger": lambda s: s["loss_streak"] >= 0.2, |
| "correct_actions": [3, 4], |
| }, |
| "panic-sell-prevention": { |
| "desc": "Prevent panic selling during drawdowns", |
| "trigger": lambda s: s["pnl"] < -0.3, |
| "correct_actions": [2, 3], |
| }, |
| "overconfidence-correction": { |
| "desc": "Correct overconfident trading after wins", |
| "trigger": lambda s: s["overtrade_score"] >= 0.7 and s["pnl"] > 0.1, |
| "correct_actions": [1, 2], |
| }, |
| } |
|
|
|
|
| def log_start(task, env, model): |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step, action, reward, done, error=None): |
| error_val = error if error else "null" |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True) |
|
|
|
|
| def log_end(success, steps, score, rewards): |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) |
|
|
|
|
| def get_llm_action(client, state, task_name): |
| prompt = ( |
| f"You are a trading behavior coach. Task: {task_name}.\n" |
| f"Trader state: loss_streak={state['loss_streak']:.2f}, " |
| f"pnl={state['pnl']:.2f}, overtrade_score={state['overtrade_score']:.2f}.\n" |
| f"Reply with single digit only. 0=ignore 1=warn 2=reduce 3=exit 4=cooldown" |
| ) |
| try: |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=3, |
| temperature=0.0, |
| ) |
| raw = (completion.choices[0].message.content or "").strip()[0] |
| action = int(raw) |
| if action not in range(5): |
| raise ValueError |
| return action |
| except: |
| pass |
|
|
| |
| loss = state["loss_streak"] |
| pnl = state["pnl"] |
| over = state["overtrade_score"] |
|
|
| if task_name == "revenge-trade-detection": |
| if loss >= 0.2: return 4 |
| if loss >= 0.1: return 3 |
| if loss > 0.0: return 1 |
| return 0 |
|
|
| if task_name == "panic-sell-prevention": |
| if pnl < -0.3: return 3 |
| if pnl < -0.1: return 2 |
| return 0 |
|
|
| if task_name == "overconfidence-correction": |
| if over >= 0.7: return 2 |
| if over >= 0.5: return 1 |
| return 0 |
|
|
| return 0 |
|
|
|
|
| def run_task(client, task_name: str) -> float: |
| task = TASKS[task_name] |
| env = CoachEnv() |
| rewards: List[float] = [] |
| steps_taken = 0 |
| correct_interventions = 0 |
| total_triggers = 0 |
|
|
| log_start(task_name, BENCHMARK, MODEL_NAME) |
|
|
| try: |
| state = env.reset() |
|
|
| for step in range(1, MAX_STEPS + 1): |
| action_idx = get_llm_action(client, state, task_name) |
| action = Action(action=ActionType(action_idx)) |
|
|
| next_state, reward, done, info = env.step(action) |
|
|
| |
| if task["trigger"](state): |
| total_triggers += 1 |
| if action_idx in task["correct_actions"]: |
| correct_interventions += 1 |
| reward = abs(reward) + 0.1 |
|
|
| log_step(step, ActionType(action_idx).name, reward, done) |
|
|
| rewards.append(reward) |
| steps_taken = step |
| state = next_state |
|
|
| if done: |
| break |
|
|
| |
| if total_triggers > 0: |
| score = correct_interventions / total_triggers |
| else: |
| score = sum(r for r in rewards if r > 0) |
| score = min(1.0, score / 0.5) |
|
|
| score = min(1.0, max(0.0, score)) |
| success = score >= SUCCESS_SCORE_THRESHOLD |
|
|
| except Exception as e: |
| log_step(steps_taken + 1, "NO", 0.0, True, error=str(e)) |
| success = False |
| score = 0.0 |
| rewards = rewards or [0.0] |
|
|
| finally: |
| log_end(success, steps_taken, score, rewards) |
|
|
| return score |
|
|
|
|
| def main(): |
| client = OpenAI( |
| api_key=HF_TOKEN, |
| base_url=API_BASE_URL |
| ) |
|
|
| all_scores = [] |
| for task_name in TASKS: |
| score = run_task(client, task_name) |
| all_scores.append(score) |
|
|
| avg = sum(all_scores) / len(all_scores) |
| print(f"[SUMMARY] tasks={len(all_scores)} avg_score={avg:.3f}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |