File size: 11,299 Bytes
7a208d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# Mouse AI - Program Generation Model

232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats.

## Quick Start

```python
import torch
from model.model_2B import StructureAwareTransformer2B
from lightweight_simulator import LightweightGameSimulator

# Load model
device = 'cuda:0'  # or 'cpu'
ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False)
config = ckpt['model_config']

model = StructureAwareTransformer2B(**config)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(device)
model.eval()

# Play a game
game = LightweightGameSimulator(level=3)
game.reset()

for run in range(20):
    if game.win_sign or game.lose_sign:
        break

    # Get state vector (828 dimensions)
    state = get_state_vector(game).unsqueeze(0).to(device)

    # Generate program
    with torch.no_grad():
        prog = model.generate(
            state, max_length=12, temperature=0.3,
            top_k=10, grammar_constrained=True
        )

    # Parse output
    if isinstance(prog, tuple): prog = prog[0]
    if isinstance(prog, torch.Tensor): prog = prog[0].tolist()
    if prog and prog[0] == 0: prog = prog[1:]  # remove start token
    if 112 in prog: prog = prog[:prog.index(112)]  # remove END and after

    # Execute
    game.execute_program(prog)

print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}")
```

## Model Architecture

| Parameter | Value |
|-----------|-------|
| Type | StructureAwareTransformer2B |
| Total Parameters | 232.2M |
| Hidden Dimension | 1024 |
| Layers | 16 |
| Attention Heads | 16 (Query) / 4 (KV, Grouped Query Attention) |
| Feed-Forward Dim | 4096 |
| State Input | 828 dimensions |
| Vocab Size | 113 tokens |
| Max Program Length | 12 tokens |

### Model Config (for initialization)
```python
config = {
    'state_dim': 828,
    'hidden_dim': 1024,
    'vocab_size': 113,
    'max_program_length': 12,
    'num_layers': 16,
    'num_heads': 16,
    'num_kv_heads': 4,
    'ff_dim': 4096,
    'dropout': 0.1,
    'end_token': 112,
}
model = StructureAwareTransformer2B(**config)
```

## Token Vocabulary (113 tokens)

### Direction Tokens (0-3)
| Token ID | Direction | Movement |
|----------|-----------|----------|
| 0 | UP | Mouse moves up one cell |
| 1 | DOWN | Mouse moves down one cell |
| 2 | LEFT | Mouse moves left one cell |
| 3 | RIGHT | Mouse moves right one cell |

### Number Tokens (100-109)
| Token ID | Value | Usage |
|----------|-------|-------|
| 100 | 1 | LOOP repeat count (1 time) |
| 104 | 5 | LOOP repeat count (5 times) |
| 105 | 6 | LOOP repeat count (6 times) |
| 106 | 7 | LOOP repeat count (7 times) |
| 107 | 8 | LOOP repeat count (8 times) |
| 108 | 9 | LOOP repeat count (9 times) |
| 109 | 10 | LOOP repeat count (10 times) |

Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency.

### Special Tokens
| Token ID | Name | Function |
|----------|------|----------|
| 110 | LOOP | Start a loop structure |
| 112 | END | End of program |

Token 111 (IF) was removed due to simulator incompatibility.

## Grammar Rules

Programs follow a strict context-free grammar:

```
start       -> DIR | LOOP NUM DIR | END
after_DIR   -> DIR | LOOP NUM DIR | END
after_LOOP  -> NUM (must be 104-109)
after_NUM   -> DIR (must be 0-3)
after_END   -> (stop generation)
```

### Valid Program Examples
```
[0, 112]                          # Move UP, END
[2, 2, 2, 112]                   # Move LEFT 3 times, END
[110, 106, 1, 112]               # LOOP(7 times, DOWN), END
[0, 110, 104, 2, 3, 112]         # UP, LOOP(5 times, LEFT), RIGHT, END
[110, 108, 0, 110, 105, 3, 112]  # LOOP(9, UP), LOOP(6, RIGHT), END
```

### Grammar Constraint: LOOP cutoff at position 8
LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs.

## State Vector (828 dimensions)

The 828-dimensional state vector encodes the complete game state:

```python
def get_state_vector(sim):
    """Extract 828-dim state vector from game simulator"""
    state_dict = sim.get_state_dict()
    state = []
    DYNAMIC_SCALE = 10.0  # Scale factor for dynamic features

    # --- Grid features (11x11 grids) ---

    # 1. Wall grid (121 dims): 1=wall, 0=empty
    for row in state_dict['wall']:
        state.extend(row)

    # 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected
    #    Scaled by DYNAMIC_SCALE (10.0)
    for row in state_dict['sc']:
        state.extend([v * DYNAMIC_SCALE for v in row])

    # 3. Junction grid (121 dims): 1=junction, 0=not
    for row in state_dict['junc']:
        state.extend(row)

    # 4. Dead-end grid (121 dims): 1=dead-end, 0=not
    for row in state_dict['deadend']:
        state.extend(row)

    # Total grid: 484 dims (4 * 121)

    # --- Entity positions ---

    # 5. Mouse position (2 dims): [x, y]
    mouse = state_dict['mouse']
    state.extend([float(mouse[0]), float(mouse[1])])

    # 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1
    cat_list = state_dict.get('cat', [])
    for i in range(6):
        if i < len(cat_list):
            state.extend([float(cat_list[i][0]), float(cat_list[i][1])])
        else:
            state.extend([-1.0, -1.0])

    # 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1
    bc_list = state_dict.get('crzbc', [])
    for i in range(5):
        if i < len(bc_list):
            state.extend([float(bc_list[i][0]), float(bc_list[i][1])])
        else:
            state.extend([-1.0, -1.0])

    # Pad to 549 dims (484 + 65)
    while len(state) < 484 + 65:
        state.append(0.0)

    # --- Scalar features (6 dims) ---

    # 8. Score (normalized by 1000, scaled)
    state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE)

    # 9. Life (normalized by 3, scaled) - starts at 3
    state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0)

    # 10. Current run number (normalized by 20, scaled)
    state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0)

    # 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise)
    state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0)

    # 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise)
    state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0)

    # 13. Step progress (current_step / step_limit, scaled)
    step = state_dict.get('step', 0)
    step_limit = state_dict.get('step_limit', 200)
    state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0)

    # Pad to 828 dims
    while len(state) < 828:
        state.append(0.0)

    return torch.tensor(state[:828], dtype=torch.float32)
```

### State Vector Layout Summary
| Range | Dims | Content | Scale |
|-------|------|---------|-------|
| 0-120 | 121 | Wall grid (11x11) | 1.0 |
| 121-241 | 121 | Small Cheese grid | 10.0 |
| 242-362 | 121 | Junction grid | 1.0 |
| 363-483 | 121 | Dead-end grid | 1.0 |
| 484-485 | 2 | Mouse position [x,y] | 1.0 |
| 486-497 | 12 | Cat positions (6 cats) | 1.0 |
| 498-507 | 10 | Big Cheese positions (5) | 1.0 |
| 508-548 | 41 | Padding (zeros) | - |
| 549 | 1 | Score / 1000 * 10 | 10.0 |
| 550 | 1 | Life / 3 * 10 | 10.0 |
| 551 | 1 | Run / 20 * 10 | 10.0 |
| 552 | 1 | Win flag | 10.0 |
| 553 | 1 | Lose flag | 10.0 |
| 554 | 1 | Step progress | 10.0 |
| 555-827 | 273 | Padding (zeros) | - |

## Game Rules (Level 3)

### Map
- 11x11 grid maze with walls
- Fixed wall layout for level 3

### Entities
- **Mouse**: Player-controlled, starts at position [10, 10]
- **Cat 0 (Dummy)**: Starts at [2, 2], moves only during command execution (len(command) steps)
- **Cat 1 (Naughty)**: Starts at [5, 5], moves every mouse step
- **Small Cheese (SC)**: 75 stationary items, +10 points each
- **Stationary Big Cheese (movbc)**: 2 items, +500 points each, don't move
- **Moving Big Cheese (crzbc)**: 2 items, +500 points each, move each step

### Cat Movement (Random Mode)
Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the `_get_cats_direct_actions` mode in the simulator.

### Scoring
| Event | Points |
|-------|--------|
| Collect Small Cheese | +10 |
| Collect Big Cheese | +500 |
| Hit Wall | -10 |
| Caught by Cat | -500 (+ lose 1 life) |
| Win Bonus | +(run * 10 + step) |

### Win/Lose Conditions
- **WIN**: Collect ALL 75 Small Cheese + END token executed
- **LOSE (life)**: Life reaches 0 (caught 3 times)
- **LOSE (step)**: Step count reaches 200
- **LOSE (run)**: 20 runs exhausted without winning

### Game Flow
1. Game starts with mouse at [10,10], 3 lives, 20 max runs
2. Each run: model generates a program -> program executes step by step
3. During execution: mouse moves, cats move randomly, cheese collected, collisions checked
4. After program ends: next run begins
5. Continue until WIN or LOSE

## Program Execution

When a program like `[0, 110, 106, 2, 3, 112]` executes:

1. Token `0` (UP): mouse moves up 1 step
2. Token `110, 106, 2` (LOOP 7 LEFT): mouse moves left 7 steps
3. Token `3` (RIGHT): mouse moves right 1 step
4. Token `112` (END): program ends

Each step:
- Mouse attempts to move in the direction
- If wall: mouse stays, -10 points
- Cat 1 moves (random at junctions)
- Cat 0 moves (only during command-length steps)
- Check for cat collision: -500 points, lose 1 life, respawn at [10,10]
- Check for cheese collection: +10 (SC) or +500 (BC)
- Check win/lose conditions

## Performance

| Metric | Value |
|--------|-------|
| Win Rate (temp=0.3, 100 games) | 30% |
| Average Score | 1437 |
| Average Runs per Win | 13.8 |
| Simulator | New simulator (random cats) |

### Training Pipeline
1. **Base Model**: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim)
2. **RM32 Data Generation**: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples
3. **SFT Training**: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate

## Generation Parameters

| Parameter | Recommended | Description |
|-----------|-------------|-------------|
| temperature | 0.3 | Lower = more deterministic, higher win rate |
| top_k | 10 | Top-k sampling |
| grammar_constrained | True | MUST be True to generate valid programs |
| max_length | 12 | Maximum program length |

## File Structure

```
hardai_model_export/
  model_best.pt              # Model checkpoint (886MB)
  README.md                  # This file
  lightweight_simulator.py   # Game simulator
  model/                     # Model architecture
    __init__.py
    model_2B.py              # Main model class
    state_encoder.py
    program_embedding.py
    transformer.py           # Flash Attention + gradient checkpointing
    multi_task_head.py
    memory_encoder.py
    memory_state_fusion.py
    value_predictor.py
```

## Requirements

```
torch >= 2.0
numpy
pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy)
```

## Headless Mode (No Display)

```python
import os
os.environ['SDL_VIDEODRIVER'] = 'dummy'
os.environ['SDL_AUDIODRIVER'] = 'dummy'
```

Set these BEFORE importing the simulator.