RayMelius Claude Opus 4.6 commited on
Commit
bea6321
·
1 Parent(s): 7e43568

Diversify agent activities, add local training script, scheduled Gemini cycle, architecture diagram

Browse files

- routine.py: Add morning/post-work exercise, park lunch, sports_field for more
personality types, varied activity descriptions, park as exercise destination
- nn_train.py: Standalone local training script (equivalent to Kaggle notebook)
with synthetic data gen, ONNX export, HF push, CUDA support
- nn_selfimprove.py: Add 'scheduled' mode for nightly Gemini collection + retrain,
oversample LLM-sourced samples 3x during training
- routes.py: Fix valid providers list (nn instead of hf)
- docs/architecture.html: SVG diagram of SociAgentTransformer architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

docs/architecture.html ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>SociAgentTransformer Architecture</title>
6
+ <style>
7
+ body {
8
+ margin: 0;
9
+ background: #0d1117;
10
+ display: flex;
11
+ justify-content: center;
12
+ align-items: flex-start;
13
+ min-height: 100vh;
14
+ font-family: 'Segoe UI', system-ui, -apple-system, sans-serif;
15
+ padding: 40px 20px;
16
+ }
17
+ svg {
18
+ filter: drop-shadow(0 4px 24px rgba(0,0,0,0.4));
19
+ }
20
+ .title {
21
+ font-size: 22px;
22
+ font-weight: 700;
23
+ fill: #e6edf3;
24
+ letter-spacing: 0.5px;
25
+ }
26
+ .subtitle {
27
+ font-size: 12px;
28
+ fill: #8b949e;
29
+ font-weight: 400;
30
+ }
31
+ .box-label {
32
+ font-size: 11px;
33
+ font-weight: 600;
34
+ fill: #e6edf3;
35
+ }
36
+ .box-detail {
37
+ font-size: 9.5px;
38
+ fill: #8b949e;
39
+ }
40
+ .box-dim {
41
+ font-size: 9px;
42
+ fill: #58a6ff;
43
+ font-weight: 600;
44
+ font-family: 'Cascadia Code', 'Consolas', monospace;
45
+ }
46
+ .group-label {
47
+ font-size: 9px;
48
+ font-weight: 600;
49
+ fill: #e6edf3;
50
+ }
51
+ .group-dim {
52
+ font-size: 8px;
53
+ fill: #8b949e;
54
+ font-family: 'Cascadia Code', 'Consolas', monospace;
55
+ }
56
+ .section-label {
57
+ font-size: 10px;
58
+ font-weight: 700;
59
+ fill: #8b949e;
60
+ letter-spacing: 1.5px;
61
+ text-transform: uppercase;
62
+ }
63
+ .arrow {
64
+ stroke: #30363d;
65
+ stroke-width: 2;
66
+ fill: none;
67
+ marker-end: url(#arrowhead);
68
+ }
69
+ .arrow-action {
70
+ stroke: #f0883e;
71
+ stroke-width: 1.5;
72
+ fill: none;
73
+ stroke-dasharray: 4 3;
74
+ marker-end: url(#arrowhead-orange);
75
+ }
76
+ .brace-text {
77
+ font-size: 9px;
78
+ fill: #8b949e;
79
+ font-style: italic;
80
+ }
81
+ .repeat-badge {
82
+ font-size: 9px;
83
+ font-weight: 700;
84
+ fill: #f0883e;
85
+ }
86
+ .param-text {
87
+ font-size: 8.5px;
88
+ fill: #7ee787;
89
+ font-family: 'Cascadia Code', 'Consolas', monospace;
90
+ }
91
+ </style>
92
+ </head>
93
+ <body>
94
+ <svg xmlns="http://www.w3.org/2000/svg" width="720" height="1280" viewBox="0 0 720 1280">
95
+ <defs>
96
+ <!-- Rounded rect filter for glow -->
97
+ <filter id="glow">
98
+ <feGaussianBlur stdDeviation="2" result="blur"/>
99
+ <feMerge><feMergeNode in="blur"/><feMergeNode in="SourceGraphic"/></feMerge>
100
+ </filter>
101
+
102
+ <!-- Arrow markers -->
103
+ <marker id="arrowhead" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
104
+ <polygon points="0 0, 8 3, 0 6" fill="#30363d"/>
105
+ </marker>
106
+ <marker id="arrowhead-orange" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
107
+ <polygon points="0 0, 8 3, 0 6" fill="#f0883e"/>
108
+ </marker>
109
+
110
+ <!-- Gradient backgrounds -->
111
+ <linearGradient id="grad-input" x1="0" y1="0" x2="0" y2="1">
112
+ <stop offset="0%" stop-color="#1a2332"/>
113
+ <stop offset="100%" stop-color="#161b22"/>
114
+ </linearGradient>
115
+ <linearGradient id="grad-tokenizer" x1="0" y1="0" x2="0" y2="1">
116
+ <stop offset="0%" stop-color="#1c2d1e"/>
117
+ <stop offset="100%" stop-color="#161b22"/>
118
+ </linearGradient>
119
+ <linearGradient id="grad-transformer" x1="0" y1="0" x2="0" y2="1">
120
+ <stop offset="0%" stop-color="#2d1f32"/>
121
+ <stop offset="100%" stop-color="#161b22"/>
122
+ </linearGradient>
123
+ <linearGradient id="grad-cls" x1="0" y1="0" x2="0" y2="1">
124
+ <stop offset="0%" stop-color="#2d2a1f"/>
125
+ <stop offset="100%" stop-color="#161b22"/>
126
+ </linearGradient>
127
+ <linearGradient id="grad-head-action" x1="0" y1="0" x2="0" y2="1">
128
+ <stop offset="0%" stop-color="#2d1f1f"/>
129
+ <stop offset="100%" stop-color="#1a1515"/>
130
+ </linearGradient>
131
+ <linearGradient id="grad-head-loc" x1="0" y1="0" x2="0" y2="1">
132
+ <stop offset="0%" stop-color="#1f2a2d"/>
133
+ <stop offset="100%" stop-color="#151a1a"/>
134
+ </linearGradient>
135
+ <linearGradient id="grad-head-dur" x1="0" y1="0" x2="0" y2="1">
136
+ <stop offset="0%" stop-color="#2d2d1f"/>
137
+ <stop offset="100%" stop-color="#1a1a15"/>
138
+ </linearGradient>
139
+ </defs>
140
+
141
+ <!-- Background -->
142
+ <rect width="720" height="1280" rx="16" fill="#0d1117" stroke="#21262d" stroke-width="1"/>
143
+
144
+ <!-- Title -->
145
+ <text x="360" y="38" text-anchor="middle" class="title">SociAgentTransformer</text>
146
+ <text x="360" y="56" text-anchor="middle" class="subtitle">Transformer + Mixture-of-Experts for Agent Decision Making</text>
147
+ <text x="360" y="72" text-anchor="middle" class="param-text">1.45M params | ~5.5 MB (fp32) | ~1ms inference (50 agents, ONNX)</text>
148
+
149
+ <!-- ═══════════════════════════════════════════════════════ -->
150
+ <!-- INPUT LAYER -->
151
+ <!-- ═══════════════════════════════════════════════════════ -->
152
+ <text x="36" y="108" class="section-label">Input</text>
153
+
154
+ <rect x="110" y="92" width="500" height="44" rx="8" fill="url(#grad-input)" stroke="#1f6feb" stroke-width="1.5"/>
155
+ <text x="360" y="112" text-anchor="middle" class="box-label">Agent State Feature Vector</text>
156
+ <text x="360" y="126" text-anchor="middle" class="box-dim">(B, 47)</text>
157
+
158
+ <!-- Arrow down -->
159
+ <line x1="360" y1="136" x2="360" y2="158" class="arrow"/>
160
+
161
+ <!-- ═══════════════════════════════════════════════════════ -->
162
+ <!-- FEATURE TOKENIZER -->
163
+ <!-- ═══════════════════════════════════════════════════════ -->
164
+ <text x="36" y="178" class="section-label">Tokenizer</text>
165
+
166
+ <rect x="60" y="162" width="600" height="155" rx="10" fill="none" stroke="#238636" stroke-width="1.5" stroke-dasharray="5 3"/>
167
+ <text x="360" y="182" text-anchor="middle" class="box-label">Feature Tokenizer</text>
168
+ <text x="360" y="194" text-anchor="middle" class="box-detail">Split features into 6 semantic groups, project each to d_model</text>
169
+
170
+ <!-- 6 Feature group boxes -->
171
+ <!-- Row 1 -->
172
+ <rect x="80" y="206" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
173
+ <text x="162" y="222" text-anchor="middle" class="group-label">Personality</text>
174
+ <text x="162" y="236" text-anchor="middle" class="group-dim">[0:6] Big5 + Age</text>
175
+ <text x="162" y="246" text-anchor="middle" class="box-dim">6 -> 128</text>
176
+
177
+ <rect x="277" y="206" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
178
+ <text x="360" y="222" text-anchor="middle" class="group-label">Time</text>
179
+ <text x="360" y="236" text-anchor="middle" class="group-dim">[6:12] sin/cos + day</text>
180
+ <text x="360" y="246" text-anchor="middle" class="box-dim">6 -> 128</text>
181
+
182
+ <rect x="474" y="206" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
183
+ <text x="557" y="222" text-anchor="middle" class="group-label">Needs + Mood</text>
184
+ <text x="557" y="236" text-anchor="middle" class="group-dim">[12:21] 6 needs + urgency</text>
185
+ <text x="557" y="246" text-anchor="middle" class="box-dim">9 -> 128</text>
186
+
187
+ <!-- Row 2 -->
188
+ <rect x="80" y="258" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
189
+ <text x="162" y="274" text-anchor="middle" class="group-label">Location</text>
190
+ <text x="162" y="288" text-anchor="middle" class="group-dim">[21:31] zone + flags + people</text>
191
+ <text x="162" y="298" text-anchor="middle" class="box-dim">10 -> 128</text>
192
+
193
+ <rect x="277" y="258" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
194
+ <text x="360" y="274" text-anchor="middle" class="group-label">Time Period</text>
195
+ <text x="360" y="288" text-anchor="middle" class="group-dim">[31:38] 7-class one-hot</text>
196
+ <text x="360" y="298" text-anchor="middle" class="box-dim">7 -> 128</text>
197
+
198
+ <rect x="474" y="258" width="165" height="44" rx="6" fill="url(#grad-tokenizer)" stroke="#238636" stroke-width="1"/>
199
+ <text x="557" y="274" text-anchor="middle" class="group-label">Last Action</text>
200
+ <text x="557" y="288" text-anchor="middle" class="group-dim">[38:47] 9-class one-hot</text>
201
+ <text x="557" y="298" text-anchor="middle" class="box-dim">9 -> 128</text>
202
+
203
+ <!-- Plus positional embeddings note -->
204
+ <text x="360" y="316" text-anchor="middle" class="brace-text">+ learnable positional embeddings per token</text>
205
+
206
+ <!-- Output shape from tokenizer -->
207
+ <text x="360" y="330" text-anchor="middle" class="box-dim">(B, 6, 128)</text>
208
+
209
+ <!-- Arrow down -->
210
+ <line x1="360" y1="335" x2="360" y2="362" class="arrow"/>
211
+
212
+ <!-- ═══════════════════════════════════════════════════════ -->
213
+ <!-- TRANSFORMER ENCODER (x4) -->
214
+ <!-- ═══════════════════════════════════════════════════════ -->
215
+ <text x="36" y="382" class="section-label">Encoder</text>
216
+
217
+ <!-- Repeat bracket -->
218
+ <rect x="60" y="366" width="600" height="310" rx="10" fill="none" stroke="#8b5cf6" stroke-width="1.5" stroke-dasharray="5 3"/>
219
+ <rect x="600" y="366" width="56" height="22" rx="6" fill="#8b5cf6" fill-opacity="0.2" stroke="#8b5cf6" stroke-width="1"/>
220
+ <text x="628" y="381" text-anchor="middle" class="repeat-badge">x 4</text>
221
+
222
+ <text x="360" y="386" text-anchor="middle" class="box-label">Transformer Encoder Block</text>
223
+
224
+ <!-- Multi-Head Self Attention -->
225
+ <rect x="130" y="396" width="460" height="52" rx="8" fill="url(#grad-transformer)" stroke="#8b5cf6" stroke-width="1.2"/>
226
+ <text x="360" y="416" text-anchor="middle" class="box-label">Multi-Head Self-Attention</text>
227
+ <text x="360" y="430" text-anchor="middle" class="box-detail">8 heads, d_k=16, batch_first=True</text>
228
+ <text x="360" y="442" text-anchor="middle" class="param-text">Q, K, V: (B, 6, 128) -> (B, 6, 128)</text>
229
+
230
+ <!-- Residual + LayerNorm -->
231
+ <rect x="220" y="454" width="280" height="24" rx="6" fill="#161b22" stroke="#30363d" stroke-width="1"/>
232
+ <text x="360" y="470" text-anchor="middle" class="box-detail">Add & LayerNorm</text>
233
+
234
+ <!-- Arrow -->
235
+ <line x1="360" y1="478" x2="360" y2="496" class="arrow"/>
236
+
237
+ <!-- MoE Feed-Forward -->
238
+ <rect x="130" y="498" width="460" height="130" rx="8" fill="url(#grad-transformer)" stroke="#8b5cf6" stroke-width="1.2"/>
239
+ <text x="360" y="518" text-anchor="middle" class="box-label">Mixture-of-Experts Feed-Forward</text>
240
+ <text x="360" y="532" text-anchor="middle" class="box-detail">4 experts, top-2 routing, gated softmax</text>
241
+
242
+ <!-- 4 Expert boxes inside -->
243
+ <rect x="155" y="544" width="95" height="36" rx="5" fill="#1c1c2e" stroke="#6e40c9" stroke-width="1"/>
244
+ <text x="202" y="558" text-anchor="middle" class="group-label">Expert 0</text>
245
+ <text x="202" y="572" text-anchor="middle" class="group-dim">128->256->128</text>
246
+
247
+ <rect x="263" y="544" width="95" height="36" rx="5" fill="#1c1c2e" stroke="#6e40c9" stroke-width="1"/>
248
+ <text x="310" y="558" text-anchor="middle" class="group-label">Expert 1</text>
249
+ <text x="310" y="572" text-anchor="middle" class="group-dim">128->256->128</text>
250
+
251
+ <rect x="371" y="544" width="95" height="36" rx="5" fill="#1c1c2e" stroke="#6e40c9" stroke-width="1"/>
252
+ <text x="418" y="558" text-anchor="middle" class="group-label">Expert 2</text>
253
+ <text x="418" y="572" text-anchor="middle" class="group-dim">128->256->128</text>
254
+
255
+ <rect x="479" y="544" width="95" height="36" rx="5" fill="#1c1c2e" stroke="#6e40c9" stroke-width="1"/>
256
+ <text x="526" y="558" text-anchor="middle" class="group-label">Expert 3</text>
257
+ <text x="526" y="572" text-anchor="middle" class="group-dim">128->256->128</text>
258
+
259
+ <!-- Gate -->
260
+ <rect x="260" y="590" width="200" height="26" rx="5" fill="#1c1c2e" stroke="#f0883e" stroke-width="1"/>
261
+ <text x="360" y="607" text-anchor="middle" class="group-label" style="fill:#f0883e">Gate: Linear(128, 4) -> top-2</text>
262
+
263
+ <!-- Residual + LayerNorm -->
264
+ <rect x="220" y="634" width="280" height="24" rx="6" fill="#161b22" stroke="#30363d" stroke-width="1"/>
265
+ <text x="360" y="650" text-anchor="middle" class="box-detail">Add & LayerNorm</text>
266
+
267
+ <!-- Output shape -->
268
+ <text x="360" y="680" text-anchor="middle" class="box-dim">(B, 6, 128)</text>
269
+
270
+ <!-- Arrow down -->
271
+ <line x1="360" y1="685" x2="360" y2="710" class="arrow"/>
272
+
273
+ <!-- ═══════════════════════════════════════════════════════ -->
274
+ <!-- CLS AGGREGATION -->
275
+ <!-- ═══════════════════════════════════════════════════════ -->
276
+ <text x="36" y="735" class="section-label">Pooling</text>
277
+
278
+ <rect x="110" y="716" width="500" height="90" rx="8" fill="url(#grad-cls)" stroke="#d29922" stroke-width="1.5"/>
279
+ <text x="360" y="738" text-anchor="middle" class="box-label">[CLS] Query Aggregation</text>
280
+ <text x="360" y="754" text-anchor="middle" class="box-detail">Learned query (1, 1, 128) attends to all 6 tokens via cross-attention</text>
281
+ <text x="360" y="770" text-anchor="middle" class="param-text">cls_query -> cross_attn(Q=cls, K=tokens, V=tokens) -> LayerNorm</text>
282
+ <text x="360" y="788" text-anchor="middle" class="box-dim">h: (B, 128)</text>
283
+
284
+ <!-- Arrow splits into 3 -->
285
+ <line x1="360" y1="806" x2="360" y2="830" class="arrow"/>
286
+
287
+ <!-- ═══════════════════════════════════════════════════════ -->
288
+ <!-- TASK HEADS -->
289
+ <!-- ═══════════════════════════════════════════════════════ -->
290
+ <text x="36" y="860" class="section-label">Task Heads</text>
291
+
292
+ <!-- Horizontal split line -->
293
+ <line x1="160" y1="840" x2="560" y2="840" stroke="#30363d" stroke-width="1"/>
294
+
295
+ <!-- Three vertical arrows from split -->
296
+ <line x1="180" y1="840" x2="180" y2="868" class="arrow"/>
297
+ <line x1="360" y1="840" x2="360" y2="868" class="arrow"/>
298
+ <line x1="540" y1="840" x2="540" y2="868" class="arrow"/>
299
+
300
+ <!-- ACTION HEAD -->
301
+ <rect x="80" y="870" width="200" height="110" rx="8" fill="url(#grad-head-action)" stroke="#f85149" stroke-width="1.5"/>
302
+ <text x="180" y="892" text-anchor="middle" class="box-label" style="fill:#f85149">Action Head</text>
303
+ <text x="180" y="908" text-anchor="middle" class="box-detail">2-layer MLP</text>
304
+ <text x="180" y="926" text-anchor="middle" class="param-text">Linear(128, 128)</text>
305
+ <text x="180" y="938" text-anchor="middle" class="param-text">GELU + Dropout(0.1)</text>
306
+ <text x="180" y="950" text-anchor="middle" class="param-text">Linear(128, 9)</text>
307
+ <text x="180" y="972" text-anchor="middle" class="box-dim">(B, 9) logits</text>
308
+
309
+ <!-- LOCATION HEAD -->
310
+ <rect x="300" y="870" width="200" height="110" rx="8" fill="url(#grad-head-loc)" stroke="#58a6ff" stroke-width="1.5"/>
311
+ <text x="400" y="892" text-anchor="middle" class="box-label" style="fill:#58a6ff">Location Head</text>
312
+ <text x="400" y="908" text-anchor="middle" class="box-detail">Action-conditioned MLP</text>
313
+ <text x="400" y="926" text-anchor="middle" class="param-text">Linear(128+9, 128)</text>
314
+ <text x="400" y="938" text-anchor="middle" class="param-text">GELU + Dropout(0.1)</text>
315
+ <text x="400" y="950" text-anchor="middle" class="param-text">Linear(128, 38)</text>
316
+ <text x="400" y="972" text-anchor="middle" class="box-dim">(B, 38) logits</text>
317
+
318
+ <!-- DURATION HEAD -->
319
+ <rect x="520" y="870" width="140" height="110" rx="8" fill="url(#grad-head-dur)" stroke="#d29922" stroke-width="1.5"/>
320
+ <text x="590" y="892" text-anchor="middle" class="box-label" style="fill:#d29922">Duration Head</text>
321
+ <text x="590" y="908" text-anchor="middle" class="box-detail">Regression MLP</text>
322
+ <text x="590" y="926" text-anchor="middle" class="param-text">Linear(137, 64)</text>
323
+ <text x="590" y="938" text-anchor="middle" class="param-text">GELU</text>
324
+ <text x="590" y="950" text-anchor="middle" class="param-text">Linear(64, 1)</text>
325
+ <text x="590" y="972" text-anchor="middle" class="box-dim">sigmoid*7+1</text>
326
+
327
+ <!-- Action probs feedback arrows -->
328
+ <path d="M 180 980 L 180 1000 L 320 1000 L 320 920 L 300 920" class="arrow-action"/>
329
+ <path d="M 180 980 L 180 1010 L 540 1010 L 540 920 L 520 920" class="arrow-action"/>
330
+ <text x="250" y="996" class="brace-text" style="fill:#f0883e">softmax(action).detach()</text>
331
+
332
+ <!-- ═══════════════════════════════════════════════════════ -->
333
+ <!-- OUTPUT -->
334
+ <!-- ═══════════════════════════════════════════════════════ -->
335
+ <text x="36" y="1060" class="section-label">Output</text>
336
+
337
+ <!-- Three output arrows -->
338
+ <line x1="180" y1="980" x2="180" y2="1068" class="arrow"/>
339
+ <line x1="400" y1="980" x2="400" y2="1068" class="arrow"/>
340
+ <line x1="590" y1="980" x2="590" y2="1068" class="arrow"/>
341
+
342
+ <!-- Output boxes -->
343
+ <rect x="95" y="1070" width="170" height="52" rx="8" fill="#1a1515" stroke="#f85149" stroke-width="1.2"/>
344
+ <text x="180" y="1090" text-anchor="middle" class="group-label" style="fill:#f85149">Action Type</text>
345
+ <text x="180" y="1104" text-anchor="middle" class="group-dim">9 classes: move, work,</text>
346
+ <text x="180" y="1114" text-anchor="middle" class="group-dim">eat, sleep, talk, ...</text>
347
+
348
+ <rect x="315" y="1070" width="170" height="52" rx="8" fill="#151a1a" stroke="#58a6ff" stroke-width="1.2"/>
349
+ <text x="400" y="1090" text-anchor="middle" class="group-label" style="fill:#58a6ff">Target Location</text>
350
+ <text x="400" y="1104" text-anchor="middle" class="group-dim">38 locations: cafe,</text>
351
+ <text x="400" y="1114" text-anchor="middle" class="group-dim">park, office, home, ...</text>
352
+
353
+ <rect x="520" y="1070" width="140" height="52" rx="8" fill="#1a1a15" stroke="#d29922" stroke-width="1.2"/>
354
+ <text x="590" y="1090" text-anchor="middle" class="group-label" style="fill:#d29922">Duration</text>
355
+ <text x="590" y="1104" text-anchor="middle" class="group-dim">1-8 ticks</text>
356
+ <text x="590" y="1114" text-anchor="middle" class="group-dim">(15 min each)</text>
357
+
358
+ <!-- ═══════════════════════════════════════════════════════ -->
359
+ <!-- LOSS SECTION -->
360
+ <!-- ═══════════════════════════════════════════════════════ -->
361
+ <text x="36" y="1160" class="section-label">Training</text>
362
+
363
+ <rect x="80" y="1146" width="560" height="52" rx="8" fill="#161b22" stroke="#30363d" stroke-width="1"/>
364
+ <text x="360" y="1168" text-anchor="middle" class="box-label">Multi-Task Loss</text>
365
+ <text x="360" y="1184" text-anchor="middle" class="param-text">L = 1.0*CE_action(weighted) + 0.5*CE_location + 0.2*MSE_duration</text>
366
+
367
+ <rect x="80" y="1206" width="560" height="34" rx="8" fill="#161b22" stroke="#30363d" stroke-width="1"/>
368
+ <text x="360" y="1224" text-anchor="middle" class="box-detail">
369
+ AdamW (lr=3e-4, wd=1e-4) | CosineAnnealing | Grad clip=1.0 | 30 epochs | Batch=512
370
+ </text>
371
+
372
+ <!-- Footer -->
373
+ <text x="360" y="1268" text-anchor="middle" class="subtitle">ONNX export with opset 17 | CPU inference ~1ms for 50 agents</text>
374
+
375
+ </svg>
376
+ </body>
377
+ </html>
scripts/nn_selfimprove.py CHANGED
@@ -309,19 +309,30 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
309
 
310
  # ── Load collected data ──────────────────────────────────────────
311
  collected = []
 
312
  if SAMPLES_FILE.exists():
313
  with open(SAMPLES_FILE) as f:
314
  for line in f:
315
  line = line.strip()
316
  if line:
317
- collected.append(json.loads(line))
318
- logger.info(f"Loaded {len(collected):,} collected samples")
 
 
 
319
  else:
320
  logger.warning(f"No collected samples at {SAMPLES_FILE}")
321
 
 
 
 
 
 
 
 
 
322
  if len(collected) < 100:
323
  logger.warning("Too few collected samples — generating synthetic data to supplement")
324
- # Import synthetic generator from the notebook's logic (inline here)
325
  collected.extend(_generate_synthetic(50_000 - len(collected)))
326
 
327
  # ── Dataset ──────────────────────────────────────────────────────
@@ -718,18 +729,151 @@ def _generate_synthetic(n: int) -> list[dict]:
718
  return data
719
 
720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  # ════════════════════════════════════════════════════════════════════════
722
  # CLI
723
  # ════════════════════════════════════════════════════════════════════════
724
 
725
  def main():
726
  parser = argparse.ArgumentParser(description="Soci Agent NN — Self-Improvement Pipeline")
727
- parser.add_argument("mode", choices=["collect", "train", "push", "all"],
728
- help="collect=watch live sim, train=retrain NN, push=upload to HF, all=full pipeline")
 
729
  parser.add_argument("--url", default="https://raymelius-soci2.hf.space",
730
  help="Live simulation URL (default: HF Space)")
731
  parser.add_argument("--minutes", type=int, default=60,
732
  help="Collection duration in minutes (default: 60)")
 
 
 
 
733
  parser.add_argument("--epochs", type=int, default=20,
734
  help="Training epochs (default: 20)")
735
  parser.add_argument("--repo", default="RayMelius/soci-agent-nn",
@@ -745,6 +889,15 @@ def main():
745
  if args.mode in ("push", "all"):
746
  push(repo_id=args.repo)
747
 
 
 
 
 
 
 
 
 
 
748
 
749
  if __name__ == "__main__":
750
  main()
 
309
 
310
  # ── Load collected data ──────────────────────────────────────────
311
  collected = []
312
+ source_counts: dict[str, int] = {}
313
  if SAMPLES_FILE.exists():
314
  with open(SAMPLES_FILE) as f:
315
  for line in f:
316
  line = line.strip()
317
  if line:
318
+ sample = json.loads(line)
319
+ collected.append(sample)
320
+ src = sample.get("source", "unknown")
321
+ source_counts[src] = source_counts.get(src, 0) + 1
322
+ logger.info(f"Loaded {len(collected):,} collected samples — sources: {source_counts}")
323
  else:
324
  logger.warning(f"No collected samples at {SAMPLES_FILE}")
325
 
326
+ # Oversample LLM-sourced data (Gemini/Claude/Groq) — these are higher quality
327
+ # than NN or routine-generated samples, so we duplicate them 3x
328
+ llm_sources = {"gemini", "claude", "groq"}
329
+ llm_samples = [s for s in collected if s.get("source", "") in llm_sources]
330
+ if llm_samples:
331
+ logger.info(f"Oversampling {len(llm_samples):,} LLM-sourced samples (3x weight)")
332
+ collected.extend(llm_samples * 2) # 2 extra copies = 3x total weight
333
+
334
  if len(collected) < 100:
335
  logger.warning("Too few collected samples — generating synthetic data to supplement")
 
336
  collected.extend(_generate_synthetic(50_000 - len(collected)))
337
 
338
  # ── Dataset ──────────────────────────────────────────────────────
 
729
  return data
730
 
731
 
732
+ # ════════════════════════════════════════════════════════════════════════
733
+ # STEP 4: SCHEDULED — Nightly Gemini collection + retrain cycle
734
+ # ════════════════════════════════════════════════════════════════════════
735
+
736
+ async def scheduled(
737
+ base_url: str = "https://raymelius-soci2.hf.space",
738
+ collect_minutes: int = 120,
739
+ epochs: int = 25,
740
+ repo_id: str = "RayMelius/soci-agent-nn",
741
+ gemini_prob: float = 0.50,
742
+ ):
743
+ """Nightly training cycle: switch to Gemini at midnight, collect, retrain, push.
744
+
745
+ Flow:
746
+ 1. Wait until Gemini quota resets (midnight PT / configurable)
747
+ 2. Switch live sim to Gemini provider, raise probability
748
+ 3. Collect high-quality (state, action) samples from Gemini decisions
749
+ 4. Switch back to NN when done (or when quota exhausted)
750
+ 5. Train on collected Gemini samples (weighted 3x vs NN/routine samples)
751
+ 6. Push improved model to HF Hub
752
+ 7. Repeat next night
753
+
754
+ Usage:
755
+ python nn_selfimprove.py scheduled --collect-minutes 120 --gemini-prob 0.50
756
+ """
757
+ import datetime
758
+
759
+ async def _api_call(client: httpx.AsyncClient, method: str, path: str, **kwargs):
760
+ """Make API call with retries."""
761
+ for attempt in range(3):
762
+ try:
763
+ resp = await getattr(client, method)(path, timeout=30.0, **kwargs)
764
+ return resp
765
+ except httpx.HTTPError as e:
766
+ logger.warning(f"API {method.upper()} {path} attempt {attempt+1} failed: {e}")
767
+ if attempt < 2:
768
+ await asyncio.sleep(5)
769
+ return None
770
+
771
+ async def switch_provider(client: httpx.AsyncClient, provider: str, prob: float):
772
+ """Switch the live sim's LLM provider and probability."""
773
+ resp = await _api_call(client, "post", "/api/llm/provider",
774
+ json={"provider": provider})
775
+ if resp and resp.status_code == 200:
776
+ logger.info(f"Switched provider to: {provider}")
777
+ else:
778
+ logger.error(f"Failed to switch to {provider}: {resp.status_code if resp else 'no response'}")
779
+ return False
780
+
781
+ resp = await _api_call(client, "post", "/api/llm/probability",
782
+ json={"value": prob})
783
+ if resp and resp.status_code == 200:
784
+ logger.info(f"Set probability to: {prob:.0%}")
785
+ else:
786
+ logger.warning(f"Failed to set probability: {resp.status_code if resp else 'no response'}")
787
+
788
+ return True
789
+
790
+ async def wait_until_midnight():
791
+ """Wait until next midnight (local time) when Gemini quota resets."""
792
+ now = datetime.datetime.now()
793
+ tomorrow = now.replace(hour=0, minute=0, second=5, microsecond=0) + datetime.timedelta(days=1)
794
+ wait_secs = (tomorrow - now).total_seconds()
795
+ logger.info(f"Waiting {wait_secs/3600:.1f}h until midnight ({tomorrow.strftime('%Y-%m-%d %H:%M')})")
796
+ await asyncio.sleep(wait_secs)
797
+
798
+ # ── Main loop ─────────────────────────────────────────────────────
799
+ cycle = 0
800
+ while True:
801
+ cycle += 1
802
+ logger.info(f"{'='*60}")
803
+ logger.info(f"TRAINING CYCLE {cycle}")
804
+ logger.info(f"{'='*60}")
805
+
806
+ # 1. Wait for midnight (Gemini quota reset)
807
+ await wait_until_midnight()
808
+
809
+ async with httpx.AsyncClient(base_url=base_url) as client:
810
+ # 2. Switch to Gemini + raise probability
811
+ logger.info("Switching live sim to Gemini...")
812
+ ok = await switch_provider(client, "gemini", gemini_prob)
813
+ if not ok:
814
+ logger.error("Could not switch to Gemini — skipping this cycle")
815
+ continue
816
+
817
+ # 3. Collect samples from Gemini-powered sim
818
+ logger.info(f"Collecting for {collect_minutes} min with Gemini at {gemini_prob:.0%} probability...")
819
+
820
+ # collect() creates its own client
821
+ n_samples = await collect(
822
+ base_url=base_url,
823
+ duration_minutes=collect_minutes,
824
+ poll_interval=3.0,
825
+ )
826
+ logger.info(f"Collected {n_samples:,} samples this cycle")
827
+
828
+ # 4. Switch back to NN + restore default probability
829
+ async with httpx.AsyncClient(base_url=base_url) as client:
830
+ await switch_provider(client, "nn", 1.0)
831
+
832
+ # 5. Count Gemini-sourced samples
833
+ gemini_samples = 0
834
+ if SAMPLES_FILE.exists():
835
+ with open(SAMPLES_FILE) as f:
836
+ for line in f:
837
+ if '"source": "gemini"' in line or '"source":"gemini"' in line:
838
+ gemini_samples += 1
839
+ logger.info(f"Total Gemini-sourced samples in file: {gemini_samples:,}")
840
+
841
+ if gemini_samples < 50:
842
+ logger.warning("Too few Gemini samples — skipping training this cycle")
843
+ continue
844
+
845
+ # 6. Train (Gemini samples get 3x weight in the training loop)
846
+ logger.info("Starting retraining...")
847
+ best_acc = train(epochs=epochs)
848
+ logger.info(f"Training done — best accuracy: {best_acc:.1%}")
849
+
850
+ # 7. Push improved model
851
+ if os.environ.get("HF_TOKEN"):
852
+ logger.info("Pushing improved model to HF Hub...")
853
+ push(repo_id=repo_id)
854
+ else:
855
+ logger.warning("HF_TOKEN not set — skipping push")
856
+
857
+ logger.info(f"Cycle {cycle} complete! Next cycle at midnight.")
858
+
859
+
860
  # ════════════════════════════════════════════════════════════════════════
861
  # CLI
862
  # ════════════════════════════════════════════════════════════════════════
863
 
864
  def main():
865
  parser = argparse.ArgumentParser(description="Soci Agent NN — Self-Improvement Pipeline")
866
+ parser.add_argument("mode", choices=["collect", "train", "push", "all", "scheduled"],
867
+ help="collect=watch live sim, train=retrain NN, push=upload to HF, "
868
+ "all=full pipeline, scheduled=nightly Gemini cycle")
869
  parser.add_argument("--url", default="https://raymelius-soci2.hf.space",
870
  help="Live simulation URL (default: HF Space)")
871
  parser.add_argument("--minutes", type=int, default=60,
872
  help="Collection duration in minutes (default: 60)")
873
+ parser.add_argument("--collect-minutes", type=int, default=120,
874
+ help="Scheduled mode: collection duration in minutes (default: 120)")
875
+ parser.add_argument("--gemini-prob", type=float, default=0.50,
876
+ help="Scheduled mode: LLM probability during Gemini collection (default: 0.50)")
877
  parser.add_argument("--epochs", type=int, default=20,
878
  help="Training epochs (default: 20)")
879
  parser.add_argument("--repo", default="RayMelius/soci-agent-nn",
 
889
  if args.mode in ("push", "all"):
890
  push(repo_id=args.repo)
891
 
892
+ if args.mode == "scheduled":
893
+ asyncio.run(scheduled(
894
+ base_url=args.url,
895
+ collect_minutes=args.collect_minutes,
896
+ epochs=args.epochs,
897
+ repo_id=args.repo,
898
+ gemini_prob=args.gemini_prob,
899
+ ))
900
+
901
 
902
  if __name__ == "__main__":
903
  main()
scripts/nn_train.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Soci Agent NN — Local Training Script
3
+
4
+ Equivalent to notebooks/soci_agent_nn.ipynb but runs as a standalone script.
5
+ Trains the SociAgentTransformer, exports to ONNX, and optionally pushes to HF Hub.
6
+
7
+ Usage:
8
+ python scripts/nn_train.py # Train from scratch (synthetic data)
9
+ python scripts/nn_train.py --data data/nn_training # Train on collected + synthetic data
10
+ python scripts/nn_train.py --push # Train and push to HF Hub
11
+ python scripts/nn_train.py --epochs 50 --lr 1e-4 # Custom hyperparameters
12
+ python scripts/nn_train.py --resume # Resume from existing weights
13
+
14
+ Requires: pip install torch onnx onnxruntime numpy huggingface_hub
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import random
25
+ import sys
26
+ import time
27
+ from collections import Counter
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
35
+ stream=sys.stdout,
36
+ )
37
+ logger = logging.getLogger("nn_train")
38
+
39
+ # ── Paths ────────────────────────────────────────────────────────────────
40
+ SCRIPT_DIR = Path(__file__).parent
41
+ PROJECT_DIR = SCRIPT_DIR.parent
42
+ MODEL_DIR = PROJECT_DIR / "models"
43
+ DATA_DIR = PROJECT_DIR / "data" / "nn_training"
44
+ SAMPLES_FILE = DATA_DIR / "collected_samples.jsonl"
45
+
46
+ # ══════════════════════════════════════════════════════════════════════════
47
+ # 1. Domain Constants — must match the Soci simulation
48
+ # ══════════════════════════════════════════════════════════════════════════
49
+
50
+ ACTION_TYPES = ["move", "work", "eat", "sleep", "talk", "exercise", "shop", "relax", "wander"]
51
+ ACTION_TO_IDX = {a: i for i, a in enumerate(ACTION_TYPES)}
52
+ NUM_ACTIONS = len(ACTION_TYPES)
53
+
54
+ LOCATIONS = [
55
+ # Residential (17)
56
+ "house_elena", "house_marcus", "house_helen", "house_diana", "house_kai",
57
+ "house_priya", "house_james", "house_rosa", "house_yuki", "house_frank",
58
+ "apartment_block_1", "apartment_block_2", "apartment_block_3",
59
+ "apt_northeast", "apt_northwest", "apt_southeast", "apt_southwest",
60
+ # Commercial (8)
61
+ "cafe", "grocery", "bar", "restaurant", "bakery", "cinema", "diner", "pharmacy",
62
+ # Work (5)
63
+ "office", "office_tower", "factory", "school", "hospital",
64
+ # Public (8)
65
+ "park", "gym", "library", "church", "town_square", "sports_field",
66
+ "street_north", "street_south", "street_east", "street_west",
67
+ ]
68
+ LOC_TO_IDX = {loc: i for i, loc in enumerate(LOCATIONS)}
69
+ NUM_LOCATIONS = len(LOCATIONS)
70
+
71
+ # Zone encoding
72
+ LOC_ZONE = {}
73
+ for _loc in LOCATIONS:
74
+ if _loc.startswith(("house_", "apartment_", "apt_")):
75
+ LOC_ZONE[_loc] = 0
76
+ elif _loc in ("cafe", "grocery", "bar", "restaurant", "bakery", "cinema", "diner", "pharmacy"):
77
+ LOC_ZONE[_loc] = 1
78
+ elif _loc in ("office", "office_tower", "factory", "school", "hospital"):
79
+ LOC_ZONE[_loc] = 2
80
+ else:
81
+ LOC_ZONE[_loc] = 3
82
+
83
+ ACTION_NEEDS = {
84
+ "work": {"purpose": 0.3},
85
+ "eat": {"hunger": 0.5},
86
+ "sleep": {"energy": 0.6},
87
+ "talk": {"social": 0.3},
88
+ "exercise": {"energy": -0.1, "fun": 0.2, "comfort": 0.1},
89
+ "shop": {"hunger": 0.1, "comfort": 0.1},
90
+ "relax": {"energy": 0.1, "fun": 0.2, "comfort": 0.2},
91
+ "wander": {"fun": 0.1},
92
+ "move": {},
93
+ }
94
+
95
+ ACTION_DURATIONS = {"move": 1, "work": 4, "eat": 2, "sleep": 8, "talk": 2, "exercise": 3, "shop": 2, "relax": 2, "wander": 1}
96
+ NEED_NAMES = ["hunger", "energy", "social", "purpose", "comfort", "fun"]
97
+ PERSONALITY_NAMES = ["openness", "conscientiousness", "extraversion", "agreeableness", "neuroticism"]
98
+
99
+ NUM_TIME_PERIODS = 7
100
+ FEATURE_DIM = 47
101
+
102
+
103
+ # ══════════════════════════════════════════════════════════════════════════
104
+ # 2. Personas — 20 Soci characters (from personas.yaml)
105
+ # ══════════════════════════════════════════════════════════════════════════
106
+
107
+ PERSONAS = [
108
+ {"id": "elena", "name": "Elena Vasquez", "age": 34, "occ": "software engineer", "O": 8, "C": 7, "E": 4, "A": 6, "N": 5, "home": "house_elena", "work": "office"},
109
+ {"id": "lila", "name": "Lila Santos", "age": 33, "occ": "artist", "O":10, "C": 3, "E": 6, "A": 7, "N": 7, "home": "house_elena", "work": "library"},
110
+ {"id": "marcus", "name": "Marcus Chen-Williams", "age": 32, "occ": "personal trainer", "O": 6, "C": 7, "E": 9, "A": 5, "N": 3, "home": "house_marcus", "work": "gym"},
111
+ {"id": "zoe", "name": "Zoe Chen-Williams", "age": 19, "occ": "college student", "O": 8, "C": 4, "E": 8, "A": 6, "N": 7, "home": "house_marcus", "work": "library"},
112
+ {"id": "helen", "name": "Helen Park", "age": 68, "occ": "retired librarian", "O": 7, "C": 6, "E": 3, "A": 8, "N": 4, "home": "house_helen", "work": "library"},
113
+ {"id": "alice", "name": "Alice Fontaine", "age": 58, "occ": "retired accountant", "O": 5, "C": 8, "E": 5, "A": 8, "N": 3, "home": "house_helen", "work": "bakery"},
114
+ {"id": "diana", "name": "Diana Delgado", "age": 42, "occ": "grocery store owner", "O": 4, "C": 8, "E": 5, "A": 6, "N": 4, "home": "house_diana", "work": "grocery"},
115
+ {"id": "marco", "name": "Marco Delgado", "age": 16, "occ": "high school student", "O": 9, "C": 4, "E": 6, "A": 4, "N": 6, "home": "house_diana", "work": "school"},
116
+ {"id": "kai", "name": "Kai Okonkwo", "age": 22, "occ": "barista", "O": 9, "C": 3, "E": 8, "A": 5, "N": 5, "home": "house_kai", "work": "cafe"},
117
+ {"id": "priya", "name": "Priya Sharma", "age": 38, "occ": "doctor", "O": 7, "C": 8, "E": 5, "A": 7, "N": 6, "home": "house_priya", "work": "hospital"},
118
+ {"id": "nina", "name": "Nina Volkov", "age": 29, "occ": "real estate agent", "O": 5, "C": 7, "E": 8, "A": 5, "N": 5, "home": "house_priya", "work": "office"},
119
+ {"id": "james", "name": "James O'Brien", "age": 40, "occ": "bar owner", "O": 6, "C": 5, "E": 7, "A": 6, "N": 4, "home": "house_james", "work": "bar"},
120
+ {"id": "theo", "name": "Theo Blackwood", "age": 45, "occ": "construction worker", "O": 3, "C": 8, "E": 4, "A": 5, "N": 5, "home": "house_james", "work": "factory"},
121
+ {"id": "rosa", "name": "Rosa Martelli", "age": 62, "occ": "restaurant owner", "O": 5, "C": 7, "E": 7, "A": 9, "N": 4, "home": "house_rosa", "work": "restaurant"},
122
+ {"id": "omar", "name": "Omar Hassan", "age": 50, "occ": "taxi driver", "O": 6, "C": 6, "E": 7, "A": 7, "N": 4, "home": "house_rosa", "work": "restaurant"},
123
+ {"id": "yuki", "name": "Yuki Tanaka", "age": 26, "occ": "yoga instructor", "O": 8, "C": 6, "E": 5, "A": 9, "N": 3, "home": "house_yuki", "work": "gym"},
124
+ {"id": "devon", "name": "Devon Reeves", "age": 30, "occ": "freelance journalist", "O": 9, "C": 5, "E": 6, "A": 5, "N": 6, "home": "house_yuki", "work": "office"},
125
+ {"id": "frank", "name": "Frank Kowalski", "age": 72, "occ": "retired mechanic", "O": 3, "C": 6, "E": 4, "A": 4, "N": 5, "home": "house_frank", "work": "bar"},
126
+ {"id": "george", "name": "George Adeyemi", "age": 47, "occ": "night shift security", "O": 5, "C": 7, "E": 3, "A": 6, "N": 4, "home": "house_frank", "work": "factory"},
127
+ {"id": "sam", "name": "Sam Torres", "age": 35, "occ": "elementary school teacher", "O": 6, "C": 8, "E": 3, "A": 7, "N": 5, "home": "house_frank", "work": "school"},
128
+ ]
129
+
130
+
131
+ # ══════════════════════════════════════════════════════════════════════════
132
+ # 3. Feature Encoding
133
+ # ══════════════════════════════════════════════════════════════════════════
134
+
135
+ def _time_period(hour: int) -> int:
136
+ if hour < 6: return 0
137
+ if hour < 9: return 1
138
+ if hour < 12: return 2
139
+ if hour < 14: return 3
140
+ if hour < 18: return 4
141
+ if hour < 22: return 5
142
+ return 6
143
+
144
+
145
+ def encode_features(
146
+ persona: dict, hour: int, minute: int, day: int,
147
+ needs: dict, mood: float, current_loc: str,
148
+ num_people_here: int = 0,
149
+ ) -> list[float]:
150
+ """Encode agent state into 47-dim feature vector."""
151
+ f: list[float] = []
152
+ # Personality (5)
153
+ f.append(persona.get("O", persona.get("openness", 5)) / 10.0)
154
+ f.append(persona.get("C", persona.get("conscientiousness", 5)) / 10.0)
155
+ f.append(persona.get("E", persona.get("extraversion", 5)) / 10.0)
156
+ f.append(persona.get("A", persona.get("agreeableness", 5)) / 10.0)
157
+ f.append(persona.get("N", persona.get("neuroticism", 5)) / 10.0)
158
+ # Age (1)
159
+ f.append(persona.get("age", 30) / 100.0)
160
+ # Time cyclical (4)
161
+ f.append(math.sin(2 * math.pi * hour / 24))
162
+ f.append(math.cos(2 * math.pi * hour / 24))
163
+ f.append(math.sin(2 * math.pi * minute / 60))
164
+ f.append(math.cos(2 * math.pi * minute / 60))
165
+ # Day (2)
166
+ dow = ((day - 1) % 7)
167
+ f.append(dow / 7.0)
168
+ f.append(1.0 if dow >= 5 else 0.0)
169
+ # Needs (6)
170
+ for n in NEED_NAMES:
171
+ f.append(needs.get(n, 0.5))
172
+ # Mood (1)
173
+ f.append(max(-1.0, min(1.0, mood)))
174
+ # Urgency (2)
175
+ vals = [needs.get(n, 0.5) for n in NEED_NAMES]
176
+ urgent_idx = int(np.argmin(vals))
177
+ f.append(urgent_idx / 5.0)
178
+ f.append(1.0 if any(v < 0.15 for v in vals) else 0.0)
179
+ # Location zone (1)
180
+ zone = LOC_ZONE.get(current_loc, 3)
181
+ f.append(zone / 3.0)
182
+ # Home/work flags (2)
183
+ home = persona.get("home", persona.get("home_location", ""))
184
+ work = persona.get("work", persona.get("work_location", ""))
185
+ f.append(1.0 if current_loc == home else 0.0)
186
+ f.append(1.0 if current_loc == work else 0.0)
187
+ # People density (1)
188
+ f.append(min(num_people_here / 10.0, 1.0))
189
+ # Location type one-hot (6)
190
+ loc_oh = [0.0] * 6
191
+ if current_loc.startswith(("house_", "apartment_", "apt_")):
192
+ loc_oh[0] = 1.0
193
+ elif zone == 1:
194
+ loc_oh[1] = 1.0
195
+ elif zone == 2:
196
+ loc_oh[2] = 1.0
197
+ elif current_loc.startswith("street_"):
198
+ loc_oh[4] = 1.0
199
+ else:
200
+ loc_oh[3] = 1.0
201
+ if current_loc == home:
202
+ loc_oh[5] = 1.0
203
+ f.extend(loc_oh)
204
+ # Time period one-hot (7)
205
+ tp = [0.0] * NUM_TIME_PERIODS
206
+ tp[_time_period(hour)] = 1.0
207
+ f.extend(tp)
208
+ # Last action one-hot (9) — random for synthetic, zeros for real
209
+ last_action_oh = [0.0] * NUM_ACTIONS
210
+ if random.random() < 0.8:
211
+ last_action_oh[random.randint(0, NUM_ACTIONS - 1)] = 1.0
212
+ f.extend(last_action_oh)
213
+ return f
214
+
215
+
216
+ # ══════════════════════════════════════════════════════════════════════════
217
+ # 4. Synthetic Data Generator
218
+ # ══════════════════════════════════════════════════════════════════════════
219
+
220
+ def generate_action_example(persona: dict) -> dict:
221
+ """Generate one training example with rule-based labels."""
222
+ hour = random.randint(0, 23)
223
+ minute = random.choice([0, 15, 30, 45])
224
+ day = random.randint(1, 30)
225
+ is_weekend = ((day - 1) % 7) >= 5
226
+
227
+ # Random needs (15% chance of critical)
228
+ needs = {}
229
+ for n in NEED_NAMES:
230
+ if random.random() < 0.15:
231
+ needs[n] = round(random.uniform(0.0, 0.2), 2)
232
+ else:
233
+ needs[n] = round(random.uniform(0.2, 1.0), 2)
234
+
235
+ mood = round(random.uniform(-1.0, 1.0), 2)
236
+ current_loc = random.choice(LOCATIONS)
237
+
238
+ # --- Determine action using rule-based logic ---
239
+ # Priority 1: Critical needs
240
+ urgent = [(n, v) for n, v in needs.items() if v < 0.15]
241
+ urgent.sort(key=lambda x: x[1])
242
+
243
+ action = None
244
+ target_loc = current_loc
245
+ duration = 1
246
+
247
+ if urgent:
248
+ need_name = urgent[0][0]
249
+ if need_name == "hunger":
250
+ action = "eat"
251
+ target_loc = random.choice(["cafe", "restaurant", "grocery", "bakery", "diner", persona["home"]])
252
+ duration = 2
253
+ elif need_name == "energy":
254
+ action = "sleep"
255
+ target_loc = persona["home"]
256
+ duration = random.choice([4, 6, 8])
257
+ elif need_name == "social":
258
+ action = "talk"
259
+ target_loc = random.choice(["cafe", "bar", "park", "town_square", current_loc])
260
+ duration = 2
261
+ elif need_name == "purpose":
262
+ action = "work"
263
+ target_loc = persona["work"]
264
+ duration = 4
265
+ elif need_name == "comfort":
266
+ action = "relax"
267
+ target_loc = random.choice([persona["home"], "park", "library"])
268
+ duration = 2
269
+ elif need_name == "fun":
270
+ action = random.choice(["relax", "exercise", "wander"])
271
+ target_loc = random.choice(["park", "gym", "cinema", "bar", "sports_field"])
272
+ duration = 2
273
+
274
+ # Priority 2: Time-of-day patterns
275
+ if action is None:
276
+ period = _time_period(hour)
277
+
278
+ if period == 0: # Late night
279
+ action = "sleep"
280
+ target_loc = persona["home"]
281
+ duration = 8
282
+
283
+ elif period == 1: # Early morning
284
+ r = random.random()
285
+ if needs["hunger"] < 0.5:
286
+ action = "eat"
287
+ target_loc = random.choice(["cafe", "bakery", persona["home"]])
288
+ duration = 2
289
+ elif r < 0.3 and persona["E"] >= 6:
290
+ action = "exercise"
291
+ target_loc = random.choice(["gym", "park", "sports_field"])
292
+ duration = 3
293
+ else:
294
+ action = "move"
295
+ target_loc = persona["work"]
296
+ duration = 1
297
+
298
+ elif period in (2, 4): # Mid-morning / Afternoon
299
+ if is_weekend:
300
+ r = random.random()
301
+ if r < 0.25:
302
+ action = "relax"
303
+ target_loc = random.choice(["park", "cafe", "library", persona["home"]])
304
+ elif r < 0.45 and persona["E"] >= 6:
305
+ action = "talk"
306
+ target_loc = random.choice(["cafe", "park", "town_square"])
307
+ elif r < 0.6:
308
+ action = "shop"
309
+ target_loc = random.choice(["grocery", "pharmacy"])
310
+ elif r < 0.8:
311
+ action = "exercise"
312
+ target_loc = random.choice(["gym", "park", "sports_field"])
313
+ else:
314
+ action = "wander"
315
+ target_loc = random.choice(["park", "town_square", "street_north", "street_south"])
316
+ duration = random.choice([2, 3])
317
+ else:
318
+ work_prob = 0.5 + persona["C"] * 0.05
319
+ if random.random() < work_prob:
320
+ action = "work"
321
+ target_loc = persona["work"]
322
+ duration = 4
323
+ else:
324
+ action = random.choice(["wander", "relax", "talk"])
325
+ target_loc = random.choice(["cafe", "park", "town_square"])
326
+ duration = 2
327
+
328
+ elif period == 3: # Midday / lunch
329
+ if needs["hunger"] < 0.6:
330
+ action = "eat"
331
+ target_loc = random.choice(["cafe", "restaurant", "bakery", "diner", "park"])
332
+ duration = 2
333
+ else:
334
+ action = "relax"
335
+ target_loc = random.choice(["park", "cafe"])
336
+ duration = 1
337
+
338
+ elif period == 5: # Evening
339
+ r = random.random()
340
+ social_bias = persona["E"] / 10.0
341
+ if r < social_bias * 0.5:
342
+ action = "talk"
343
+ target_loc = random.choice(["bar", "restaurant", "park", "cafe"])
344
+ duration = 2
345
+ elif r < 0.4:
346
+ action = "eat"
347
+ target_loc = random.choice(["restaurant", "bar", "diner", persona["home"]])
348
+ duration = 2
349
+ elif r < 0.55:
350
+ action = "exercise"
351
+ target_loc = random.choice(["gym", "park", "sports_field"])
352
+ duration = 3
353
+ elif r < 0.7:
354
+ action = "relax"
355
+ target_loc = random.choice(["cinema", "bar", persona["home"], "library"])
356
+ duration = 2
357
+ else:
358
+ action = "relax"
359
+ target_loc = persona["home"]
360
+ duration = 2
361
+
362
+ elif period == 6: # Night
363
+ if needs["energy"] < 0.4:
364
+ action = "sleep"
365
+ target_loc = persona["home"]
366
+ duration = 8
367
+ else:
368
+ action = "relax"
369
+ target_loc = persona["home"]
370
+ duration = 2
371
+
372
+ # 30% chance of picking "move" if target != current
373
+ if target_loc != current_loc and action != "move":
374
+ if random.random() < 0.3:
375
+ action = "move"
376
+ duration = 1
377
+
378
+ features = encode_features(
379
+ persona=persona, hour=hour, minute=minute, day=day,
380
+ needs=needs, mood=mood, current_loc=current_loc,
381
+ num_people_here=random.randint(0, 8),
382
+ )
383
+
384
+ return {
385
+ "features": features,
386
+ "action_idx": ACTION_TO_IDX[action],
387
+ "target_loc_idx": LOC_TO_IDX.get(target_loc, 0),
388
+ "duration": min(max(duration, 1), 8),
389
+ }
390
+
391
+
392
+ def generate_dataset(n: int) -> list[dict]:
393
+ """Generate n synthetic training examples."""
394
+ data = []
395
+ for _ in range(n):
396
+ persona = random.choice(PERSONAS)
397
+ data.append(generate_action_example(persona))
398
+ return data
399
+
400
+
401
+ # ══════════════════════════════════════════════════════════════════════════
402
+ # 5. Model Architecture — SociAgentTransformer
403
+ # ══════════════════════════════════════════════════════════════════════════
404
+
405
+ def build_model():
406
+ """Build the SociAgentTransformer model."""
407
+ import torch
408
+ import torch.nn as nn
409
+ import torch.nn.functional as F
410
+
411
+ class FeatureTokenizer(nn.Module):
412
+ GROUPS = [
413
+ ("personality", 0, 6),
414
+ ("time", 6, 12),
415
+ ("needs", 12, 21),
416
+ ("location", 21, 31),
417
+ ("time_period", 31, 38),
418
+ ("last_action", 38, 47),
419
+ ]
420
+
421
+ def __init__(self, d_model: int):
422
+ super().__init__()
423
+ self.projections = nn.ModuleList()
424
+ for name, start, end in self.GROUPS:
425
+ self.projections.append(nn.Sequential(
426
+ nn.Linear(end - start, d_model),
427
+ nn.LayerNorm(d_model),
428
+ nn.GELU(),
429
+ ))
430
+ self.pos_embed = nn.Parameter(torch.randn(1, len(self.GROUPS), d_model) * 0.02)
431
+
432
+ def forward(self, features):
433
+ tokens = []
434
+ for i, (_, start, end) in enumerate(self.GROUPS):
435
+ tokens.append(self.projections[i](features[:, start:end]))
436
+ tokens = torch.stack(tokens, dim=1)
437
+ return tokens + self.pos_embed
438
+
439
+ class MoEFeedForward(nn.Module):
440
+ def __init__(self, d_model, d_ff, num_experts=4, top_k=2):
441
+ super().__init__()
442
+ self.num_experts = num_experts
443
+ self.top_k = top_k
444
+ self.gate = nn.Linear(d_model, num_experts, bias=False)
445
+ self.experts = nn.ModuleList([
446
+ nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
447
+ for _ in range(num_experts)
448
+ ])
449
+
450
+ def forward(self, x):
451
+ B, S, D = x.shape
452
+ gate_probs = F.softmax(self.gate(x), dim=-1)
453
+ top_k_probs, top_k_idx = gate_probs.topk(self.top_k, dim=-1)
454
+ top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
455
+ output = torch.zeros_like(x)
456
+ for k in range(self.top_k):
457
+ eidx = top_k_idx[:, :, k]
458
+ w = top_k_probs[:, :, k].unsqueeze(-1)
459
+ for e in range(self.num_experts):
460
+ mask = (eidx == e).unsqueeze(-1)
461
+ if mask.any():
462
+ output = output + mask.float() * w * self.experts[e](x)
463
+ return output
464
+
465
+ class TransformerBlock(nn.Module):
466
+ def __init__(self, d_model, nhead, d_ff, num_experts=4, dropout=0.1):
467
+ super().__init__()
468
+ self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
469
+ self.norm1 = nn.LayerNorm(d_model)
470
+ self.moe_ff = MoEFeedForward(d_model, d_ff, num_experts)
471
+ self.norm2 = nn.LayerNorm(d_model)
472
+ self.dropout = nn.Dropout(dropout)
473
+
474
+ def forward(self, x):
475
+ attn_out, _ = self.attn(x, x, x)
476
+ x = self.norm1(x + self.dropout(attn_out))
477
+ ff_out = self.moe_ff(x)
478
+ return self.norm2(x + self.dropout(ff_out))
479
+
480
+ class SociAgentTransformer(nn.Module):
481
+ def __init__(self, d_model=128, nhead=8, num_layers=4, d_ff=256,
482
+ num_experts=4, dropout=0.1):
483
+ super().__init__()
484
+ self.tokenizer = FeatureTokenizer(d_model)
485
+ self.layers = nn.ModuleList([
486
+ TransformerBlock(d_model, nhead, d_ff, num_experts, dropout)
487
+ for _ in range(num_layers)
488
+ ])
489
+ self.cls_query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
490
+ self.cls_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
491
+ self.cls_norm = nn.LayerNorm(d_model)
492
+ self.action_head = nn.Sequential(
493
+ nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout),
494
+ nn.Linear(d_model, NUM_ACTIONS),
495
+ )
496
+ self.location_head = nn.Sequential(
497
+ nn.Linear(d_model + NUM_ACTIONS, d_model), nn.GELU(), nn.Dropout(dropout),
498
+ nn.Linear(d_model, NUM_LOCATIONS),
499
+ )
500
+ self.duration_head = nn.Sequential(
501
+ nn.Linear(d_model + NUM_ACTIONS, d_model // 2), nn.GELU(),
502
+ nn.Linear(d_model // 2, 1),
503
+ )
504
+
505
+ def forward(self, features):
506
+ tokens = self.tokenizer(features)
507
+ for layer in self.layers:
508
+ tokens = layer(tokens)
509
+ B = features.shape[0]
510
+ cls = self.cls_query.expand(B, -1, -1)
511
+ cls_out, _ = self.cls_attn(cls, tokens, tokens)
512
+ h = self.cls_norm(cls_out.squeeze(1))
513
+ action_logits = self.action_head(h)
514
+ action_probs = F.softmax(action_logits.detach(), dim=-1)
515
+ h_a = torch.cat([h, action_probs], dim=-1)
516
+ location_logits = self.location_head(h_a)
517
+ duration = torch.sigmoid(self.duration_head(h_a)) * 7.0 + 1.0
518
+ return {
519
+ "action_logits": action_logits,
520
+ "location_logits": location_logits,
521
+ "duration": duration.squeeze(-1),
522
+ }
523
+
524
+ return SociAgentTransformer()
525
+
526
+
527
+ # ══════════════════════════════════════════════════════════════════════════
528
+ # 6. Training
529
+ # ══════════════════════════════════════════════════════════════════════════
530
+
531
+ def train(
532
+ epochs: int = 30,
533
+ batch_size: int = 512,
534
+ lr: float = 3e-4,
535
+ num_train: int = 100_000,
536
+ num_val: int = 10_000,
537
+ data_dir: str | None = None,
538
+ resume: bool = False,
539
+ push: bool = False,
540
+ repo_id: str = "RayMelius/soci-agent-nn",
541
+ ):
542
+ """Full training pipeline: generate/load data, train, export ONNX, optionally push."""
543
+ import torch
544
+ import torch.nn as nn
545
+ from torch.utils.data import Dataset, DataLoader
546
+
547
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
548
+ logger.info(f"Device: {DEVICE}")
549
+ if DEVICE.type == "cuda":
550
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
551
+
552
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
553
+ best_pt = MODEL_DIR / "soci_agent_best.pt"
554
+ onnx_path = MODEL_DIR / "soci_agent.onnx"
555
+
556
+ # ── Load / generate data ─────────────────────────────────────────
557
+ collected = []
558
+ source_counts: dict[str, int] = {}
559
+
560
+ # Load collected samples from live sim (if available)
561
+ samples_file = Path(data_dir) / "collected_samples.jsonl" if data_dir else SAMPLES_FILE
562
+ if samples_file.exists():
563
+ with open(samples_file) as f:
564
+ for line in f:
565
+ line = line.strip()
566
+ if line:
567
+ sample = json.loads(line)
568
+ collected.append(sample)
569
+ src = sample.get("source", "unknown")
570
+ source_counts[src] = source_counts.get(src, 0) + 1
571
+ logger.info(f"Loaded {len(collected):,} collected samples — sources: {source_counts}")
572
+
573
+ # Oversample LLM-sourced data 3x (higher quality than NN/routine)
574
+ llm_sources = {"gemini", "claude", "groq"}
575
+ llm_samples = [s for s in collected if s.get("source", "") in llm_sources]
576
+ if llm_samples:
577
+ logger.info(f"Oversampling {len(llm_samples):,} LLM-sourced samples (3x weight)")
578
+ collected.extend(llm_samples * 2)
579
+
580
+ # Generate synthetic data to fill up to target size
581
+ total_target = num_train + num_val
582
+ synthetic_needed = max(0, total_target - len(collected))
583
+ if synthetic_needed > 0:
584
+ logger.info(f"Generating {synthetic_needed:,} synthetic samples...")
585
+ random.seed(42)
586
+ collected.extend(generate_dataset(synthetic_needed))
587
+
588
+ random.shuffle(collected)
589
+ split = int(len(collected) * 0.9)
590
+ train_data = collected[:split]
591
+ val_data = collected[split:]
592
+
593
+ # ── Dataset ──────────────────────────────────────────────────────
594
+ class ActionDataset(Dataset):
595
+ def __init__(self, data):
596
+ self.features = torch.tensor([d["features"] for d in data], dtype=torch.float32)
597
+ self.actions = torch.tensor([d["action_idx"] for d in data], dtype=torch.long)
598
+ self.locations = torch.tensor([d["target_loc_idx"] for d in data], dtype=torch.long)
599
+ self.durations = torch.tensor([d["duration"] for d in data], dtype=torch.float32)
600
+
601
+ def __len__(self):
602
+ return len(self.actions)
603
+
604
+ def __getitem__(self, idx):
605
+ return {
606
+ "features": self.features[idx],
607
+ "action": self.actions[idx],
608
+ "location": self.locations[idx],
609
+ "duration": self.durations[idx],
610
+ }
611
+
612
+ train_ds = ActionDataset(train_data)
613
+ val_ds = ActionDataset(val_data)
614
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
615
+ num_workers=0, pin_memory=(DEVICE.type == "cuda"))
616
+ val_loader = DataLoader(val_ds, batch_size=1024, shuffle=False,
617
+ num_workers=0, pin_memory=(DEVICE.type == "cuda"))
618
+ logger.info(f"Train: {len(train_ds):,}, Val: {len(val_ds):,}")
619
+
620
+ # ── Model ────────────────────────────────────────────────────────
621
+ model = build_model().to(DEVICE)
622
+
623
+ total_params = sum(p.numel() for p in model.parameters())
624
+ logger.info(f"Model parameters: {total_params:,} ({total_params * 4 / 1024 / 1024:.1f} MB fp32)")
625
+
626
+ if resume and best_pt.exists():
627
+ model.load_state_dict(torch.load(str(best_pt), map_location=DEVICE, weights_only=True))
628
+ logger.info(f"Resumed from {best_pt}")
629
+
630
+ # ── Class weights ────────────────────────────────────────────────
631
+ action_counts = torch.zeros(NUM_ACTIONS)
632
+ for d in train_data:
633
+ action_counts[d["action_idx"]] += 1
634
+ action_weights = 1.0 / (action_counts + 1.0)
635
+ action_weights = action_weights / action_weights.sum() * NUM_ACTIONS
636
+ action_weights = action_weights.to(DEVICE)
637
+
638
+ logger.info("Action distribution:")
639
+ for idx in range(NUM_ACTIONS):
640
+ count = int(action_counts[idx])
641
+ pct = count / len(train_data) * 100
642
+ logger.info(f" {ACTION_TYPES[idx]:>10s}: {count:6d} ({pct:.1f}%)")
643
+
644
+ # ── Loss & optimizer ─────────────────────────────────────────────
645
+ action_loss_fn = nn.CrossEntropyLoss(weight=action_weights)
646
+ location_loss_fn = nn.CrossEntropyLoss()
647
+ duration_loss_fn = nn.MSELoss()
648
+
649
+ W_ACTION = 1.0
650
+ W_LOCATION = 0.5
651
+ W_DURATION = 0.2
652
+
653
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
654
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
655
+
656
+ logger.info(f"Training for {epochs} epochs, LR={lr}, batch_size={batch_size}")
657
+
658
+ # ── Training loop ────────────────────────────────────────────────
659
+ best_val_acc = 0.0
660
+ history = {"train_loss": [], "val_loss": [], "val_action_acc": [], "val_loc_acc": []}
661
+
662
+ for epoch in range(epochs):
663
+ # Train
664
+ model.train()
665
+ total_loss = 0.0
666
+ n_batches = 0
667
+ for batch in train_loader:
668
+ feat = batch["features"].to(DEVICE)
669
+ out = model(feat)
670
+ loss = (
671
+ W_ACTION * action_loss_fn(out["action_logits"], batch["action"].to(DEVICE))
672
+ + W_LOCATION * location_loss_fn(out["location_logits"], batch["location"].to(DEVICE))
673
+ + W_DURATION * duration_loss_fn(out["duration"], batch["duration"].to(DEVICE))
674
+ )
675
+ optimizer.zero_grad()
676
+ loss.backward()
677
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
678
+ optimizer.step()
679
+ total_loss += loss.item()
680
+ n_batches += 1
681
+ scheduler.step()
682
+ avg_train_loss = total_loss / n_batches
683
+
684
+ # Validate
685
+ model.eval()
686
+ val_loss = 0.0
687
+ correct_action = 0
688
+ correct_loc = 0
689
+ total = 0
690
+ with torch.no_grad():
691
+ for batch in val_loader:
692
+ feat = batch["features"].to(DEVICE)
693
+ out = model(feat)
694
+ loss = (
695
+ W_ACTION * action_loss_fn(out["action_logits"], batch["action"].to(DEVICE))
696
+ + W_LOCATION * location_loss_fn(out["location_logits"], batch["location"].to(DEVICE))
697
+ + W_DURATION * duration_loss_fn(out["duration"], batch["duration"].to(DEVICE))
698
+ )
699
+ val_loss += loss.item()
700
+ pred_action = out["action_logits"].argmax(dim=-1)
701
+ pred_loc = out["location_logits"].argmax(dim=-1)
702
+ correct_action += (pred_action == batch["action"].to(DEVICE)).sum().item()
703
+ correct_loc += (pred_loc == batch["location"].to(DEVICE)).sum().item()
704
+ total += feat.shape[0]
705
+
706
+ avg_val_loss = val_loss / len(val_loader)
707
+ action_acc = correct_action / total if total > 0 else 0
708
+ loc_acc = correct_loc / total if total > 0 else 0
709
+
710
+ history["train_loss"].append(avg_train_loss)
711
+ history["val_loss"].append(avg_val_loss)
712
+ history["val_action_acc"].append(action_acc)
713
+ history["val_loc_acc"].append(loc_acc)
714
+
715
+ if action_acc > best_val_acc:
716
+ best_val_acc = action_acc
717
+ torch.save(model.state_dict(), str(best_pt))
718
+
719
+ if (epoch + 1) % 5 == 0 or epoch == 0:
720
+ lr_now = scheduler.get_last_lr()[0]
721
+ logger.info(
722
+ f"Epoch {epoch+1:3d}/{epochs} | "
723
+ f"Train: {avg_train_loss:.4f} | "
724
+ f"Val: {avg_val_loss:.4f} | "
725
+ f"Act Acc: {action_acc:.1%} | "
726
+ f"Loc Acc: {loc_acc:.1%} | "
727
+ f"LR: {lr_now:.2e}"
728
+ )
729
+
730
+ logger.info(f"Best validation action accuracy: {best_val_acc:.1%}")
731
+
732
+ # ── Per-action accuracy ──────────────────────────────────────────
733
+ model.load_state_dict(torch.load(str(best_pt), map_location=DEVICE, weights_only=True))
734
+ model.eval()
735
+ cm = np.zeros((NUM_ACTIONS, NUM_ACTIONS), dtype=int)
736
+ with torch.no_grad():
737
+ for batch in val_loader:
738
+ feat = batch["features"].to(DEVICE)
739
+ out = model(feat)
740
+ preds = out["action_logits"].argmax(dim=-1).cpu().numpy()
741
+ labels = batch["action"].numpy()
742
+ for p, l in zip(preds, labels):
743
+ cm[l][p] += 1
744
+
745
+ logger.info("Per-action accuracy:")
746
+ for i, action in enumerate(ACTION_TYPES):
747
+ row_total = cm[i].sum()
748
+ correct = cm[i][i]
749
+ acc = correct / row_total if row_total > 0 else 0
750
+ logger.info(f" {action:>10s}: {acc:.1%} ({correct}/{row_total})")
751
+
752
+ # ── Test scenarios ───────────────────────────────────────────────
753
+ import torch.nn.functional as F
754
+
755
+ @torch.no_grad()
756
+ def predict(persona, hour, minute, day, needs, mood, loc, num_people=0):
757
+ features = encode_features(persona, hour, minute, day, needs, mood, loc, num_people)
758
+ feat_t = torch.tensor([features], dtype=torch.float32, device=DEVICE)
759
+ out = model(feat_t)
760
+ action_probs = F.softmax(out["action_logits"][0] / 0.7, dim=-1)
761
+ action_idx = action_probs.argmax().item()
762
+ loc_idx = out["location_logits"][0].argmax().item()
763
+ dur = max(1, min(8, round(out["duration"][0].item())))
764
+ return ACTION_TYPES[action_idx], LOCATIONS[loc_idx], dur, action_probs[action_idx].item()
765
+
766
+ logger.info("Test scenarios:")
767
+ a, l, d, c = predict(PERSONAS[0], 0, 30, 5,
768
+ {"hunger": 0.5, "energy": 0.05, "social": 0.4, "purpose": 0.6, "comfort": 0.3, "fun": 0.3},
769
+ -0.3, "office")
770
+ logger.info(f" Elena midnight exhausted: {a} -> {l} ({d} ticks, {c:.0%})")
771
+
772
+ a, l, d, c = predict(PERSONAS[2], 12, 30, 3,
773
+ {"hunger": 0.05, "energy": 0.7, "social": 0.5, "purpose": 0.6, "comfort": 0.5, "fun": 0.4},
774
+ 0.2, "gym", 5)
775
+ logger.info(f" Marcus lunchtime starving: {a} -> {l} ({d} ticks, {c:.0%})")
776
+
777
+ a, l, d, c = predict(PERSONAS[8], 10, 0, 6,
778
+ {"hunger": 0.6, "energy": 0.7, "social": 0.5, "purpose": 0.5, "comfort": 0.7, "fun": 0.4},
779
+ 0.5, "house_kai")
780
+ logger.info(f" Kai Saturday morning: {a} -> {l} ({d} ticks, {c:.0%})")
781
+
782
+ # ── Export to ONNX ───────────────────────────────────────────────
783
+ logger.info("Exporting to ONNX...")
784
+ model.cpu().eval()
785
+ dummy = torch.randn(1, FEATURE_DIM)
786
+ torch.onnx.export(
787
+ model, dummy, str(onnx_path),
788
+ input_names=["features"],
789
+ output_names=["action_logits", "location_logits", "duration"],
790
+ dynamic_axes={"features": {0: "batch"}},
791
+ opset_version=17,
792
+ dynamo=False,
793
+ )
794
+
795
+ # Verify ONNX
796
+ import onnx
797
+ onnx_model = onnx.load(str(onnx_path))
798
+ onnx.checker.check_model(onnx_model)
799
+ onnx_size = onnx_path.stat().st_size / 1024
800
+ logger.info(f"ONNX exported: {onnx_path} ({onnx_size:.0f} KB)")
801
+
802
+ # Benchmark ONNX
803
+ import onnxruntime as ort
804
+ session = ort.InferenceSession(str(onnx_path))
805
+ batch_input = np.random.randn(50, FEATURE_DIM).astype(np.float32)
806
+ start = time.perf_counter()
807
+ for _ in range(100):
808
+ session.run(None, {"features": batch_input})
809
+ elapsed = (time.perf_counter() - start) / 100
810
+ logger.info(f"ONNX inference (50 agents): {elapsed*1000:.1f} ms per batch")
811
+
812
+ # ── Save training stats ──────────────────────────────────────────
813
+ stats = {
814
+ "best_val_action_acc": best_val_acc,
815
+ "epochs": epochs,
816
+ "train_samples": len(train_ds),
817
+ "val_samples": len(val_ds),
818
+ "collected_samples": sum(source_counts.values()),
819
+ "source_counts": source_counts,
820
+ "model_size_kb": onnx_size,
821
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
822
+ "history": history,
823
+ }
824
+ stats_path = MODEL_DIR / "training_stats.json"
825
+ stats_path.write_text(json.dumps(stats, indent=2))
826
+ logger.info(f"Stats saved to {stats_path}")
827
+
828
+ # ── Push to HF Hub ───────────────────────────────────────────────
829
+ if push:
830
+ _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs, len(train_ds))
831
+
832
+ return best_val_acc
833
+
834
+
835
+ def _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs, num_train):
836
+ """Upload model files to HuggingFace Hub."""
837
+ from huggingface_hub import HfApi, login
838
+
839
+ token = os.environ.get("HF_TOKEN", "")
840
+ if not token:
841
+ logger.error("HF_TOKEN not set — cannot push. Export it: export HF_TOKEN=hf_...")
842
+ return
843
+
844
+ login(token=token)
845
+ api = HfApi()
846
+ api.create_repo(repo_id, exist_ok=True)
847
+
848
+ # Config
849
+ config = {
850
+ "architecture": "SociAgentTransformer",
851
+ "d_model": 128, "nhead": 8, "num_layers": 4, "d_ff": 256, "num_experts": 4,
852
+ "feature_dim": FEATURE_DIM, "num_actions": NUM_ACTIONS, "num_locations": NUM_LOCATIONS,
853
+ "action_types": ACTION_TYPES, "locations": LOCATIONS,
854
+ "action_durations": ACTION_DURATIONS, "need_names": NEED_NAMES,
855
+ "personality_names": PERSONALITY_NAMES,
856
+ "best_val_action_acc": best_val_acc,
857
+ "training_samples": num_train, "epochs": epochs,
858
+ }
859
+ config_path = MODEL_DIR / "config.json"
860
+ config_path.write_text(json.dumps(config, indent=2))
861
+
862
+ for local, remote in [
863
+ (onnx_path, "soci_agent.onnx"),
864
+ (best_pt, "soci_agent_best.pt"),
865
+ (config_path, "config.json"),
866
+ (stats_path, "training_stats.json"),
867
+ ]:
868
+ if local.exists():
869
+ api.upload_file(
870
+ path_or_fileobj=str(local),
871
+ path_in_repo=remote,
872
+ repo_id=repo_id,
873
+ commit_message=f"Train: acc={best_val_acc:.1%}, {epochs} epochs",
874
+ )
875
+ logger.info(f"Uploaded {remote}")
876
+
877
+ logger.info(f"Model pushed to https://huggingface.co/{repo_id}")
878
+
879
+
880
+ # ══════════════════════════════════════════════════════════════════════════
881
+ # CLI
882
+ # ══════════════════════════════════════════════════════════════════════════
883
+
884
+ def main():
885
+ parser = argparse.ArgumentParser(
886
+ description="Soci Agent NN — Local Training Script",
887
+ formatter_class=argparse.RawDescriptionHelpFormatter,
888
+ epilog="""Examples:
889
+ python scripts/nn_train.py # Train from scratch
890
+ python scripts/nn_train.py --resume --epochs 50 # Continue training
891
+ python scripts/nn_train.py --data data/nn_training # Use collected samples
892
+ python scripts/nn_train.py --push --repo RayMelius/soci-agent-nn # Train + push
893
+ """,
894
+ )
895
+ parser.add_argument("--epochs", type=int, default=30, help="Training epochs (default: 30)")
896
+ parser.add_argument("--batch-size", type=int, default=512, help="Batch size (default: 512)")
897
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate (default: 3e-4)")
898
+ parser.add_argument("--train-samples", type=int, default=100_000,
899
+ help="Number of synthetic training samples (default: 100000)")
900
+ parser.add_argument("--val-samples", type=int, default=10_000,
901
+ help="Number of validation samples (default: 10000)")
902
+ parser.add_argument("--data", type=str, default=None,
903
+ help="Path to directory with collected_samples.jsonl")
904
+ parser.add_argument("--resume", action="store_true",
905
+ help="Resume from existing weights in models/")
906
+ parser.add_argument("--push", action="store_true",
907
+ help="Push trained model to HuggingFace Hub")
908
+ parser.add_argument("--repo", default="RayMelius/soci-agent-nn",
909
+ help="HF Hub repo ID (default: RayMelius/soci-agent-nn)")
910
+ args = parser.parse_args()
911
+
912
+ train(
913
+ epochs=args.epochs,
914
+ batch_size=args.batch_size,
915
+ lr=args.lr,
916
+ num_train=args.train_samples,
917
+ num_val=args.val_samples,
918
+ data_dir=args.data,
919
+ resume=args.resume,
920
+ push=args.push,
921
+ repo_id=args.repo,
922
+ )
923
+
924
+
925
+ if __name__ == "__main__":
926
+ main()
src/soci/agents/routine.py CHANGED
@@ -134,6 +134,25 @@ class DailyRoutine:
134
  t = self._add(h, m, "relax", home, 2, "Morning routine — getting ready",
135
  {"comfort": 0.1, "energy": 0.05})
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # Breakfast
138
  h, m = t // 60, t % 60
139
  t = self._add(h, m, "eat", home, 2, "Having breakfast at home",
@@ -152,8 +171,8 @@ class DailyRoutine:
152
  f"{work_label} — morning block",
153
  {"purpose": 0.3})
154
 
155
- # Lunch — pick a food place or stay at work
156
- food_places = ["cafe", "restaurant", "grocery", "bakery"]
157
  lunch_spot = self._rng.choice(food_places)
158
  h, m = t // 60, t % 60
159
  t = self._add(h, m, "move", lunch_spot, 1, f"Walking to lunch at {lunch_spot}",
@@ -174,6 +193,25 @@ class DailyRoutine:
174
  f"{work_label} — afternoon block",
175
  {"purpose": 0.3})
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # Commute home
178
  h, m = t // 60, t % 60
179
  t = self._add(h, m, "move", home, 1, "Heading home",
@@ -310,10 +348,19 @@ class DailyRoutine:
310
  t = self._add(h, m, "move", place, 1, f"Going to {place}",
311
  {})
312
  act_ticks = self._rng.randint(3, 6)
313
- act_type = "exercise" if place == "gym" else "relax"
 
 
 
 
 
 
 
 
 
314
  h, m = t // 60, t % 60
315
  t = self._add(h, m, act_type, place, act_ticks,
316
- f"Hanging out at {place}",
317
  {"social": 0.2, "fun": 0.25})
318
  else:
319
  # Quiet afternoon
@@ -368,7 +415,8 @@ class DailyRoutine:
368
 
369
  if e >= 6:
370
  # Extroverts: go out, stay until ~30-45 min before sleep, then come home
371
- venue = self._rng.choice(["bar", "restaurant", "park", "cinema", "town_square"])
 
372
  h, m = t // 60, t % 60
373
  t = self._add(h, m, "move", venue, 1, f"Heading to {venue}", {})
374
  wind_down_start = sleep_t - self._rng.randint(2, 3) * 15
@@ -399,15 +447,21 @@ class DailyRoutine:
399
  def _add_leisure_block(self, persona: Persona, home: str,
400
  t: int, end_t: int) -> int:
401
  """Fill a leisure period with activities based on personality."""
402
- activities = []
 
 
403
  if persona.extraversion >= 6:
404
- activities.extend(["park", "cafe", "gym", "town_square", "sports_field"])
 
 
 
405
  else:
406
- activities.extend(["library", "park", "church"])
 
407
  if persona.conscientiousness >= 6:
408
- activities.append("gym")
409
  if persona.openness >= 6:
410
- activities.extend(["library", "park", "cinema"])
411
 
412
  dest = self._rng.choice(activities)
413
  available_ticks = max(0, (end_t - t) // 15)
@@ -424,26 +478,56 @@ class DailyRoutine:
424
  # Activity there
425
  act_ticks = min(available_ticks - 1, self._rng.randint(2, max(3, available_ticks - 1)))
426
  if act_ticks > 0:
427
- act_type = "exercise" if dest == "gym" else "relax"
 
 
 
 
428
  act_detail = {
429
- "park": "Taking a walk in the park",
430
- "cafe": "Hanging out at the cafe",
431
- "gym": "Working out at the gym",
432
- "library": "Reading at the library",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  "cinema": "Watching a movie",
434
- "town_square": "People-watching at the square",
435
- "sports_field": "Playing sports at the field",
 
 
 
 
 
 
 
 
 
 
436
  "church": "Quiet time at the church",
 
437
  }.get(dest, f"Spending time at {dest}")
438
  needs = {
439
- "park": {"fun": 0.2, "comfort": 0.1},
440
  "cafe": {"social": 0.2, "fun": 0.1},
441
  "gym": {"energy": -0.1, "fun": 0.2},
442
  "library": {"fun": 0.15, "comfort": 0.1},
443
  "cinema": {"fun": 0.3, "social": 0.1},
444
  "town_square": {"social": 0.2, "fun": 0.15},
445
- "sports_field": {"fun": 0.25, "energy": -0.1},
446
  "church": {"comfort": 0.2, "purpose": 0.1},
 
447
  }.get(dest, {"fun": 0.1})
448
  h, m = t // 60, t % 60
449
  t = self._add(h, m, act_type, dest, act_ticks, act_detail, needs)
 
134
  t = self._add(h, m, "relax", home, 2, "Morning routine — getting ready",
135
  {"comfort": 0.1, "energy": 0.05})
136
 
137
+ # Morning exercise for active personas (30% chance if conscientious or extravert)
138
+ if (persona.conscientiousness >= 7 or persona.extraversion >= 7) and self._rng.random() < 0.3:
139
+ morning_spot = self._rng.choice(["park", "park", "gym", "sports_field"])
140
+ morning_exercise = {
141
+ "park": self._rng.choice(["Morning jog in the park", "Early walk in the park"]),
142
+ "gym": "Morning gym session",
143
+ "sports_field": self._rng.choice(["Morning run at the sports field",
144
+ "Early workout at the sports field"]),
145
+ }.get(morning_spot, f"Morning exercise at {morning_spot}")
146
+ h, m = t // 60, t % 60
147
+ t = self._add(h, m, "move", morning_spot, 1, f"Heading to {morning_spot}",
148
+ {})
149
+ h, m = t // 60, t % 60
150
+ t = self._add(h, m, "exercise", morning_spot, 2, morning_exercise,
151
+ {"fun": 0.1, "energy": -0.05})
152
+ h, m = t // 60, t % 60
153
+ t = self._add(h, m, "move", home, 1, "Back home to freshen up",
154
+ {})
155
+
156
  # Breakfast
157
  h, m = t // 60, t % 60
158
  t = self._add(h, m, "eat", home, 2, "Having breakfast at home",
 
171
  f"{work_label} — morning block",
172
  {"purpose": 0.3})
173
 
174
+ # Lunch — pick a food place, park, or stay at work
175
+ food_places = ["cafe", "restaurant", "grocery", "bakery", "park", "park"]
176
  lunch_spot = self._rng.choice(food_places)
177
  h, m = t // 60, t % 60
178
  t = self._add(h, m, "move", lunch_spot, 1, f"Walking to lunch at {lunch_spot}",
 
193
  f"{work_label} — afternoon block",
194
  {"purpose": 0.3})
195
 
196
+ # Post-work exercise for active personas (conscientiousness >= 6 or extraversion >= 7)
197
+ if (persona.conscientiousness >= 6 or persona.extraversion >= 7) and self._rng.random() < 0.4:
198
+ exercise_spot = self._rng.choice(["gym", "park", "sports_field", "park"])
199
+ exercise_details = {
200
+ "gym": "Post-work gym session",
201
+ "park": self._rng.choice(["Jogging in the park", "Evening walk in the park",
202
+ "Stretching and walking in the park"]),
203
+ "sports_field": self._rng.choice(["Playing pickup soccer after work",
204
+ "Evening run at the sports field",
205
+ "Shooting hoops at the sports field"]),
206
+ }
207
+ h, m = t // 60, t % 60
208
+ t = self._add(h, m, "move", exercise_spot, 1, f"Heading to {exercise_spot}",
209
+ {})
210
+ h, m = t // 60, t % 60
211
+ t = self._add(h, m, "exercise", exercise_spot, self._rng.randint(2, 4),
212
+ exercise_details.get(exercise_spot, f"Exercising at {exercise_spot}"),
213
+ {"fun": 0.2, "energy": -0.1})
214
+
215
  # Commute home
216
  h, m = t // 60, t % 60
217
  t = self._add(h, m, "move", home, 1, "Heading home",
 
348
  t = self._add(h, m, "move", place, 1, f"Going to {place}",
349
  {})
350
  act_ticks = self._rng.randint(3, 6)
351
+ act_type = "exercise" if place in ("gym", "sports_field") else "relax"
352
+ if place == "park" and self._rng.random() < 0.4:
353
+ act_type = "exercise"
354
+ act_detail = {
355
+ "park": self._rng.choice(["Walking around the park", "Jogging in the park",
356
+ "Relaxing in the park"]),
357
+ "sports_field": self._rng.choice(["Playing soccer", "Shooting hoops",
358
+ "Running laps", "Playing frisbee"]),
359
+ "gym": "Working out",
360
+ }.get(place, f"Hanging out at {place}")
361
  h, m = t // 60, t % 60
362
  t = self._add(h, m, act_type, place, act_ticks,
363
+ act_detail,
364
  {"social": 0.2, "fun": 0.25})
365
  else:
366
  # Quiet afternoon
 
415
 
416
  if e >= 6:
417
  # Extroverts: go out, stay until ~30-45 min before sleep, then come home
418
+ venue = self._rng.choice(["bar", "restaurant", "park", "cinema",
419
+ "town_square", "sports_field", "park"])
420
  h, m = t // 60, t % 60
421
  t = self._add(h, m, "move", venue, 1, f"Heading to {venue}", {})
422
  wind_down_start = sleep_t - self._rng.randint(2, 3) * 15
 
447
  def _add_leisure_block(self, persona: Persona, home: str,
448
  t: int, end_t: int) -> int:
449
  """Fill a leisure period with activities based on personality."""
450
+ # Base activities available to everyone
451
+ activities = ["park", "park"] # Park is always a strong option
452
+
453
  if persona.extraversion >= 6:
454
+ activities.extend(["cafe", "gym", "town_square", "sports_field",
455
+ "sports_field", "park", "bar"])
456
+ elif persona.extraversion >= 4:
457
+ activities.extend(["cafe", "park", "sports_field", "town_square"])
458
  else:
459
+ activities.extend(["library", "church", "park"])
460
+
461
  if persona.conscientiousness >= 6:
462
+ activities.extend(["gym", "sports_field"])
463
  if persona.openness >= 6:
464
+ activities.extend(["library", "park", "cinema", "town_square"])
465
 
466
  dest = self._rng.choice(activities)
467
  available_ticks = max(0, (end_t - t) // 15)
 
478
  # Activity there
479
  act_ticks = min(available_ticks - 1, self._rng.randint(2, max(3, available_ticks - 1)))
480
  if act_ticks > 0:
481
+ act_type = "exercise" if dest in ("gym", "sports_field") else "relax"
482
+ # Park can be exercise too (jogging, walking)
483
+ if dest == "park" and self._rng.random() < 0.5:
484
+ act_type = "exercise"
485
+
486
  act_detail = {
487
+ "park": self._rng.choice([
488
+ "Taking a walk in the park", "Jogging through the park",
489
+ "Strolling along the park paths", "Sitting on a bench in the park",
490
+ "Walking the trails at Willow Park", "Enjoying nature in the park",
491
+ "Doing yoga in the park", "Reading on a park bench",
492
+ ]),
493
+ "cafe": self._rng.choice([
494
+ "Hanging out at the cafe", "Having coffee at the cafe",
495
+ "Working on a laptop at the cafe", "Chatting at the cafe",
496
+ ]),
497
+ "gym": self._rng.choice([
498
+ "Working out at the gym", "Lifting weights at the gym",
499
+ "Doing cardio at the gym", "Fitness class at the gym",
500
+ ]),
501
+ "library": self._rng.choice([
502
+ "Reading at the library", "Browsing books at the library",
503
+ "Studying at the library", "Quiet time at the library",
504
+ ]),
505
  "cinema": "Watching a movie",
506
+ "town_square": self._rng.choice([
507
+ "People-watching at the square", "Hanging out at the square",
508
+ "Sitting by the fountain in town square",
509
+ ]),
510
+ "sports_field": self._rng.choice([
511
+ "Playing soccer at the sports field",
512
+ "Shooting hoops at the sports field",
513
+ "Playing catch at the sports field",
514
+ "Running laps at the sports field",
515
+ "Playing frisbee at the sports field",
516
+ "Doing drills at the sports field",
517
+ ]),
518
  "church": "Quiet time at the church",
519
+ "bar": "Having a drink at the bar",
520
  }.get(dest, f"Spending time at {dest}")
521
  needs = {
522
+ "park": {"fun": 0.2, "comfort": 0.15},
523
  "cafe": {"social": 0.2, "fun": 0.1},
524
  "gym": {"energy": -0.1, "fun": 0.2},
525
  "library": {"fun": 0.15, "comfort": 0.1},
526
  "cinema": {"fun": 0.3, "social": 0.1},
527
  "town_square": {"social": 0.2, "fun": 0.15},
528
+ "sports_field": {"fun": 0.3, "social": 0.15, "energy": -0.1},
529
  "church": {"comfort": 0.2, "purpose": 0.1},
530
+ "bar": {"social": 0.2, "fun": 0.15},
531
  }.get(dest, {"fun": 0.1})
532
  h, m = t // 60, t % 60
533
  t = self._add(h, m, act_type, dest, act_ticks, act_detail, needs)
src/soci/api/routes.py CHANGED
@@ -310,7 +310,7 @@ async def test_llm():
310
  async def set_llm_provider(req: SwitchProviderRequest):
311
  """Hot-swap the active LLM provider."""
312
  from soci.api.server import switch_llm_provider
313
- valid = {"claude", "groq", "gemini", "hf", "ollama"}
314
  if req.provider not in valid:
315
  raise HTTPException(status_code=400, detail=f"Unknown provider '{req.provider}'")
316
  try:
 
310
  async def set_llm_provider(req: SwitchProviderRequest):
311
  """Hot-swap the active LLM provider."""
312
  from soci.api.server import switch_llm_provider
313
+ valid = {"claude", "groq", "gemini", "nn", "ollama"}
314
  if req.provider not in valid:
315
  raise HTTPException(status_code=400, detail=f"Unknown provider '{req.provider}'")
316
  try: