coach_env / inference.py
tether007
3 tasks with graders score 1.0
91e580b
"""
inference.py - root directory
3 tasks:
1. revenge-trade-detection — catch loss_streak >= 2
2. panic-sell-prevention — catch deep pnl < -0.3
3. overconfidence-correction — catch win streak + overtrading
"""
import os
from typing import List
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
from trade_env.env.coach_env import CoachEnv
from trade_env.schemas.action import Action, ActionType
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
BENCHMARK = "coach-env"
MAX_STEPS = 20
SUCCESS_SCORE_THRESHOLD = 0.1
TASKS = {
"revenge-trade-detection": {
"desc": "Detect and intervene on revenge trading after loss streaks",
"trigger": lambda s: s["loss_streak"] >= 0.2,
"correct_actions": [3, 4], # EXIT or COOLDOWN
},
"panic-sell-prevention": {
"desc": "Prevent panic selling during drawdowns",
"trigger": lambda s: s["pnl"] < -0.3,
"correct_actions": [2, 3], # REDUCE or EXIT
},
"overconfidence-correction": {
"desc": "Correct overconfident trading after wins",
"trigger": lambda s: s["overtrade_score"] >= 0.7 and s["pnl"] > 0.1,
"correct_actions": [1, 2], # WARN or REDUCE
},
}
def log_start(task, env, model):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step, action, reward, done, error=None):
error_val = error if error else "null"
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
def log_end(success, steps, score, rewards):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def get_llm_action(client, state, task_name):
prompt = (
f"You are a trading behavior coach. Task: {task_name}.\n"
f"Trader state: loss_streak={state['loss_streak']:.2f}, "
f"pnl={state['pnl']:.2f}, overtrade_score={state['overtrade_score']:.2f}.\n"
f"Reply with single digit only. 0=ignore 1=warn 2=reduce 3=exit 4=cooldown"
)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=3,
temperature=0.0,
)
raw = (completion.choices[0].message.content or "").strip()[0]
action = int(raw)
if action not in range(5):
raise ValueError
return action
except:
pass
# rule-based fallback
loss = state["loss_streak"]
pnl = state["pnl"]
over = state["overtrade_score"]
if task_name == "revenge-trade-detection":
if loss >= 0.2: return 4
if loss >= 0.1: return 3
if loss > 0.0: return 1
return 0
if task_name == "panic-sell-prevention":
if pnl < -0.3: return 3
if pnl < -0.1: return 2
return 0
if task_name == "overconfidence-correction":
if over >= 0.7: return 2
if over >= 0.5: return 1
return 0
return 0
def run_task(client, task_name: str) -> float:
task = TASKS[task_name]
env = CoachEnv()
rewards: List[float] = []
steps_taken = 0
correct_interventions = 0
total_triggers = 0
log_start(task_name, BENCHMARK, MODEL_NAME)
try:
state = env.reset()
for step in range(1, MAX_STEPS + 1):
action_idx = get_llm_action(client, state, task_name)
action = Action(action=ActionType(action_idx))
next_state, reward, done, info = env.step(action)
# grade: did agent pick correct action when trigger fired?
if task["trigger"](state):
total_triggers += 1
if action_idx in task["correct_actions"]:
correct_interventions += 1
reward = abs(reward) + 0.1 # bonus for correct intervention
log_step(step, ActionType(action_idx).name, reward, done)
rewards.append(reward)
steps_taken = step
state = next_state
if done:
break
# score = intervention accuracy when triggers fired
if total_triggers > 0:
score = correct_interventions / total_triggers
else:
score = sum(r for r in rewards if r > 0)
score = min(1.0, score / 0.5)
score = min(1.0, max(0.0, score))
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
log_step(steps_taken + 1, "NO", 0.0, True, error=str(e))
success = False
score = 0.0
rewards = rewards or [0.0]
finally:
log_end(success, steps_taken, score, rewards)
return score
def main():
client = OpenAI(
api_key=HF_TOKEN,
base_url=API_BASE_URL
)
all_scores = []
for task_name in TASKS:
score = run_task(client, task_name)
all_scores.append(score)
avg = sum(all_scores) / len(all_scores)
print(f"[SUMMARY] tasks={len(all_scores)} avg_score={avg:.3f}", flush=True)
if __name__ == "__main__":
main()