File size: 12,721 Bytes
c9d1b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbcb71c
c9d1b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbcb71c
 
c9d1b27
 
 
 
 
dbcb71c
 
 
 
 
 
 
 
 
 
c9d1b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbcb71c
 
c9d1b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
Mahoraga Adaptation Engine — FastAPI Bridge
Wraps MahoragaEnv with REST endpoints for the React combat dashboard.
Includes LLM auto-play via trained Qwen 2.5 3B LoRA model.
"""
import sys
import os
import re

sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional

from env.mahoraga_env import MahoragaEnv
from utils.constants import MAX_HP, ENEMY_HP, MAX_TURNS

# ── Action lookup ──
ACTION_NAMES = {
    0: "Adapt PHYSICAL",
    1: "Adapt CE",
    2: "Adapt TECHNIQUE",
    3: "Judgment Strike",
    4: "Regeneration",
    None: "Wasted Turn",
}

app = FastAPI(title="Mahoraga Adaptation Engine API", version="3.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Global state ──
env: Optional[MahoragaEnv] = None
current_difficulty: str = "hard"

# ── LLM Model (lazy loaded) ──
llm_model = None
llm_tokenizer = None
llm_loaded = False
llm_error: Optional[str] = None


def load_llm():
    """Load Qwen 2.5 3B + LoRA for auto-play. Called once on first use."""
    global llm_model, llm_tokenizer, llm_loaded, llm_error

    if llm_loaded:
        return True
    if llm_error:
        return False

    model_path = os.path.join(os.path.dirname(__file__), "mahoraga_loral_final")

    if not os.path.exists(os.path.join(model_path, "adapter_config.json")):
        llm_error = f"LoRA weights not found at {model_path}"
        print(f"[LLM] ERROR: {llm_error}")
        return False

    try:
        print("[LLM] Loading Qwen 2.5 3B + LoRA (4-bit)... This may take 30-60s.")

        # Try unsloth first (faster), fall back to transformers+peft
        try:
            from unsloth import FastLanguageModel
            import torch

            llm_model, llm_tokenizer = FastLanguageModel.from_pretrained(
                model_name=model_path,
                max_seq_length=1024,
                dtype=None,
                load_in_4bit=True,
            )
            FastLanguageModel.for_inference(llm_model)
            print("[LLM] Model loaded via Unsloth.")

        except ImportError:
            print("[LLM] Unsloth not found, using transformers + peft...")
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
            from peft import PeftModel

            base_model_name = "Qwen/Qwen2.5-3B-Instruct"
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
            )

            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True,
            )
            llm_model = PeftModel.from_pretrained(base_model, model_path)
            llm_tokenizer = AutoTokenizer.from_pretrained(model_path)
            llm_model.eval()
            print("[LLM] Model loaded via transformers + peft.")

        llm_loaded = True
        return True

    except Exception as e:
        llm_error = str(e)
        print(f"[LLM] Failed to load model: {llm_error}")
        return False


def build_prompt(state_dict):
    """Build instruction prompt from environment state."""
    res = state_dict["resistances"]
    return f"""You are Mahoraga, an adaptive combat agent in a turn-based RL environment.

Current State:
- Your HP: {state_dict['agent_hp']}
- Enemy HP: {state_dict['enemy_hp']}
- Resistances: Physical={res['physical']}, CE={res['ce']}, Technique={res['technique']}
- Last Enemy Attack: {state_dict['last_enemy_attack_type']}
- Last Action Taken: {state_dict['last_action']}
- Turn: {state_dict['turn_number']}

Available Actions:
0 = Adapt Physical Resistance (+40 Physical, -20 others)
1 = Adapt CE Resistance (+40 CE, -20 others)
2 = Adapt Technique Resistance (+40 Technique, -20 others)
3 = Judgment Strike (burst if you adapted to enemy's type, resets resistances)
4 = Regeneration (heal 300 HP, 3-turn cooldown)

WINNING STRATEGY:
1. Adapt to enemy attack type 2 times to build resistance + stacks
2. Use Judgment Strike for burst damage (350 + 50 per stack)
3. Repeat: Adapt → Adapt → Strike
4. Heal ONLY when HP is critically low

Choose the best action. Return ONLY a single integer (0-4)."""


def parse_action(text):
    """Extract integer action 0-4 from model output."""
    text = text.strip()
    if text in ['0', '1', '2', '3', '4']:
        return int(text)
    match = re.search(r'[0-4]', text)
    if match:
        return int(match.group())
    return 0


def llm_choose_action(state_dict):
    """Use the trained LLM to pick an action given the current state."""
    import torch

    prompt = build_prompt(state_dict)
    messages = [
        {"role": "system", "content": "You are a combat AI. Respond with ONLY a single integer 0-4."},
        {"role": "user", "content": prompt}
    ]

    input_text = llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = llm_tokenizer(input_text, return_tensors="pt").to(llm_model.device)

    with torch.no_grad():
        outputs = llm_model.generate(
            **inputs,
            max_new_tokens=8,
            temperature=0.7,
            do_sample=True,
            pad_token_id=llm_tokenizer.eos_token_id
        )

    response = llm_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    action = parse_action(response)
    return action, response.strip()


# ── Response schemas ──
class TurnLog(BaseModel):
    turn: int
    enemy_attack_type: str
    enemy_subtype: str
    mahoraga_action: str
    damage_taken: int
    damage_dealt: int
    correct_adaptation: bool
    reward: float
    heal_blocked: bool


class Resistances(BaseModel):
    Physical: int
    CE: int
    Technique: int


class CombatState(BaseModel):
    enemy_hp: int
    enemy_hp_max: int
    mahoraga_hp: int
    mahoraga_hp_max: int
    resistances: Resistances
    adaptation_stack: int
    heal_cooldown: int
    turn_number: int
    max_turns: int
    done: bool
    done_reason: Optional[str] = None
    turn_log: Optional[TurnLog] = None
    difficulty: str = "hard"
    llm_raw: Optional[str] = None


class StepRequest(BaseModel):
    player_action: Optional[str] = None  # None means auto (based on difficulty)


class ResetRequest(BaseModel):
    difficulty: str = "hard"


# ── Helper ──
def make_combat_state(state, env_instance, turn_log=None, llm_raw=None):
    return CombatState(
        enemy_hp=state["enemy_hp"],
        enemy_hp_max=ENEMY_HP,
        mahoraga_hp=state["agent_hp"],
        mahoraga_hp_max=MAX_HP,
        resistances=Resistances(
            Physical=state["resistances"]["physical"],
            CE=state["resistances"]["ce"],
            Technique=state["resistances"]["technique"],
        ),
        adaptation_stack=env_instance.adaptation_stack if hasattr(env_instance, 'adaptation_stack') else 0,
        heal_cooldown=env_instance.heal_cooldown_counter,
        turn_number=state["turn_number"],
        max_turns=MAX_TURNS,
        done=False,
        done_reason=None,
        turn_log=turn_log,
        difficulty=current_difficulty,
        llm_raw=llm_raw,
    )


# ── Endpoints ──

@app.post("/api/reset", response_model=CombatState)
def reset(req: ResetRequest = ResetRequest()):
    """Reset the environment to initial state with specified difficulty."""
    global env, current_difficulty
    current_difficulty = req.difficulty
    env = MahoragaEnv(difficulty=current_difficulty)
    env.reset()

    return CombatState(
        enemy_hp=ENEMY_HP,
        enemy_hp_max=ENEMY_HP,
        mahoraga_hp=MAX_HP,
        mahoraga_hp_max=MAX_HP,
        resistances=Resistances(Physical=0, CE=0, Technique=0),
        adaptation_stack=0,
        heal_cooldown=0,
        turn_number=0,
        max_turns=MAX_TURNS,
        done=False,
        done_reason=None,
        turn_log=None,
        difficulty=current_difficulty,
    )


def _do_step(player_action=None):
    """Execute one turn of combat. Mahoraga uses LLM to pick action, player uses player_action."""
    global env
    if env is None:
        env = MahoragaEnv(difficulty=current_difficulty)
        env.reset()

    # Load model on first call
    if not llm_loaded and not load_llm():
        # Fallback to smart rule-based agent
        mahoraga_action = _smart_agent_action()
        llm_raw = "[FALLBACK] rule-based"
    else:
        state_dict = env._get_state()
        mahoraga_action, llm_raw = llm_choose_action(state_dict)

    state, reward, done, info = env.step(mahoraga_action, enemy_category_override=player_action)
    action_name = ACTION_NAMES.get(env.last_action, "Unknown")

    turn_log = TurnLog(
        turn=state["turn_number"],
        enemy_attack_type=state["last_enemy_attack_type"] or "NONE",
        enemy_subtype=state["last_enemy_subtype"] or "NONE",
        mahoraga_action=action_name,
        damage_taken=info["damage_taken"],
        damage_dealt=info["damage_dealt"],
        correct_adaptation=info["correct_adaptation"],
        reward=round(reward, 2),
        heal_blocked=info.get("heal_on_cooldown", False),
    )

    return CombatState(
        enemy_hp=state["enemy_hp"],
        enemy_hp_max=ENEMY_HP,
        mahoraga_hp=state["agent_hp"],
        mahoraga_hp_max=MAX_HP,
        resistances=Resistances(
            Physical=state["resistances"]["physical"],
            CE=state["resistances"]["ce"],
            Technique=state["resistances"]["technique"],
        ),
        adaptation_stack=info["adaptation_stack"],
        heal_cooldown=env.heal_cooldown_counter,
        turn_number=state["turn_number"],
        max_turns=MAX_TURNS,
        done=done,
        done_reason=info.get("reason"),
        turn_log=turn_log,
        difficulty=current_difficulty,
        llm_raw=llm_raw,
    )


@app.post("/api/step", response_model=CombatState)
def step(req: StepRequest):
    """Execute one turn of combat."""
    return _do_step(req.player_action)


@app.get("/api/model-status")
def model_status():
    """Check if the LLM model is loaded."""
    return {
        "loaded": llm_loaded,
        "error": llm_error,
        "model_path": os.path.join(os.path.dirname(__file__), "mahoraga_loral_final"),
    }


def _smart_agent_action():
    """Rule-based fallback agent mimicking the trained LLM's strategy."""
    if env is None:
        return 0

    state = env._get_state()
    agent_hp = state["agent_hp"]
    res = state["resistances"]

    # Heal if critical HP and cooldown ready
    if agent_hp < 300 and env.heal_cooldown_counter == 0:
        return 4

    # Judgment Strike if stacks >= 3 (or >= 2 and adapted to right type)
    if env.adaptation_stack >= 3:
        return 3
    if env.adaptation_stack >= 2 and env.last_adapted_category == state.get("last_enemy_attack_type"):
        return 3

    # Adapt to last enemy attack type
    last_attack = state.get("last_enemy_attack_type")
    adapt_map = {"PHYSICAL": 0, "CE": 1, "TECHNIQUE": 2}
    if last_attack and last_attack in adapt_map:
        return adapt_map[last_attack]

    # Default: adapt to weakest resistance
    weakest = min(res, key=res.get)
    return adapt_map.get(weakest.upper(), 0)

# ── Serve React Frontend (SPA Catch-all) ──
dist_dir = os.path.join(os.path.dirname(__file__), "frontend", "dist")

@app.get("/{full_path:path}")
async def serve_frontend(full_path: str):
    """
    Catch-all route to serve the React frontend build.
    Serves exact files if they exist, otherwise falls back to index.html for client-side routing.
    """
    # Exclude /api routes just in case they fall through (though FastAPI routes them first)
    if full_path.startswith("api/"):
        raise HTTPException(status_code=404, detail="API endpoint not found")
        
    file_path = os.path.join(dist_dir, full_path)
    if full_path and os.path.isfile(file_path):
        return FileResponse(file_path)
        
    index_path = os.path.join(dist_dir, "index.html")
    if os.path.isfile(index_path):
        return FileResponse(index_path)
        
    return {"message": "Frontend build not found. Run 'npm run build' in the frontend directory."}



if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)