RayMelius Claude Opus 4.6 commited on
Commit
b596ed6
·
1 Parent(s): f64d3dc

NN training: realistic persona data, graphs, hot-reload endpoint

Browse files

- Update all 20 personas in nn_train.py and nn_selfimprove.py to match
enhanced personas.yaml (corrected names, ages, traits, occupations)
- Add persona-specific behavioral patterns: night shift schedule (George),
nightly bar routine (Frank), crush proximity (Lila→Elena), taxi wandering
(Omar), overworked doctor (Priya), morning exercise (Marcus/Yuki), etc.
- Add persona-aware needs generation, mood calculation, starting locations
- Add training graph plots (loss + accuracy curves) after training
- Add --graph flag to display graphs from last training run
- Decouple --push from training (push existing model without retraining)
- Add POST /api/nn/reload endpoint for hot-reloading ONNX model from HF Hub
- NNClient.reload() deletes cache, re-downloads, swaps ONNX session in-place
- Both scripts auto-trigger /api/nn/reload after pushing to HF Hub

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

scripts/nn_selfimprove.py CHANGED
@@ -395,6 +395,7 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
395
 
396
  best_acc = 0.0
397
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
 
398
 
399
  for epoch in range(epochs):
400
  model.train()
@@ -415,19 +416,32 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
415
  total_loss += loss.item()
416
  n += 1
417
  scheduler.step()
 
418
 
419
  # Validate
420
  model.eval()
421
  correct = 0
422
  total = 0
 
423
  with torch.no_grad():
424
  for batch in val_loader:
425
  feat = batch["features"].to(DEVICE)
426
  out = model(feat)
 
 
 
 
 
 
427
  pred = out["action_logits"].argmax(dim=-1)
428
  correct += (pred == batch["action"].to(DEVICE)).sum().item()
429
  total += feat.shape[0]
430
  acc = correct / total if total > 0 else 0
 
 
 
 
 
431
 
432
  if acc > best_acc:
433
  best_acc = acc
@@ -436,8 +450,9 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
436
  if (epoch + 1) % 5 == 0 or epoch == 0:
437
  logger.info(
438
  f"Epoch {epoch+1}/{epochs} | "
439
- f"Loss: {total_loss/n:.4f} | "
440
- f"Val Acc: {acc:.1%} | "
 
441
  f"Best: {best_acc:.1%}"
442
  )
443
 
@@ -456,7 +471,27 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
456
  opset_version=17,
457
  dynamo=False,
458
  )
459
- logger.info(f"ONNX exported: {ONNX_PATH} ({ONNX_PATH.stat().st_size / 1024:.0f} KB)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  return best_acc
462
 
@@ -465,8 +500,9 @@ def train(epochs: int = 20, batch_size: int = 512, lr: float = 3e-4):
465
  # STEP 3: PUSH — Upload improved model to HuggingFace Hub
466
  # ════════════════════════════════════════════════════════════════════════
467
 
468
- def push(repo_id: str = "RayMelius/soci-agent-nn", accuracy: float = None):
469
- """Push the retrained ONNX model to HuggingFace Hub."""
 
470
  from huggingface_hub import HfApi, login
471
 
472
  token = os.environ.get("HF_TOKEN", "")
@@ -535,6 +571,114 @@ def push(repo_id: str = "RayMelius/soci-agent-nn", accuracy: float = None):
535
 
536
  logger.info("Push complete!")
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  # ════════════════════════════════════════════════════════════════════════
540
  # Model architecture (inline to avoid import dependency)
@@ -659,23 +803,71 @@ def _build_model():
659
  # Synthetic data fallback (when not enough collected samples)
660
  # ════════════════════════════════════════════════════════════════════════
661
 
662
- # Inline personas for synthetic generation
663
  _PERSONAS = [
664
- {"O": 8, "C": 7, "E": 4, "A": 6, "N": 5, "age": 34, "home": "house_elena", "work": "office"},
665
- {"O": 10, "C": 3, "E": 6, "A": 7, "N": 7, "age": 33, "home": "house_elena", "work": "library"},
666
- {"O": 6, "C": 7, "E": 9, "A": 5, "N": 3, "age": 32, "home": "house_marcus", "work": "gym"},
667
- {"O": 7, "C": 6, "E": 3, "A": 8, "N": 4, "age": 68, "home": "house_helen", "work": "library"},
668
- {"O": 5, "C": 8, "E": 5, "A": 8, "N": 3, "age": 58, "home": "house_helen", "work": "bakery"},
669
- {"O": 9, "C": 3, "E": 8, "A": 5, "N": 5, "age": 22, "home": "house_kai", "work": "cafe"},
670
- {"O": 7, "C": 8, "E": 5, "A": 7, "N": 6, "age": 38, "home": "house_priya", "work": "hospital"},
671
- {"O": 5, "C": 7, "E": 7, "A": 9, "N": 4, "age": 62, "home": "house_rosa", "work": "restaurant"},
672
- {"O": 3, "C": 6, "E": 4, "A": 4, "N": 5, "age": 72, "home": "house_frank", "work": "bar"},
673
- {"O": 6, "C": 8, "E": 3, "A": 7, "N": 5, "age": 35, "home": "house_frank", "work": "school"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  ]
675
 
676
 
 
 
 
 
 
 
 
 
677
  def _generate_synthetic(n: int) -> list[dict]:
678
- """Generate synthetic training samples (same logic as notebook)."""
679
  data = []
680
  for _ in range(n):
681
  p = random.choice(_PERSONAS)
@@ -683,52 +875,230 @@ def _generate_synthetic(n: int) -> list[dict]:
683
  "openness": p["O"], "conscientiousness": p["C"], "extraversion": p["E"],
684
  "agreeableness": p["A"], "neuroticism": p["N"],
685
  }
 
 
 
 
 
686
  hour = random.randint(0, 23)
687
  minute = random.choice([0, 15, 30, 45])
688
  day = random.randint(1, 30)
 
 
 
 
689
  needs = {}
690
  for nm in NEED_NAMES:
691
- needs[nm] = round(random.uniform(0.0, 1.0), 2)
692
- mood = round(random.uniform(-1.0, 1.0), 2)
693
- loc = random.choice(LOCATIONS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
- # Simple rule-based label
696
  urgent = [(nm, needs[nm]) for nm in NEED_NAMES if needs[nm] < 0.15]
697
  urgent.sort(key=lambda x: x[1])
698
  action = None
699
  target = loc
700
 
 
701
  if urgent:
702
  need_name = urgent[0][0]
703
  if need_name == "hunger":
704
- action, target = "eat", random.choice(["cafe", "restaurant", "bakery"])
 
 
 
 
 
705
  elif need_name == "energy":
706
  action, target = "sleep", p["home"]
707
  elif need_name == "social":
708
- action, target = "talk", random.choice(["cafe", "bar", "park"])
 
 
 
 
 
709
  elif need_name == "purpose":
710
  action, target = "work", p["work"]
711
  elif need_name == "comfort":
712
- action, target = "relax", p["home"]
713
  elif need_name == "fun":
714
- action, target = "relax", random.choice(["park", "cinema"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  if action is None:
717
- period = _time_period(hour)
718
  if period == 0:
719
  action, target = "sleep", p["home"]
 
 
 
 
 
 
 
720
  elif period in (2, 4):
721
- action, target = "work", p["work"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  elif period == 3:
723
- action, target = "eat", random.choice(["cafe", "restaurant"])
 
 
 
724
  elif period == 5:
725
- action = random.choice(["talk", "eat", "relax"])
726
- target = random.choice(["bar", "restaurant", "park", p["home"]])
 
 
 
 
 
 
 
 
 
 
727
  elif period == 6:
728
- action, target = "sleep", p["home"]
729
- else:
730
- action = random.choice(["eat", "exercise", "move"])
731
- target = random.choice(["cafe", "gym", p["work"]])
 
 
 
 
 
 
 
 
 
732
 
733
  features = encode_features(
734
  personality=persona, age=p["age"],
@@ -741,7 +1111,7 @@ def _generate_synthetic(n: int) -> list[dict]:
741
  "features": features,
742
  "action_idx": ACTION_TO_IDX.get(action, 0),
743
  "target_loc_idx": LOC_TO_IDX.get(target, 0),
744
- "duration": ACTION_DURATIONS.get(action, 2),
745
  })
746
 
747
  return data
@@ -928,7 +1298,7 @@ async def scheduled(
928
  # 7. Push improved model
929
  if os.environ.get("HF_TOKEN"):
930
  logger.info("Pushing improved model to HF Hub...")
931
- push(repo_id=repo_id, accuracy=best_acc)
932
  else:
933
  logger.warning("HF_TOKEN not set — skipping push")
934
 
@@ -1017,10 +1387,11 @@ async def budget(
1017
 
1018
  def main():
1019
  parser = argparse.ArgumentParser(description="Soci Agent NN — Self-Improvement Pipeline")
1020
- parser.add_argument("mode", choices=["collect", "train", "push", "all", "scheduled", "budget"],
1021
  help="collect=watch live sim, train=retrain NN, push=upload to HF, "
1022
  "all=full pipeline, scheduled=daily Gemini cycle, "
1023
- "budget=check quota & set probability for target duration")
 
1024
  parser.add_argument("--url", default="https://raymelius-soci2.hf.space",
1025
  help="Live simulation URL (default: HF Space)")
1026
  parser.add_argument("--minutes", type=int, default=60,
@@ -1035,6 +1406,10 @@ def main():
1035
  help="HF Hub repo ID")
1036
  args = parser.parse_args()
1037
 
 
 
 
 
1038
  if args.mode in ("collect", "all"):
1039
  asyncio.run(collect(base_url=args.url, duration_minutes=args.minutes))
1040
 
@@ -1043,7 +1418,7 @@ def main():
1043
 
1044
  if args.mode in ("push", "all"):
1045
  acc = best_acc if args.mode == "all" else None
1046
- push(repo_id=args.repo, accuracy=acc)
1047
 
1048
  if args.mode == "scheduled":
1049
  asyncio.run(scheduled(
 
395
 
396
  best_acc = 0.0
397
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
398
+ history = {"train_loss": [], "val_loss": [], "val_action_acc": []}
399
 
400
  for epoch in range(epochs):
401
  model.train()
 
416
  total_loss += loss.item()
417
  n += 1
418
  scheduler.step()
419
+ avg_train_loss = total_loss / n
420
 
421
  # Validate
422
  model.eval()
423
  correct = 0
424
  total = 0
425
+ val_loss = 0.0
426
  with torch.no_grad():
427
  for batch in val_loader:
428
  feat = batch["features"].to(DEVICE)
429
  out = model(feat)
430
+ loss = (
431
+ 1.0 * action_loss_fn(out["action_logits"], batch["action"].to(DEVICE))
432
+ + 0.5 * location_loss_fn(out["location_logits"], batch["location"].to(DEVICE))
433
+ + 0.2 * duration_loss_fn(out["duration"], batch["duration"].to(DEVICE))
434
+ )
435
+ val_loss += loss.item()
436
  pred = out["action_logits"].argmax(dim=-1)
437
  correct += (pred == batch["action"].to(DEVICE)).sum().item()
438
  total += feat.shape[0]
439
  acc = correct / total if total > 0 else 0
440
+ avg_val_loss = val_loss / len(val_loader)
441
+
442
+ history["train_loss"].append(avg_train_loss)
443
+ history["val_loss"].append(avg_val_loss)
444
+ history["val_action_acc"].append(acc)
445
 
446
  if acc > best_acc:
447
  best_acc = acc
 
450
  if (epoch + 1) % 5 == 0 or epoch == 0:
451
  logger.info(
452
  f"Epoch {epoch+1}/{epochs} | "
453
+ f"Train: {avg_train_loss:.4f} | "
454
+ f"Val: {avg_val_loss:.4f} | "
455
+ f"Acc: {acc:.1%} | "
456
  f"Best: {best_acc:.1%}"
457
  )
458
 
 
471
  opset_version=17,
472
  dynamo=False,
473
  )
474
+ onnx_size = ONNX_PATH.stat().st_size / 1024
475
+ logger.info(f"ONNX exported: {ONNX_PATH} ({onnx_size:.0f} KB)")
476
+
477
+ # ── Save training stats ───────────────────────────────────────────
478
+ stats = {
479
+ "best_val_action_acc": best_acc,
480
+ "epochs": epochs,
481
+ "train_samples": len(train_ds),
482
+ "val_samples": len(val_ds),
483
+ "collected_samples": sum(source_counts.values()),
484
+ "source_counts": source_counts,
485
+ "model_size_kb": onnx_size,
486
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
487
+ "history": history,
488
+ }
489
+ stats_path = MODEL_DIR / "training_stats.json"
490
+ stats_path.write_text(json.dumps(stats, indent=2))
491
+ logger.info(f"Stats saved to {stats_path}")
492
+
493
+ # ── Plot training graphs ──────────────────────────────────────────
494
+ plot_training_graphs(stats_path)
495
 
496
  return best_acc
497
 
 
500
  # STEP 3: PUSH — Upload improved model to HuggingFace Hub
501
  # ════════════════════════════════════════════════════════════════════════
502
 
503
+ def push(repo_id: str = "RayMelius/soci-agent-nn", accuracy: float = None,
504
+ base_url: str = "https://raymelius-soci2.hf.space"):
505
+ """Push the retrained ONNX model to HuggingFace Hub, then trigger live reload."""
506
  from huggingface_hub import HfApi, login
507
 
508
  token = os.environ.get("HF_TOKEN", "")
 
571
 
572
  logger.info("Push complete!")
573
 
574
+ # Trigger hot-reload on the live simulation if reachable
575
+ try:
576
+ resp = httpx.post(f"{base_url}/api/nn/reload", timeout=30.0)
577
+ if resp.status_code == 200:
578
+ logger.info(f"Live sim NN reloaded: {resp.json().get('message', 'ok')}")
579
+ else:
580
+ logger.warning(f"Could not reload live sim NN: HTTP {resp.status_code}")
581
+ except Exception as e:
582
+ logger.warning(f"Could not reach live sim for reload: {e}")
583
+
584
+
585
+ # ════════════════════════════════════════════════════════════════════════
586
+ # Training Graphs
587
+ # ════════════════════════════════════════════════════════════════════════
588
+
589
+ def plot_training_graphs(stats_path: Path | str | None = None):
590
+ """Plot training loss and accuracy curves from saved training stats.
591
+
592
+ Saves the plot to models/training_graphs.png and displays it.
593
+ """
594
+ import matplotlib
595
+ matplotlib.use("Agg")
596
+ import matplotlib.pyplot as plt
597
+
598
+ stats_path = Path(stats_path) if stats_path else MODEL_DIR / "training_stats.json"
599
+ if not stats_path.exists():
600
+ logger.error(f"No training stats found at {stats_path}")
601
+ return
602
+
603
+ stats = json.loads(stats_path.read_text())
604
+ history = stats.get("history", {})
605
+
606
+ train_loss = history.get("train_loss", [])
607
+ val_loss = history.get("val_loss", [])
608
+ val_action_acc = history.get("val_action_acc", [])
609
+
610
+ if not train_loss:
611
+ logger.error("No training history found in stats file")
612
+ return
613
+
614
+ epochs_range = list(range(1, len(train_loss) + 1))
615
+
616
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
617
+ fig.suptitle(
618
+ f"Soci Self-Improve Training — {stats.get('timestamp', '?')} | "
619
+ f"Best Acc: {stats.get('best_val_action_acc', stats.get('best_accuracy', 0)):.1%}",
620
+ fontsize=13, fontweight="bold",
621
+ )
622
+
623
+ # Loss curves
624
+ ax = axes[0]
625
+ ax.plot(epochs_range, train_loss, label="Train Loss", color="#2196F3", linewidth=2)
626
+ if val_loss:
627
+ ax.plot(epochs_range, val_loss, label="Val Loss", color="#F44336", linewidth=2)
628
+ ax.set_xlabel("Epoch")
629
+ ax.set_ylabel("Loss")
630
+ ax.set_title("Training & Validation Loss")
631
+ ax.legend()
632
+ ax.grid(True, alpha=0.3)
633
+ ax.set_xlim(1, len(train_loss))
634
+
635
+ # Action accuracy
636
+ ax = axes[1]
637
+ if val_action_acc:
638
+ ax.plot(epochs_range, [a * 100 for a in val_action_acc], label="Action Accuracy",
639
+ color="#4CAF50", linewidth=2)
640
+ best_epoch = int(np.argmax(val_action_acc)) + 1
641
+ best_acc = max(val_action_acc) * 100
642
+ ax.axhline(y=best_acc, color="#4CAF50", linestyle="--", alpha=0.4)
643
+ ax.annotate(f"Best: {best_acc:.1f}% (epoch {best_epoch})",
644
+ xy=(best_epoch, best_acc), fontsize=9,
645
+ xytext=(best_epoch + 1, best_acc - 3),
646
+ arrowprops=dict(arrowstyle="->", color="#4CAF50"),
647
+ color="#4CAF50")
648
+ ax.set_xlabel("Epoch")
649
+ ax.set_ylabel("Accuracy (%)")
650
+ ax.set_title("Action Prediction Accuracy")
651
+ ax.legend()
652
+ ax.grid(True, alpha=0.3)
653
+ ax.set_xlim(1, len(train_loss))
654
+
655
+ # Footer
656
+ footer = (
657
+ f"Train: {stats.get('train_samples', '?'):,} samples | "
658
+ f"Val: {stats.get('val_samples', '?'):,} samples | "
659
+ f"Collected: {stats.get('collected_samples', 0):,} | "
660
+ f"Model: {stats.get('model_size_kb', 0):.0f} KB"
661
+ )
662
+ fig.text(0.5, 0.01, footer, ha="center", fontsize=9, color="gray")
663
+
664
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95])
665
+
666
+ graph_path = MODEL_DIR / "training_graphs.png"
667
+ fig.savefig(str(graph_path), dpi=150, bbox_inches="tight")
668
+ logger.info(f"Training graphs saved to {graph_path}")
669
+
670
+ try:
671
+ import warnings
672
+ with warnings.catch_warnings():
673
+ warnings.simplefilter("ignore")
674
+ matplotlib.use("TkAgg")
675
+ plt.show(block=False)
676
+ plt.pause(0.5)
677
+ except Exception:
678
+ pass
679
+
680
+ plt.close(fig)
681
+
682
 
683
  # ════════════════════════════════════════════════════════════════════════
684
  # Model architecture (inline to avoid import dependency)
 
803
  # Synthetic data fallback (when not enough collected samples)
804
  # ════════════════════════════════════════════════════════════════════════
805
 
806
+ # Inline personas for synthetic generation — must match personas.yaml
807
  _PERSONAS = [
808
+ # House 1 Elena & Lila (roommates)
809
+ {"id": "elena", "O": 8, "C": 7, "E": 4, "A": 6, "N": 5, "age": 34, "home": "house_elena", "work": "office",
810
+ "tags": ["freelance", "introvert", "tech"], "hangouts": ["cafe", "library"]},
811
+ {"id": "lila", "O":10, "C": 3, "E": 6, "A": 7, "N": 7, "age": 33, "home": "house_elena", "work": "library",
812
+ "tags": ["creative", "emotional", "crush_elena"], "hangouts": ["park", "cafe", "library"]},
813
+ # House 2 Marcus & Zoe (siblings)
814
+ {"id": "marcus", "O": 5, "C": 8, "E": 9, "A": 7, "N": 3, "age": 28, "home": "house_marcus", "work": "gym",
815
+ "tags": ["athletic", "extrovert", "community"], "hangouts": ["park", "sports_field", "cafe"]},
816
+ {"id": "zoe", "O": 8, "C": 4, "E": 8, "A": 6, "N": 7, "age": 19, "home": "house_marcus", "work": "library",
817
+ "tags": ["student", "social_media", "young"], "hangouts": ["cafe", "cinema", "park", "town_square"]},
818
+ # House 3 — Helen & Alice (close friends)
819
+ {"id": "helen", "O": 6, "C": 8, "E": 6, "A": 8, "N": 4, "age": 67, "home": "house_helen", "work": "library",
820
+ "tags": ["retired", "bookworm", "widow"], "hangouts": ["library", "park", "bakery", "church"]},
821
+ {"id": "alice", "O": 5, "C": 8, "E": 6, "A": 8, "N": 3, "age": 58, "home": "house_helen", "work": "bakery",
822
+ "tags": ["retired", "baker", "nurturing"], "hangouts": ["bakery", "grocery", "church"]},
823
+ # House 4 — Diana & Marco (mother & son)
824
+ {"id": "diana", "O": 4, "C": 9, "E": 5, "A": 6, "N": 7, "age": 41, "home": "house_diana", "work": "grocery",
825
+ "tags": ["business_owner", "single_mother", "protective"], "hangouts": ["grocery"]},
826
+ {"id": "marco", "O": 7, "C": 4, "E": 6, "A": 5, "N": 6, "age": 16, "home": "house_diana", "work": "school",
827
+ "tags": ["student", "teen", "gamer"], "hangouts": ["park", "cinema", "cafe", "sports_field"]},
828
+ # House 5 — Kai (lives alone)
829
+ {"id": "kai", "O": 9, "C": 3, "E": 7, "A": 5, "N": 6, "age": 22, "home": "house_kai", "work": "cafe",
830
+ "tags": ["musician", "creative", "dropout"], "hangouts": ["bar", "park", "town_square"]},
831
+ # House 6 — Priya & Nina (flatmates)
832
+ {"id": "priya", "O": 7, "C": 9, "E": 5, "A": 8, "N": 6, "age": 38, "home": "house_priya", "work": "hospital",
833
+ "tags": ["overworked", "caring", "guilt"], "hangouts": ["hospital", "pharmacy"]},
834
+ {"id": "nina", "O": 5, "C": 8, "E": 9, "A": 4, "N": 5, "age": 29, "home": "house_priya", "work": "office",
835
+ "tags": ["ambitious", "networker", "suspicious"], "hangouts": ["cafe", "restaurant", "office_tower"]},
836
+ # House 7 — James & Theo (housemates)
837
+ {"id": "james", "O": 5, "C": 6, "E": 8, "A": 7, "N": 4, "age": 55, "home": "house_james", "work": "bar",
838
+ "tags": ["social_hub", "divorced", "storyteller"], "hangouts": ["bar"]},
839
+ {"id": "theo", "O": 3, "C": 7, "E": 4, "A": 5, "N": 5, "age": 45, "home": "house_james", "work": "factory",
840
+ "tags": ["blue_collar", "stoic", "handy"], "hangouts": ["bar", "diner"]},
841
+ # House 8 — Rosa & Omar
842
+ {"id": "rosa", "O": 6, "C": 9, "E": 7, "A": 8, "N": 5, "age": 62, "home": "house_rosa", "work": "restaurant",
843
+ "tags": ["nurturing", "italian", "community_mother"], "hangouts": ["restaurant", "grocery"]},
844
+ {"id": "omar", "O": 6, "C": 6, "E": 7, "A": 7, "N": 4, "age": 50, "home": "house_rosa", "work": "restaurant",
845
+ "tags": ["immigrant", "philosophical", "hardworking"], "hangouts": ["restaurant", "cafe", "park"]},
846
+ # House 9 — Yuki & Devon (flatmates)
847
+ {"id": "yuki", "O": 8, "C": 6, "E": 5, "A": 9, "N": 3, "age": 26, "home": "house_yuki", "work": "gym",
848
+ "tags": ["mindful", "calm", "empathetic"], "hangouts": ["park", "gym", "library"]},
849
+ {"id": "devon", "O": 9, "C": 5, "E": 6, "A": 4, "N": 6, "age": 30, "home": "house_yuki", "work": "office",
850
+ "tags": ["investigative", "paranoid", "curious"], "hangouts": ["cafe", "bar", "library", "town_square"]},
851
+ # House 10 — Frank, George & Sam
852
+ {"id": "frank", "O": 3, "C": 7, "E": 5, "A": 4, "N": 5, "age": 72, "home": "house_frank", "work": "bar",
853
+ "tags": ["retired", "cantankerous", "creature_of_habit"], "hangouts": ["bar", "diner"]},
854
+ {"id": "george", "O": 4, "C": 7, "E": 3, "A": 6, "N": 4, "age": 47, "home": "house_frank", "work": "factory",
855
+ "tags": ["night_shift", "widower", "observant"], "hangouts": ["park"]},
856
+ {"id": "sam", "O": 7, "C": 8, "E": 3, "A": 7, "N": 4, "age": 40, "home": "house_frank", "work": "library",
857
+ "tags": ["quiet", "bookish", "inclusive"], "hangouts": ["library", "park", "cafe"]},
858
  ]
859
 
860
 
861
+ def _persona_hangout(p: dict, fallbacks: list[str]) -> str:
862
+ """Pick a location the persona naturally gravitates toward."""
863
+ hangouts = p.get("hangouts", [])
864
+ if hangouts and random.random() < 0.6:
865
+ return random.choice(hangouts)
866
+ return random.choice(fallbacks)
867
+
868
+
869
  def _generate_synthetic(n: int) -> list[dict]:
870
+ """Generate persona-aware synthetic training samples."""
871
  data = []
872
  for _ in range(n):
873
  p = random.choice(_PERSONAS)
 
875
  "openness": p["O"], "conscientiousness": p["C"], "extraversion": p["E"],
876
  "agreeableness": p["A"], "neuroticism": p["N"],
877
  }
878
+ tags = p.get("tags", [])
879
+ is_night_shift = "night_shift" in tags
880
+ is_retired = "retired" in tags
881
+ is_student = "student" in tags
882
+
883
  hour = random.randint(0, 23)
884
  minute = random.choice([0, 15, 30, 45])
885
  day = random.randint(1, 30)
886
+ is_weekend = ((day - 1) % 7) >= 5
887
+ period = _time_period(hour)
888
+
889
+ # Persona-aware needs generation
890
  needs = {}
891
  for nm in NEED_NAMES:
892
+ if random.random() < 0.15:
893
+ needs[nm] = round(random.uniform(0.0, 0.2), 2)
894
+ else:
895
+ needs[nm] = round(random.uniform(0.2, 1.0), 2)
896
+
897
+ if "overworked" in tags:
898
+ needs["energy"] = round(min(needs["energy"], random.uniform(0.1, 0.5)), 2)
899
+ needs["social"] = round(min(needs["social"], random.uniform(0.1, 0.5)), 2)
900
+ if "athletic" in tags:
901
+ needs["energy"] = round(max(needs["energy"], random.uniform(0.5, 0.9)), 2)
902
+ if "emotional" in tags:
903
+ swing = random.choice(NEED_NAMES)
904
+ needs[swing] = round(random.uniform(0.0, 0.3), 2)
905
+ if "creature_of_habit" in tags:
906
+ for nm in NEED_NAMES:
907
+ needs[nm] = round(needs[nm] * 0.7 + 0.2, 2)
908
+ if is_night_shift and 6 <= hour <= 18:
909
+ needs["energy"] = round(min(needs["energy"], random.uniform(0.05, 0.35)), 2)
910
+ if "mindful" in tags:
911
+ for nm in NEED_NAMES:
912
+ needs[nm] = round(max(needs[nm], 0.2), 2)
913
+ if is_student:
914
+ needs["social"] = round(max(needs["social"], random.uniform(0.3, 0.7)), 2)
915
+
916
+ # Persona-aware mood
917
+ avg_need = sum(needs.values()) / len(needs)
918
+ mood = round(max(-1.0, min(1.0,
919
+ (avg_need - 0.5) * 2 + random.uniform(-0.5, 0.5) * (p["N"] / 10.0)
920
+ )), 2)
921
+
922
+ # Persona-aware starting location
923
+ if is_night_shift:
924
+ if period in (0, 6):
925
+ loc = p["work"]
926
+ elif period in (2, 3):
927
+ loc = p["home"]
928
+ else:
929
+ loc = random.choice([p["home"], "park"] if random.random() < 0.7 else [p["home"]])
930
+ elif period == 0:
931
+ loc = p["home"]
932
+ elif period in (2, 4) and not is_weekend:
933
+ if is_retired:
934
+ loc = random.choice([p["home"]] + p.get("hangouts", ["park"]))
935
+ else:
936
+ loc = random.choice([p["work"], p["work"], _persona_hangout(p, ["cafe"])])
937
+ elif period == 5:
938
+ loc = random.choice([p["home"], _persona_hangout(p, ["bar", "cafe"])])
939
+ else:
940
+ loc = random.choice([p["home"], p["work"]])
941
 
942
+ # --- Determine action ---
943
  urgent = [(nm, needs[nm]) for nm in NEED_NAMES if needs[nm] < 0.15]
944
  urgent.sort(key=lambda x: x[1])
945
  action = None
946
  target = loc
947
 
948
+ # Priority 1: Critical needs
949
  if urgent:
950
  need_name = urgent[0][0]
951
  if need_name == "hunger":
952
+ eat_locs = ["cafe", "restaurant", "bakery", "diner", p["home"]]
953
+ if "community_mother" in tags:
954
+ eat_locs = ["restaurant", p["home"]]
955
+ elif "baker" in tags:
956
+ eat_locs = ["bakery", p["home"]]
957
+ action, target = "eat", random.choice(eat_locs)
958
  elif need_name == "energy":
959
  action, target = "sleep", p["home"]
960
  elif need_name == "social":
961
+ social_locs = ["cafe", "bar", "park", "town_square"]
962
+ if "social_hub" in tags:
963
+ social_locs = ["bar", "bar", "restaurant"]
964
+ elif "networker" in tags:
965
+ social_locs = ["cafe", "restaurant", "office"]
966
+ action, target = "talk", random.choice(social_locs)
967
  elif need_name == "purpose":
968
  action, target = "work", p["work"]
969
  elif need_name == "comfort":
970
+ action, target = "relax", random.choice([p["home"], "park", "library"])
971
  elif need_name == "fun":
972
+ fun_locs = ["park", "cinema", "bar", "sports_field"]
973
+ if is_student:
974
+ fun_locs = ["cinema", "park", "cafe", "town_square"]
975
+ action, target = random.choice(["relax", "exercise", "wander"]), random.choice(fun_locs)
976
+
977
+ # Priority 2: Night shift inverted schedule (George)
978
+ if action is None and is_night_shift:
979
+ if period in (0, 6):
980
+ action, target = "work", p["work"]
981
+ elif period == 1:
982
+ action, target = "move", p["home"]
983
+ elif period in (2, 3):
984
+ if needs["energy"] < 0.6:
985
+ action, target = "sleep", p["home"]
986
+ else:
987
+ action, target = "relax", random.choice([p["home"], "park"])
988
+ elif period in (4, 5):
989
+ if needs["hunger"] < 0.5:
990
+ action, target = "eat", random.choice(["diner", "restaurant", p["home"]])
991
+ else:
992
+ action, target = "move", p["work"]
993
 
994
+ # Priority 3: Persona-specific patterns
995
+ if action is None:
996
+ pid = p.get("id", "")
997
+ if pid == "frank" and period in (5, 6) and random.random() < 0.7:
998
+ action, target = "relax", "bar"
999
+ elif pid == "lila" and random.random() < 0.15:
1000
+ action = random.choice(["wander", "talk", "relax"])
1001
+ target = random.choice(["house_elena", "cafe", "library"])
1002
+ elif pid == "rosa" and period in (1, 2) and random.random() < 0.4:
1003
+ action, target = "shop", "grocery"
1004
+ elif pid == "omar" and period in (2, 3, 4) and not is_weekend and random.random() < 0.5:
1005
+ action, target = "wander", random.choice(["street_north", "street_south", "street_east", "street_west"])
1006
+ elif pid == "diana" and not is_weekend and period in (2, 3, 4) and random.random() < 0.7:
1007
+ action, target = "work", "grocery"
1008
+ elif pid == "marcus" and period == 1 and random.random() < 0.6:
1009
+ action, target = "exercise", random.choice(["gym", "park", "sports_field"])
1010
+ elif pid == "yuki" and period == 1 and random.random() < 0.5:
1011
+ action, target = "exercise", random.choice(["park", "gym"])
1012
+ elif pid == "devon" and period in (2, 4) and random.random() < 0.3:
1013
+ action = random.choice(["wander", "talk"])
1014
+ target = random.choice(["cafe", "bar", "town_square", "library"])
1015
+
1016
+ # Priority 4: General time-of-day patterns
1017
  if action is None:
 
1018
  if period == 0:
1019
  action, target = "sleep", p["home"]
1020
+ elif period == 1:
1021
+ if needs["hunger"] < 0.5:
1022
+ action, target = "eat", random.choice(["cafe", "bakery", p["home"]])
1023
+ elif p["E"] >= 6 and random.random() < 0.3:
1024
+ action, target = "exercise", random.choice(["gym", "park", "sports_field"])
1025
+ else:
1026
+ action, target = "move", p["work"]
1027
  elif period in (2, 4):
1028
+ if is_weekend:
1029
+ r = random.random()
1030
+ if is_retired:
1031
+ if r < 0.35:
1032
+ action, target = "relax", _persona_hangout(p, ["park", "library", p["home"]])
1033
+ elif r < 0.55:
1034
+ action, target = "talk", _persona_hangout(p, ["cafe", "park", "church"])
1035
+ elif r < 0.7:
1036
+ action, target = "shop", random.choice(["grocery", "pharmacy", "bakery"])
1037
+ else:
1038
+ action, target = "wander", random.choice(["park", "town_square"])
1039
+ elif is_student:
1040
+ if r < 0.3:
1041
+ action, target = "talk", random.choice(["cafe", "park", "cinema", "town_square"])
1042
+ elif r < 0.5:
1043
+ action, target = "relax", random.choice(["cinema", "park", p["home"]])
1044
+ elif r < 0.7:
1045
+ action, target = "exercise", random.choice(["gym", "park", "sports_field"])
1046
+ else:
1047
+ action, target = "wander", random.choice(["town_square", "street_north"])
1048
+ else:
1049
+ if r < 0.25:
1050
+ action, target = "relax", _persona_hangout(p, ["park", "cafe", p["home"]])
1051
+ elif r < 0.45 and p["E"] >= 6:
1052
+ action, target = "talk", _persona_hangout(p, ["cafe", "park", "town_square"])
1053
+ elif r < 0.6:
1054
+ action, target = "shop", random.choice(["grocery", "pharmacy"])
1055
+ elif r < 0.8:
1056
+ action, target = "exercise", random.choice(["gym", "park"])
1057
+ else:
1058
+ action, target = "wander", random.choice(["park", "town_square"])
1059
+ else:
1060
+ work_prob = 0.5 + p["C"] * 0.05
1061
+ if "business_owner" in tags or "overworked" in tags:
1062
+ work_prob += 0.15
1063
+ if is_retired:
1064
+ work_prob = 0.15
1065
+ if random.random() < work_prob:
1066
+ action, target = "work", p["work"]
1067
+ else:
1068
+ action = random.choice(["wander", "relax", "talk"])
1069
+ target = _persona_hangout(p, ["cafe", "park", "town_square"])
1070
  elif period == 3:
1071
+ if needs["hunger"] < 0.6:
1072
+ action, target = "eat", random.choice(["cafe", "restaurant", "bakery", "diner"])
1073
+ else:
1074
+ action, target = "relax", random.choice(["park", "cafe"])
1075
  elif period == 5:
1076
+ social_bias = p["E"] / 10.0
1077
+ r = random.random()
1078
+ if r < social_bias * 0.5:
1079
+ action, target = "talk", random.choice(["bar", "restaurant", "park", "cafe"])
1080
+ elif r < 0.4:
1081
+ action, target = "eat", random.choice(["restaurant", "bar", "diner", p["home"]])
1082
+ elif r < 0.55:
1083
+ action, target = "exercise", random.choice(["gym", "park"])
1084
+ elif r < 0.7:
1085
+ action, target = "relax", _persona_hangout(p, ["cinema", "bar", p["home"]])
1086
+ else:
1087
+ action, target = "relax", p["home"]
1088
  elif period == 6:
1089
+ if needs["energy"] < 0.4:
1090
+ action, target = "sleep", p["home"]
1091
+ else:
1092
+ action, target = "relax", p["home"]
1093
+
1094
+ # Move override
1095
+ if target != loc and action != "move" and random.random() < 0.3:
1096
+ action = "move"
1097
+
1098
+ # Duration adjustments
1099
+ dur = ACTION_DURATIONS.get(action, 2)
1100
+ if is_retired and dur > 3 and action not in ("sleep", "work"):
1101
+ dur = min(dur, 3)
1102
 
1103
  features = encode_features(
1104
  personality=persona, age=p["age"],
 
1111
  "features": features,
1112
  "action_idx": ACTION_TO_IDX.get(action, 0),
1113
  "target_loc_idx": LOC_TO_IDX.get(target, 0),
1114
+ "duration": min(max(dur, 1), 8),
1115
  })
1116
 
1117
  return data
 
1298
  # 7. Push improved model
1299
  if os.environ.get("HF_TOKEN"):
1300
  logger.info("Pushing improved model to HF Hub...")
1301
+ push(repo_id=repo_id, accuracy=best_acc, base_url=base_url)
1302
  else:
1303
  logger.warning("HF_TOKEN not set — skipping push")
1304
 
 
1387
 
1388
  def main():
1389
  parser = argparse.ArgumentParser(description="Soci Agent NN — Self-Improvement Pipeline")
1390
+ parser.add_argument("mode", choices=["collect", "train", "push", "all", "scheduled", "budget", "graph"],
1391
  help="collect=watch live sim, train=retrain NN, push=upload to HF, "
1392
  "all=full pipeline, scheduled=daily Gemini cycle, "
1393
+ "budget=check quota & set probability, "
1394
+ "graph=display training graphs from last run")
1395
  parser.add_argument("--url", default="https://raymelius-soci2.hf.space",
1396
  help="Live simulation URL (default: HF Space)")
1397
  parser.add_argument("--minutes", type=int, default=60,
 
1406
  help="HF Hub repo ID")
1407
  args = parser.parse_args()
1408
 
1409
+ if args.mode == "graph":
1410
+ plot_training_graphs()
1411
+ return
1412
+
1413
  if args.mode in ("collect", "all"):
1414
  asyncio.run(collect(base_url=args.url, duration_minutes=args.minutes))
1415
 
 
1418
 
1419
  if args.mode in ("push", "all"):
1420
  acc = best_acc if args.mode == "all" else None
1421
+ push(repo_id=args.repo, accuracy=acc, base_url=args.url)
1422
 
1423
  if args.mode == "scheduled":
1424
  asyncio.run(scheduled(
scripts/nn_train.py CHANGED
@@ -105,26 +105,135 @@ FEATURE_DIM = 47
105
  # ══════════════════════════════════════════════════════════════════════════
106
 
107
  PERSONAS = [
108
- {"id": "elena", "name": "Elena Vasquez", "age": 34, "occ": "software engineer", "O": 8, "C": 7, "E": 4, "A": 6, "N": 5, "home": "house_elena", "work": "office"},
109
- {"id": "lila", "name": "Lila Santos", "age": 33, "occ": "artist", "O":10, "C": 3, "E": 6, "A": 7, "N": 7, "home": "house_elena", "work": "library"},
110
- {"id": "marcus", "name": "Marcus Chen-Williams", "age": 32, "occ": "personal trainer", "O": 6, "C": 7, "E": 9, "A": 5, "N": 3, "home": "house_marcus", "work": "gym"},
111
- {"id": "zoe", "name": "Zoe Chen-Williams", "age": 19, "occ": "college student", "O": 8, "C": 4, "E": 8, "A": 6, "N": 7, "home": "house_marcus", "work": "library"},
112
- {"id": "helen", "name": "Helen Park", "age": 68, "occ": "retired librarian", "O": 7, "C": 6, "E": 3, "A": 8, "N": 4, "home": "house_helen", "work": "library"},
113
- {"id": "alice", "name": "Alice Fontaine", "age": 58, "occ": "retired accountant", "O": 5, "C": 8, "E": 5, "A": 8, "N": 3, "home": "house_helen", "work": "bakery"},
114
- {"id": "diana", "name": "Diana Delgado", "age": 42, "occ": "grocery store owner", "O": 4, "C": 8, "E": 5, "A": 6, "N": 4, "home": "house_diana", "work": "grocery"},
115
- {"id": "marco", "name": "Marco Delgado", "age": 16, "occ": "high school student", "O": 9, "C": 4, "E": 6, "A": 4, "N": 6, "home": "house_diana", "work": "school"},
116
- {"id": "kai", "name": "Kai Okonkwo", "age": 22, "occ": "barista", "O": 9, "C": 3, "E": 8, "A": 5, "N": 5, "home": "house_kai", "work": "cafe"},
117
- {"id": "priya", "name": "Priya Sharma", "age": 38, "occ": "doctor", "O": 7, "C": 8, "E": 5, "A": 7, "N": 6, "home": "house_priya", "work": "hospital"},
118
- {"id": "nina", "name": "Nina Volkov", "age": 29, "occ": "real estate agent", "O": 5, "C": 7, "E": 8, "A": 5, "N": 5, "home": "house_priya", "work": "office"},
119
- {"id": "james", "name": "James O'Brien", "age": 40, "occ": "bar owner", "O": 6, "C": 5, "E": 7, "A": 6, "N": 4, "home": "house_james", "work": "bar"},
120
- {"id": "theo", "name": "Theo Blackwood", "age": 45, "occ": "construction worker", "O": 3, "C": 8, "E": 4, "A": 5, "N": 5, "home": "house_james", "work": "factory"},
121
- {"id": "rosa", "name": "Rosa Martelli", "age": 62, "occ": "restaurant owner", "O": 5, "C": 7, "E": 7, "A": 9, "N": 4, "home": "house_rosa", "work": "restaurant"},
122
- {"id": "omar", "name": "Omar Hassan", "age": 50, "occ": "taxi driver", "O": 6, "C": 6, "E": 7, "A": 7, "N": 4, "home": "house_rosa", "work": "restaurant"},
123
- {"id": "yuki", "name": "Yuki Tanaka", "age": 26, "occ": "yoga instructor", "O": 8, "C": 6, "E": 5, "A": 9, "N": 3, "home": "house_yuki", "work": "gym"},
124
- {"id": "devon", "name": "Devon Reeves", "age": 30, "occ": "freelance journalist", "O": 9, "C": 5, "E": 6, "A": 5, "N": 6, "home": "house_yuki", "work": "office"},
125
- {"id": "frank", "name": "Frank Kowalski", "age": 72, "occ": "retired mechanic", "O": 3, "C": 6, "E": 4, "A": 4, "N": 5, "home": "house_frank", "work": "bar"},
126
- {"id": "george", "name": "George Adeyemi", "age": 47, "occ": "night shift security", "O": 5, "C": 7, "E": 3, "A": 6, "N": 4, "home": "house_frank", "work": "factory"},
127
- {"id": "sam", "name": "Sam Torres", "age": 35, "occ": "elementary school teacher", "O": 6, "C": 8, "E": 3, "A": 7, "N": 5, "home": "house_frank", "work": "school"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  ]
129
 
130
 
@@ -217,23 +326,155 @@ def encode_features(
217
  # 4. Synthetic Data Generator
218
  # ══════════════════════════════════════════════════════════════════════════
219
 
220
- def generate_action_example(persona: dict) -> dict:
221
- """Generate one training example with rule-based labels."""
222
- hour = random.randint(0, 23)
223
- minute = random.choice([0, 15, 30, 45])
224
- day = random.randint(1, 30)
225
- is_weekend = ((day - 1) % 7) >= 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- # Random needs (15% chance of critical)
 
 
228
  needs = {}
 
 
 
229
  for n in NEED_NAMES:
 
230
  if random.random() < 0.15:
231
  needs[n] = round(random.uniform(0.0, 0.2), 2)
232
  else:
233
  needs[n] = round(random.uniform(0.2, 1.0), 2)
234
 
235
- mood = round(random.uniform(-1.0, 1.0), 2)
236
- current_loc = random.choice(LOCATIONS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  # --- Determine action using rule-based logic ---
239
  # Priority 1: Critical needs
@@ -248,7 +489,13 @@ def generate_action_example(persona: dict) -> dict:
248
  need_name = urgent[0][0]
249
  if need_name == "hunger":
250
  action = "eat"
251
- target_loc = random.choice(["cafe", "restaurant", "grocery", "bakery", "diner", persona["home"]])
 
 
 
 
 
 
252
  duration = 2
253
  elif need_name == "energy":
254
  action = "sleep"
@@ -256,7 +503,12 @@ def generate_action_example(persona: dict) -> dict:
256
  duration = random.choice([4, 6, 8])
257
  elif need_name == "social":
258
  action = "talk"
259
- target_loc = random.choice(["cafe", "bar", "park", "town_square", current_loc])
 
 
 
 
 
260
  duration = 2
261
  elif need_name == "purpose":
262
  action = "work"
@@ -268,10 +520,117 @@ def generate_action_example(persona: dict) -> dict:
268
  duration = 2
269
  elif need_name == "fun":
270
  action = random.choice(["relax", "exercise", "wander"])
271
- target_loc = random.choice(["park", "gym", "cinema", "bar", "sports_field"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  duration = 2
273
 
274
- # Priority 2: Time-of-day patterns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  if action is None:
276
  period = _time_period(hour)
277
 
@@ -298,37 +657,81 @@ def generate_action_example(persona: dict) -> dict:
298
  elif period in (2, 4): # Mid-morning / Afternoon
299
  if is_weekend:
300
  r = random.random()
301
- if r < 0.25:
302
- action = "relax"
303
- target_loc = random.choice(["park", "cafe", "library", persona["home"]])
304
- elif r < 0.45 and persona["E"] >= 6:
305
- action = "talk"
306
- target_loc = random.choice(["cafe", "park", "town_square"])
307
- elif r < 0.6:
308
- action = "shop"
309
- target_loc = random.choice(["grocery", "pharmacy"])
310
- elif r < 0.8:
311
- action = "exercise"
312
- target_loc = random.choice(["gym", "park", "sports_field"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  else:
314
- action = "wander"
315
- target_loc = random.choice(["park", "town_square", "street_north", "street_south"])
316
- duration = random.choice([2, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  else:
 
318
  work_prob = 0.5 + persona["C"] * 0.05
 
 
 
 
 
319
  if random.random() < work_prob:
320
  action = "work"
321
  target_loc = persona["work"]
322
  duration = 4
323
  else:
324
  action = random.choice(["wander", "relax", "talk"])
325
- target_loc = random.choice(["cafe", "park", "town_square"])
326
  duration = 2
327
 
328
  elif period == 3: # Midday / lunch
329
  if needs["hunger"] < 0.6:
330
  action = "eat"
331
- target_loc = random.choice(["cafe", "restaurant", "bakery", "diner", "park"])
 
 
 
 
332
  duration = 2
333
  else:
334
  action = "relax"
@@ -340,7 +743,10 @@ def generate_action_example(persona: dict) -> dict:
340
  social_bias = persona["E"] / 10.0
341
  if r < social_bias * 0.5:
342
  action = "talk"
343
- target_loc = random.choice(["bar", "restaurant", "park", "cafe"])
 
 
 
344
  duration = 2
345
  elif r < 0.4:
346
  action = "eat"
@@ -352,7 +758,7 @@ def generate_action_example(persona: dict) -> dict:
352
  duration = 3
353
  elif r < 0.7:
354
  action = "relax"
355
- target_loc = random.choice(["cinema", "bar", persona["home"], "library"])
356
  duration = 2
357
  else:
358
  action = "relax"
@@ -375,6 +781,14 @@ def generate_action_example(persona: dict) -> dict:
375
  action = "move"
376
  duration = 1
377
 
 
 
 
 
 
 
 
 
378
  features = encode_features(
379
  persona=persona, hour=hour, minute=minute, day=day,
380
  needs=needs, mood=mood, current_loc=current_loc,
@@ -536,10 +950,8 @@ def train(
536
  num_val: int = 10_000,
537
  data_dir: str | None = None,
538
  resume: bool = False,
539
- push: bool = False,
540
- repo_id: str = "RayMelius/soci-agent-nn",
541
  ):
542
- """Full training pipeline: generate/load data, train, export ONNX, optionally push."""
543
  import torch
544
  import torch.nn as nn
545
  from torch.utils.data import Dataset, DataLoader
@@ -767,17 +1179,38 @@ def train(
767
  a, l, d, c = predict(PERSONAS[0], 0, 30, 5,
768
  {"hunger": 0.5, "energy": 0.05, "social": 0.4, "purpose": 0.6, "comfort": 0.3, "fun": 0.3},
769
  -0.3, "office")
770
- logger.info(f" Elena midnight exhausted: {a} -> {l} ({d} ticks, {c:.0%})")
771
 
772
  a, l, d, c = predict(PERSONAS[2], 12, 30, 3,
773
  {"hunger": 0.05, "energy": 0.7, "social": 0.5, "purpose": 0.6, "comfort": 0.5, "fun": 0.4},
774
  0.2, "gym", 5)
775
- logger.info(f" Marcus lunchtime starving: {a} -> {l} ({d} ticks, {c:.0%})")
776
 
777
  a, l, d, c = predict(PERSONAS[8], 10, 0, 6,
778
  {"hunger": 0.6, "energy": 0.7, "social": 0.5, "purpose": 0.5, "comfort": 0.7, "fun": 0.4},
779
  0.5, "house_kai")
780
- logger.info(f" Kai Saturday morning: {a} -> {l} ({d} ticks, {c:.0%})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  # ── Export to ONNX ───────────────────────────────────────────────
783
  logger.info("Exporting to ONNX...")
@@ -825,15 +1258,129 @@ def train(
825
  stats_path.write_text(json.dumps(stats, indent=2))
826
  logger.info(f"Stats saved to {stats_path}")
827
 
828
- # ── Push to HF Hub ───────────────────────────────────────────────
829
- if push:
830
- _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs, len(train_ds))
831
 
832
  return best_val_acc
833
 
834
 
835
- def _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs, num_train):
836
- """Upload model files to HuggingFace Hub."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
  from huggingface_hub import HfApi, login
838
 
839
  token = os.environ.get("HF_TOKEN", "")
@@ -876,6 +1423,17 @@ def _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs,
876
 
877
  logger.info(f"Model pushed to https://huggingface.co/{repo_id}")
878
 
 
 
 
 
 
 
 
 
 
 
 
879
 
880
  # ══════════════════════════════════════════════════════════════════════════
881
  # CLI
@@ -889,7 +1447,8 @@ def main():
889
  python scripts/nn_train.py # Train from scratch
890
  python scripts/nn_train.py --resume --epochs 50 # Continue training
891
  python scripts/nn_train.py --data data/nn_training # Use collected samples
892
- python scripts/nn_train.py --push --repo RayMelius/soci-agent-nn # Train + push
 
893
  """,
894
  )
895
  parser.add_argument("--epochs", type=int, default=30, help="Training epochs (default: 30)")
@@ -904,11 +1463,37 @@ def main():
904
  parser.add_argument("--resume", action="store_true",
905
  help="Resume from existing weights in models/")
906
  parser.add_argument("--push", action="store_true",
907
- help="Push trained model to HuggingFace Hub")
 
 
908
  parser.add_argument("--repo", default="RayMelius/soci-agent-nn",
909
  help="HF Hub repo ID (default: RayMelius/soci-agent-nn)")
 
 
910
  args = parser.parse_args()
911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
  train(
913
  epochs=args.epochs,
914
  batch_size=args.batch_size,
@@ -917,8 +1502,6 @@ def main():
917
  num_val=args.val_samples,
918
  data_dir=args.data,
919
  resume=args.resume,
920
- push=args.push,
921
- repo_id=args.repo,
922
  )
923
 
924
 
 
105
  # ══════════════════════════════════════════════════════════════════════════
106
 
107
  PERSONAS = [
108
+ # House 1 Elena & Lila (roommates)
109
+ {"id": "elena", "name": "Elena Vasquez", "age": 34, "gender": "female", "occ": "software engineer",
110
+ "O": 8, "C": 7, "E": 4, "A": 6, "N": 5, "home": "house_elena", "work": "office",
111
+ "tags": ["freelance", "introvert", "tech"],
112
+ "hangouts": ["cafe", "library"], # where she goes to think/work remotely
113
+ "routine_bias": {}},
114
+
115
+ {"id": "lila", "name": "Lila Santos", "age": 33, "gender": "female", "occ": "artist",
116
+ "O": 10, "C": 3, "E": 6, "A": 7, "N": 7, "home": "house_elena", "work": "library",
117
+ "tags": ["creative", "emotional", "crush_elena"],
118
+ "hangouts": ["park", "cafe", "library"], # paints outdoors, hangs near Elena
119
+ "routine_bias": {"relax": 0.15, "wander": 0.10}},
120
+
121
+ # House 2 Marcus & Zoe (siblings)
122
+ {"id": "marcus", "name": "Marcus Chen", "age": 28, "gender": "male", "occ": "fitness trainer",
123
+ "O": 5, "C": 8, "E": 9, "A": 7, "N": 3, "home": "house_marcus", "work": "gym",
124
+ "tags": ["athletic", "extrovert", "community"],
125
+ "hangouts": ["park", "sports_field", "cafe"],
126
+ "routine_bias": {"exercise": 0.20, "talk": 0.10}},
127
+
128
+ {"id": "zoe", "name": "Zoe Chen-Williams", "age": 19, "gender": "female", "occ": "college student",
129
+ "O": 8, "C": 4, "E": 8, "A": 6, "N": 7, "home": "house_marcus", "work": "library",
130
+ "tags": ["student", "social_media", "young"],
131
+ "hangouts": ["cafe", "cinema", "park", "town_square"],
132
+ "routine_bias": {"talk": 0.15, "wander": 0.10}},
133
+
134
+ # House 3 — Helen & Alice (close friends)
135
+ {"id": "helen", "name": "Helen Park", "age": 67, "gender": "female", "occ": "retired teacher",
136
+ "O": 6, "C": 8, "E": 6, "A": 8, "N": 4, "home": "house_helen", "work": "library",
137
+ "tags": ["retired", "bookworm", "widow"],
138
+ "hangouts": ["library", "park", "bakery", "church"],
139
+ "routine_bias": {"relax": 0.15}},
140
+
141
+ {"id": "alice", "name": "Alice Fontaine", "age": 58, "gender": "female", "occ": "retired accountant",
142
+ "O": 5, "C": 8, "E": 6, "A": 8, "N": 3, "home": "house_helen", "work": "bakery",
143
+ "tags": ["retired", "baker", "nurturing"],
144
+ "hangouts": ["bakery", "grocery", "church"],
145
+ "routine_bias": {"work": 0.10}}, # loves baking, spends extra time at bakery
146
+
147
+ # House 4 — Diana & Marco (mother & son)
148
+ {"id": "diana", "name": "Diana Novak", "age": 41, "gender": "female", "occ": "grocery store owner",
149
+ "O": 4, "C": 9, "E": 5, "A": 6, "N": 7, "home": "house_diana", "work": "grocery",
150
+ "tags": ["business_owner", "single_mother", "protective"],
151
+ "hangouts": ["grocery"], # rarely leaves the store
152
+ "routine_bias": {"work": 0.20}},
153
+
154
+ {"id": "marco", "name": "Marco Delgado", "age": 16, "gender": "male", "occ": "high school student",
155
+ "O": 7, "C": 4, "E": 6, "A": 5, "N": 6, "home": "house_diana", "work": "school",
156
+ "tags": ["student", "teen", "gamer"],
157
+ "hangouts": ["park", "cinema", "cafe", "sports_field"],
158
+ "routine_bias": {"relax": 0.10, "wander": 0.10}},
159
+
160
+ # House 5 — Kai (lives alone)
161
+ {"id": "kai", "name": "Kai Okonkwo", "age": 22, "gender": "nonbinary", "occ": "barista",
162
+ "O": 9, "C": 3, "E": 7, "A": 5, "N": 6, "home": "house_kai", "work": "cafe",
163
+ "tags": ["musician", "creative", "dropout"],
164
+ "hangouts": ["bar", "park", "town_square"], # plays music, socializes
165
+ "routine_bias": {"relax": 0.10, "talk": 0.10}},
166
+
167
+ # House 6 — Priya & Nina (flatmates)
168
+ {"id": "priya", "name": "Priya Sharma", "age": 38, "gender": "female", "occ": "doctor",
169
+ "O": 7, "C": 9, "E": 5, "A": 8, "N": 6, "home": "house_priya", "work": "hospital",
170
+ "tags": ["overworked", "caring", "guilt"],
171
+ "hangouts": ["hospital", "pharmacy"], # rarely leaves work orbit
172
+ "routine_bias": {"work": 0.25}}, # long hospital hours
173
+
174
+ {"id": "nina", "name": "Nina Volkov", "age": 29, "gender": "female", "occ": "real estate agent",
175
+ "O": 5, "C": 8, "E": 9, "A": 4, "N": 5, "home": "house_priya", "work": "office",
176
+ "tags": ["ambitious", "networker", "suspicious"],
177
+ "hangouts": ["cafe", "restaurant", "office_tower"],
178
+ "routine_bias": {"talk": 0.15, "work": 0.10}},
179
+
180
+ # House 7 — James & Theo (housemates)
181
+ {"id": "james", "name": "James O'Brien", "age": 55, "gender": "male", "occ": "bar owner",
182
+ "O": 5, "C": 6, "E": 8, "A": 7, "N": 4, "home": "house_james", "work": "bar",
183
+ "tags": ["social_hub", "divorced", "storyteller"],
184
+ "hangouts": ["bar"], # his whole life revolves around the bar
185
+ "routine_bias": {"talk": 0.20}},
186
+
187
+ {"id": "theo", "name": "Theo Blackwood", "age": 45, "gender": "male", "occ": "construction worker",
188
+ "O": 3, "C": 7, "E": 4, "A": 5, "N": 5, "home": "house_james", "work": "factory",
189
+ "tags": ["blue_collar", "stoic", "handy"],
190
+ "hangouts": ["bar", "diner"], # bar after work
191
+ "routine_bias": {"work": 0.15}},
192
+
193
+ # House 8 — Rosa & Omar
194
+ {"id": "rosa", "name": "Rosa Martelli", "age": 62, "gender": "female", "occ": "restaurant owner",
195
+ "O": 6, "C": 9, "E": 7, "A": 8, "N": 5, "home": "house_rosa", "work": "restaurant",
196
+ "tags": ["nurturing", "italian", "community_mother"],
197
+ "hangouts": ["restaurant", "grocery"], # buys ingredients, feeds everyone
198
+ "routine_bias": {"work": 0.20, "eat": 0.05}},
199
+
200
+ {"id": "omar", "name": "Omar Hassan", "age": 50, "gender": "male", "occ": "taxi driver",
201
+ "O": 6, "C": 6, "E": 7, "A": 7, "N": 4, "home": "house_rosa", "work": "restaurant",
202
+ "tags": ["immigrant", "philosophical", "hardworking"],
203
+ "hangouts": ["restaurant", "cafe", "park"],
204
+ "routine_bias": {"wander": 0.15}}, # drives around town = wander
205
+
206
+ # House 9 — Yuki & Devon (flatmates)
207
+ {"id": "yuki", "name": "Yuki Tanaka", "age": 26, "gender": "female", "occ": "yoga instructor",
208
+ "O": 8, "C": 6, "E": 5, "A": 9, "N": 3, "home": "house_yuki", "work": "gym",
209
+ "tags": ["mindful", "calm", "empathetic"],
210
+ "hangouts": ["park", "gym", "library"], # meditates in park
211
+ "routine_bias": {"exercise": 0.15, "relax": 0.10}},
212
+
213
+ {"id": "devon", "name": "Devon Reeves", "age": 30, "gender": "male", "occ": "freelance journalist",
214
+ "O": 9, "C": 5, "E": 6, "A": 4, "N": 6, "home": "house_yuki", "work": "office",
215
+ "tags": ["investigative", "paranoid", "curious"],
216
+ "hangouts": ["cafe", "bar", "library", "town_square"], # interviews, research
217
+ "routine_bias": {"wander": 0.15, "talk": 0.10}},
218
+
219
+ # House 10 — Frank, George & Sam
220
+ {"id": "frank", "name": "Frank Kowalski", "age": 72, "gender": "male", "occ": "retired mechanic",
221
+ "O": 3, "C": 7, "E": 5, "A": 4, "N": 5, "home": "house_frank", "work": "bar",
222
+ "tags": ["retired", "cantankerous", "creature_of_habit"],
223
+ "hangouts": ["bar", "diner"], # same bar stool every night
224
+ "routine_bias": {"relax": 0.15}},
225
+
226
+ {"id": "george", "name": "George Adeyemi", "age": 47, "gender": "male", "occ": "night shift security",
227
+ "O": 4, "C": 7, "E": 3, "A": 6, "N": 4, "home": "house_frank", "work": "factory",
228
+ "tags": ["night_shift", "widower", "observant"],
229
+ "hangouts": ["park"], # naps in park during day
230
+ "routine_bias": {}}, # schedule handled by night_shift tag
231
+
232
+ {"id": "sam", "name": "Sam Nakamura", "age": 40, "gender": "nonbinary", "occ": "librarian",
233
+ "O": 7, "C": 8, "E": 3, "A": 7, "N": 4, "home": "house_frank", "work": "library",
234
+ "tags": ["quiet", "bookish", "inclusive"],
235
+ "hangouts": ["library", "park", "cafe"],
236
+ "routine_bias": {"work": 0.10, "relax": 0.05}},
237
  ]
238
 
239
 
 
326
  # 4. Synthetic Data Generator
327
  # ══════════════════════════════════════════════════════════════════════════
328
 
329
+ def _is_night_shift(persona: dict) -> bool:
330
+ return "night_shift" in persona.get("tags", [])
331
+
332
+
333
+ def _is_retired(persona: dict) -> bool:
334
+ return "retired" in persona.get("tags", [])
335
+
336
+
337
+ def _is_student(persona: dict) -> bool:
338
+ return "student" in persona.get("tags", [])
339
+
340
+
341
+ def _persona_hangout(persona: dict, fallbacks: list[str]) -> str:
342
+ """Pick a location the persona naturally gravitates toward."""
343
+ hangouts = persona.get("hangouts", [])
344
+ if hangouts and random.random() < 0.6:
345
+ return random.choice(hangouts)
346
+ return random.choice(fallbacks)
347
+
348
+
349
+ def _apply_routine_bias(persona: dict, action: str | None) -> str | None:
350
+ """Probabilistically override action based on persona routine_bias."""
351
+ bias = persona.get("routine_bias", {})
352
+ for biased_action, prob in bias.items():
353
+ if random.random() < prob:
354
+ return biased_action
355
+ return action
356
 
357
+
358
+ def _generate_needs_for_persona(persona: dict, hour: int) -> dict:
359
+ """Generate needs influenced by persona lifestyle, not purely random."""
360
  needs = {}
361
+ tags = persona.get("tags", [])
362
+ is_night = _is_night_shift(persona)
363
+
364
  for n in NEED_NAMES:
365
+ # Base: 15% chance critical, else moderate-to-full
366
  if random.random() < 0.15:
367
  needs[n] = round(random.uniform(0.0, 0.2), 2)
368
  else:
369
  needs[n] = round(random.uniform(0.2, 1.0), 2)
370
 
371
+ # Persona-specific need tendencies
372
+ if "overworked" in tags:
373
+ # Priya: chronically low energy, low social
374
+ needs["energy"] = round(min(needs["energy"], random.uniform(0.1, 0.5)), 2)
375
+ needs["social"] = round(min(needs["social"], random.uniform(0.1, 0.5)), 2)
376
+ if "athletic" in tags:
377
+ # Marcus: high energy baseline, low fun without exercise
378
+ needs["energy"] = round(max(needs["energy"], random.uniform(0.5, 0.9)), 2)
379
+ if "emotional" in tags:
380
+ # Lila: volatile needs
381
+ swing = random.choice(NEED_NAMES)
382
+ needs[swing] = round(random.uniform(0.0, 0.3), 2)
383
+ if "creature_of_habit" in tags:
384
+ # Frank: stable moderate needs
385
+ for n in NEED_NAMES:
386
+ needs[n] = round(needs[n] * 0.7 + 0.2, 2)
387
+ if is_night:
388
+ # George: energy inverted — tired during day, awake at night
389
+ if 6 <= hour <= 18:
390
+ needs["energy"] = round(min(needs["energy"], random.uniform(0.05, 0.35)), 2)
391
+ else:
392
+ needs["energy"] = round(max(needs["energy"], random.uniform(0.5, 0.9)), 2)
393
+ if "student" in tags:
394
+ # Students: higher social need, lower purpose
395
+ needs["social"] = round(max(needs["social"], random.uniform(0.3, 0.7)), 2)
396
+ needs["fun"] = round(max(needs["fun"], random.uniform(0.2, 0.5)), 2)
397
+ if "nurturing" in tags or "community_mother" in tags:
398
+ # Rosa, Alice: high comfort, purpose from feeding/helping others
399
+ needs["purpose"] = round(max(needs["purpose"], random.uniform(0.4, 0.8)), 2)
400
+ if "mindful" in tags:
401
+ # Yuki: generally balanced, rarely critical
402
+ for n in NEED_NAMES:
403
+ needs[n] = round(max(needs[n], 0.2), 2)
404
+
405
+ return needs
406
+
407
+
408
+ def _mood_for_persona(persona: dict, needs: dict) -> float:
409
+ """Generate mood influenced by personality and current needs."""
410
+ tags = persona.get("tags", [])
411
+ # Base mood from needs average
412
+ avg_need = sum(needs.values()) / len(needs)
413
+ base_mood = (avg_need - 0.5) * 2 # maps 0-1 to -1..+1
414
+
415
+ # Neuroticism makes mood more volatile
416
+ n_factor = persona.get("N", 5) / 10.0
417
+ volatility = random.uniform(-0.5, 0.5) * n_factor
418
+ base_mood += volatility
419
+
420
+ if "calm" in tags or "mindful" in tags:
421
+ base_mood = base_mood * 0.6 + 0.2 # dampen toward positive
422
+ if "emotional" in tags:
423
+ base_mood += random.uniform(-0.4, 0.4)
424
+
425
+ return round(max(-1.0, min(1.0, base_mood)), 2)
426
+
427
+
428
+ def _starting_location(persona: dict, hour: int, is_weekend: bool) -> str:
429
+ """Pick a realistic starting location based on time and persona."""
430
+ tags = persona.get("tags", [])
431
+ is_night = _is_night_shift(persona)
432
+ period = _time_period(hour)
433
+
434
+ # Night shift workers: at work during night, home during day
435
+ if is_night:
436
+ if period in (0, 6): # late night / night — at work
437
+ return persona["work"]
438
+ elif period in (1, 2): # morning — heading home or sleeping
439
+ return random.choice([persona["home"], persona["work"]])
440
+ else: # daytime — at home (sleeping) or park (napping)
441
+ return random.choice([persona["home"], "park"] if random.random() < 0.7
442
+ else [persona["home"]])
443
+
444
+ # Normal schedule
445
+ if period == 0: # late night — home
446
+ return persona["home"]
447
+ elif period == 1: # early morning — home or commuting
448
+ return random.choice([persona["home"], persona["work"]])
449
+ elif period in (2, 4) and not is_weekend: # working hours
450
+ if _is_retired(persona):
451
+ return random.choice([persona["home"]] + persona.get("hangouts", ["park"]))
452
+ if _is_student(persona):
453
+ return random.choice([persona["work"], "library", persona["home"]])
454
+ return random.choice([persona["work"], persona["work"], persona["work"],
455
+ _persona_hangout(persona, ["cafe"])])
456
+ elif period == 3: # lunch
457
+ return random.choice([persona["work"], "cafe", "restaurant", "diner", "park"])
458
+ elif period == 5: # evening
459
+ return random.choice([persona["home"], _persona_hangout(persona, ["bar", "cafe", "park"])])
460
+ elif period == 6: # night
461
+ return random.choice([persona["home"], persona["home"], _persona_hangout(persona, ["bar"])])
462
+
463
+ return persona["home"]
464
+
465
+
466
+ def generate_action_example(persona: dict) -> dict:
467
+ """Generate one training example with persona-aware rule-based labels."""
468
+ hour = random.randint(0, 23)
469
+ minute = random.choice([0, 15, 30, 45])
470
+ day = random.randint(1, 30)
471
+ is_weekend = ((day - 1) % 7) >= 5
472
+ tags = persona.get("tags", [])
473
+ is_night = _is_night_shift(persona)
474
+
475
+ needs = _generate_needs_for_persona(persona, hour)
476
+ mood = _mood_for_persona(persona, needs)
477
+ current_loc = _starting_location(persona, hour, is_weekend)
478
 
479
  # --- Determine action using rule-based logic ---
480
  # Priority 1: Critical needs
 
489
  need_name = urgent[0][0]
490
  if need_name == "hunger":
491
  action = "eat"
492
+ # Persona-aware eating locations
493
+ eat_locs = ["cafe", "restaurant", "grocery", "bakery", "diner", persona["home"]]
494
+ if "community_mother" in tags: # Rosa eats at her restaurant
495
+ eat_locs = ["restaurant", persona["home"]]
496
+ elif "baker" in tags: # Alice eats at bakery or home
497
+ eat_locs = ["bakery", persona["home"]]
498
+ target_loc = random.choice(eat_locs)
499
  duration = 2
500
  elif need_name == "energy":
501
  action = "sleep"
 
503
  duration = random.choice([4, 6, 8])
504
  elif need_name == "social":
505
  action = "talk"
506
+ social_locs = ["cafe", "bar", "park", "town_square", current_loc]
507
+ if "social_hub" in tags: # James talks at his bar
508
+ social_locs = ["bar", "bar", "restaurant", "park"]
509
+ elif "networker" in tags: # Nina networks everywhere
510
+ social_locs = ["cafe", "restaurant", "office", "office_tower"]
511
+ target_loc = random.choice(social_locs)
512
  duration = 2
513
  elif need_name == "purpose":
514
  action = "work"
 
520
  duration = 2
521
  elif need_name == "fun":
522
  action = random.choice(["relax", "exercise", "wander"])
523
+ fun_locs = ["park", "gym", "cinema", "bar", "sports_field"]
524
+ if "teen" in tags or "student" in tags:
525
+ fun_locs = ["cinema", "park", "cafe", "sports_field", "town_square"]
526
+ target_loc = random.choice(fun_locs)
527
+ duration = 2
528
+
529
+ # Priority 2: Night shift inverted schedule (George)
530
+ if action is None and is_night:
531
+ period = _time_period(hour)
532
+ if period in (0, 6): # night — George is at work
533
+ action = "work"
534
+ target_loc = persona["work"]
535
+ duration = 4
536
+ elif period == 1: # early morning — heading home
537
+ action = "move"
538
+ target_loc = persona["home"]
539
+ duration = 1
540
+ elif period in (2, 3): # day — sleeping
541
+ if needs["energy"] < 0.6:
542
+ action = "sleep"
543
+ target_loc = persona["home"]
544
+ duration = random.choice([4, 6, 8])
545
+ else:
546
+ # Sometimes naps in park
547
+ action = "relax"
548
+ target_loc = random.choice([persona["home"], "park"])
549
+ duration = 2
550
+ elif period in (4, 5): # afternoon/evening — wake up, eat, prep for work
551
+ r = random.random()
552
+ if needs["hunger"] < 0.5:
553
+ action = "eat"
554
+ target_loc = random.choice(["diner", "restaurant", persona["home"]])
555
+ duration = 2
556
+ elif r < 0.3:
557
+ action = "talk"
558
+ target_loc = random.choice(["park", "cafe"])
559
+ duration = 2
560
+ else:
561
+ action = "move"
562
+ target_loc = persona["work"]
563
+ duration = 1
564
+
565
+ # Priority 3: Persona-specific behavioral patterns
566
+ if action is None:
567
+ period = _time_period(hour)
568
+
569
+ # Frank: same bar stool every evening/night
570
+ if persona["id"] == "frank" and period in (5, 6):
571
+ if random.random() < 0.7:
572
+ action = "relax"
573
+ target_loc = "bar"
574
+ duration = 3
575
+
576
+ # Lila: gravitates toward Elena (crush) — seeks her hangouts
577
+ elif persona["id"] == "lila" and random.random() < 0.15:
578
+ action = random.choice(["wander", "talk", "relax"])
579
+ target_loc = random.choice(["house_elena", "cafe", "library", "office"])
580
  duration = 2
581
 
582
+ # Rosa: spends mornings buying ingredients, cooks all day
583
+ elif persona["id"] == "rosa" and period in (1, 2):
584
+ if random.random() < 0.4:
585
+ action = "shop"
586
+ target_loc = "grocery"
587
+ duration = 2
588
+
589
+ # Devon: investigative journalist, wanders and interviews
590
+ elif persona["id"] == "devon" and period in (2, 4):
591
+ if random.random() < 0.3:
592
+ action = random.choice(["wander", "talk"])
593
+ target_loc = random.choice(["cafe", "bar", "town_square", "library", "park"])
594
+ duration = 2
595
+
596
+ # Omar: taxi driver — wanders the streets during work hours
597
+ elif persona["id"] == "omar" and period in (2, 3, 4) and not is_weekend:
598
+ if random.random() < 0.5:
599
+ action = "wander"
600
+ target_loc = random.choice(["street_north", "street_south", "street_east", "street_west",
601
+ "town_square", "cafe", "restaurant"])
602
+ duration = 2
603
+
604
+ # Diana: barely leaves the grocery store on weekdays
605
+ elif persona["id"] == "diana" and not is_weekend and period in (2, 3, 4):
606
+ if random.random() < 0.7:
607
+ action = "work"
608
+ target_loc = "grocery"
609
+ duration = 4
610
+
611
+ # Marcus: morning exercise is sacred
612
+ elif persona["id"] == "marcus" and period == 1:
613
+ if random.random() < 0.6:
614
+ action = "exercise"
615
+ target_loc = random.choice(["gym", "park", "sports_field"])
616
+ duration = 3
617
+
618
+ # Yuki: morning meditation/yoga
619
+ elif persona["id"] == "yuki" and period == 1:
620
+ if random.random() < 0.5:
621
+ action = "exercise"
622
+ target_loc = random.choice(["park", "gym"])
623
+ duration = 3
624
+
625
+ # Priority 4: Apply routine_bias override
626
+ if action is None:
627
+ biased = _apply_routine_bias(persona, None)
628
+ if biased:
629
+ action = biased
630
+ target_loc = _persona_hangout(persona, ["park", "cafe", persona["home"]])
631
+ duration = 2
632
+
633
+ # Priority 5: General time-of-day patterns (fallback)
634
  if action is None:
635
  period = _time_period(hour)
636
 
 
657
  elif period in (2, 4): # Mid-morning / Afternoon
658
  if is_weekend:
659
  r = random.random()
660
+ if _is_retired(persona):
661
+ # Retired: relaxed weekend routine
662
+ if r < 0.35:
663
+ action = "relax"
664
+ target_loc = _persona_hangout(persona, ["park", "library", persona["home"]])
665
+ elif r < 0.55:
666
+ action = "talk"
667
+ target_loc = _persona_hangout(persona, ["cafe", "park", "church"])
668
+ elif r < 0.7:
669
+ action = "shop"
670
+ target_loc = random.choice(["grocery", "pharmacy", "bakery"])
671
+ else:
672
+ action = "wander"
673
+ target_loc = random.choice(["park", "town_square", "street_north"])
674
+ duration = random.choice([2, 3])
675
+ elif _is_student(persona):
676
+ # Students: social weekends
677
+ if r < 0.3:
678
+ action = "talk"
679
+ target_loc = random.choice(["cafe", "park", "cinema", "town_square"])
680
+ elif r < 0.5:
681
+ action = "relax"
682
+ target_loc = random.choice(["cinema", "park", persona["home"]])
683
+ elif r < 0.65:
684
+ action = "exercise"
685
+ target_loc = random.choice(["gym", "park", "sports_field"])
686
+ elif r < 0.8:
687
+ action = "wander"
688
+ target_loc = random.choice(["town_square", "street_north", "street_south"])
689
+ else:
690
+ action = "shop"
691
+ target_loc = random.choice(["grocery", "pharmacy"])
692
+ duration = random.choice([2, 3])
693
  else:
694
+ if r < 0.25:
695
+ action = "relax"
696
+ target_loc = _persona_hangout(persona, ["park", "cafe", "library", persona["home"]])
697
+ elif r < 0.45 and persona["E"] >= 6:
698
+ action = "talk"
699
+ target_loc = _persona_hangout(persona, ["cafe", "park", "town_square"])
700
+ elif r < 0.6:
701
+ action = "shop"
702
+ target_loc = random.choice(["grocery", "pharmacy"])
703
+ elif r < 0.8:
704
+ action = "exercise"
705
+ target_loc = random.choice(["gym", "park", "sports_field"])
706
+ else:
707
+ action = "wander"
708
+ target_loc = random.choice(["park", "town_square", "street_north", "street_south"])
709
+ duration = random.choice([2, 3])
710
  else:
711
+ # Weekday work hours
712
  work_prob = 0.5 + persona["C"] * 0.05
713
+ # Business owners and doctors work even harder
714
+ if "business_owner" in tags or persona["occ"] == "doctor":
715
+ work_prob += 0.15
716
+ if _is_retired(persona):
717
+ work_prob = 0.15 # retired people rarely "work"
718
  if random.random() < work_prob:
719
  action = "work"
720
  target_loc = persona["work"]
721
  duration = 4
722
  else:
723
  action = random.choice(["wander", "relax", "talk"])
724
+ target_loc = _persona_hangout(persona, ["cafe", "park", "town_square"])
725
  duration = 2
726
 
727
  elif period == 3: # Midday / lunch
728
  if needs["hunger"] < 0.6:
729
  action = "eat"
730
+ lunch_locs = ["cafe", "restaurant", "bakery", "diner", "park"]
731
+ # People eat near their workplace
732
+ if current_loc == persona["work"]:
733
+ lunch_locs = ["cafe", "restaurant", "diner", "bakery"]
734
+ target_loc = random.choice(lunch_locs)
735
  duration = 2
736
  else:
737
  action = "relax"
 
743
  social_bias = persona["E"] / 10.0
744
  if r < social_bias * 0.5:
745
  action = "talk"
746
+ evening_social = ["bar", "restaurant", "park", "cafe"]
747
+ if "social_hub" in tags:
748
+ evening_social = ["bar", "bar", "restaurant"]
749
+ target_loc = random.choice(evening_social)
750
  duration = 2
751
  elif r < 0.4:
752
  action = "eat"
 
758
  duration = 3
759
  elif r < 0.7:
760
  action = "relax"
761
+ target_loc = _persona_hangout(persona, ["cinema", "bar", persona["home"], "library"])
762
  duration = 2
763
  else:
764
  action = "relax"
 
781
  action = "move"
782
  duration = 1
783
 
784
+ # Retired and elderly people do shorter activities
785
+ if _is_retired(persona) and duration > 3 and action not in ("sleep", "work"):
786
+ duration = min(duration, 3)
787
+
788
+ # Teens/students have shorter attention spans for non-social activities
789
+ if _is_student(persona) and action in ("relax", "work") and random.random() < 0.3:
790
+ duration = max(1, duration - 1)
791
+
792
  features = encode_features(
793
  persona=persona, hour=hour, minute=minute, day=day,
794
  needs=needs, mood=mood, current_loc=current_loc,
 
950
  num_val: int = 10_000,
951
  data_dir: str | None = None,
952
  resume: bool = False,
 
 
953
  ):
954
+ """Full training pipeline: generate/load data, train, export ONNX."""
955
  import torch
956
  import torch.nn as nn
957
  from torch.utils.data import Dataset, DataLoader
 
1179
  a, l, d, c = predict(PERSONAS[0], 0, 30, 5,
1180
  {"hunger": 0.5, "energy": 0.05, "social": 0.4, "purpose": 0.6, "comfort": 0.3, "fun": 0.3},
1181
  -0.3, "office")
1182
+ logger.info(f" Elena midnight exhausted at office: {a} -> {l} ({d} ticks, {c:.0%})")
1183
 
1184
  a, l, d, c = predict(PERSONAS[2], 12, 30, 3,
1185
  {"hunger": 0.05, "energy": 0.7, "social": 0.5, "purpose": 0.6, "comfort": 0.5, "fun": 0.4},
1186
  0.2, "gym", 5)
1187
+ logger.info(f" Marcus lunchtime starving at gym: {a} -> {l} ({d} ticks, {c:.0%})")
1188
 
1189
  a, l, d, c = predict(PERSONAS[8], 10, 0, 6,
1190
  {"hunger": 0.6, "energy": 0.7, "social": 0.5, "purpose": 0.5, "comfort": 0.7, "fun": 0.4},
1191
  0.5, "house_kai")
1192
+ logger.info(f" Kai Saturday morning at home: {a} -> {l} ({d} ticks, {c:.0%})")
1193
+
1194
+ # George (night shift) — should sleep during the day
1195
+ george = [p for p in PERSONAS if p["id"] == "george"][0]
1196
+ a, l, d, c = predict(george, 11, 0, 3,
1197
+ {"hunger": 0.4, "energy": 0.15, "social": 0.5, "purpose": 0.7, "comfort": 0.5, "fun": 0.4},
1198
+ -0.1, "house_frank")
1199
+ logger.info(f" George midday after night shift: {a} -> {l} ({d} ticks, {c:.0%})")
1200
+
1201
+ # Frank — evening at the bar
1202
+ frank = [p for p in PERSONAS if p["id"] == "frank"][0]
1203
+ a, l, d, c = predict(frank, 20, 0, 4,
1204
+ {"hunger": 0.5, "energy": 0.4, "social": 0.3, "purpose": 0.6, "comfort": 0.5, "fun": 0.3},
1205
+ 0.1, "bar")
1206
+ logger.info(f" Frank evening at the bar: {a} -> {l} ({d} ticks, {c:.0%})")
1207
+
1208
+ # Priya — overworked at hospital
1209
+ priya = [p for p in PERSONAS if p["id"] == "priya"][0]
1210
+ a, l, d, c = predict(priya, 15, 0, 2,
1211
+ {"hunger": 0.3, "energy": 0.2, "social": 0.3, "purpose": 0.8, "comfort": 0.4, "fun": 0.2},
1212
+ -0.2, "hospital")
1213
+ logger.info(f" Priya afternoon exhausted at hospital: {a} -> {l} ({d} ticks, {c:.0%})")
1214
 
1215
  # ── Export to ONNX ───────────────────────────────────────────────
1216
  logger.info("Exporting to ONNX...")
 
1258
  stats_path.write_text(json.dumps(stats, indent=2))
1259
  logger.info(f"Stats saved to {stats_path}")
1260
 
1261
+ # ── Plot training graphs ──────────────────────────────────────────
1262
+ plot_training_graphs(stats_path)
 
1263
 
1264
  return best_val_acc
1265
 
1266
 
1267
+ def plot_training_graphs(stats_path: Path | str | None = None):
1268
+ """Plot training loss and accuracy curves from saved training stats.
1269
+
1270
+ Saves the plot to models/training_graphs.png and displays it.
1271
+ """
1272
+ import matplotlib
1273
+ matplotlib.use("Agg") # non-interactive backend as fallback
1274
+ import matplotlib.pyplot as plt
1275
+
1276
+ stats_path = Path(stats_path) if stats_path else MODEL_DIR / "training_stats.json"
1277
+ if not stats_path.exists():
1278
+ logger.error(f"No training stats found at {stats_path}")
1279
+ return
1280
+
1281
+ stats = json.loads(stats_path.read_text())
1282
+ history = stats.get("history", {})
1283
+
1284
+ train_loss = history.get("train_loss", [])
1285
+ val_loss = history.get("val_loss", [])
1286
+ val_action_acc = history.get("val_action_acc", [])
1287
+ val_loc_acc = history.get("val_loc_acc", [])
1288
+
1289
+ if not train_loss:
1290
+ logger.error("No training history found in stats file")
1291
+ return
1292
+
1293
+ epochs_range = list(range(1, len(train_loss) + 1))
1294
+
1295
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
1296
+ fig.suptitle(
1297
+ f"Soci Agent NN Training — {stats.get('timestamp', '?')} | "
1298
+ f"Best Action Acc: {stats.get('best_val_action_acc', 0):.1%}",
1299
+ fontsize=13, fontweight="bold",
1300
+ )
1301
+
1302
+ # Loss curves
1303
+ ax = axes[0]
1304
+ ax.plot(epochs_range, train_loss, label="Train Loss", color="#2196F3", linewidth=2)
1305
+ ax.plot(epochs_range, val_loss, label="Val Loss", color="#F44336", linewidth=2)
1306
+ ax.set_xlabel("Epoch")
1307
+ ax.set_ylabel("Loss")
1308
+ ax.set_title("Training & Validation Loss")
1309
+ ax.legend()
1310
+ ax.grid(True, alpha=0.3)
1311
+ ax.set_xlim(1, len(train_loss))
1312
+
1313
+ # Action accuracy
1314
+ ax = axes[1]
1315
+ ax.plot(epochs_range, [a * 100 for a in val_action_acc], label="Action Accuracy",
1316
+ color="#4CAF50", linewidth=2)
1317
+ best_epoch = int(np.argmax(val_action_acc)) + 1
1318
+ best_acc = max(val_action_acc) * 100
1319
+ ax.axhline(y=best_acc, color="#4CAF50", linestyle="--", alpha=0.4)
1320
+ ax.annotate(f"Best: {best_acc:.1f}% (epoch {best_epoch})",
1321
+ xy=(best_epoch, best_acc), fontsize=9,
1322
+ xytext=(best_epoch + 1, best_acc - 3),
1323
+ arrowprops=dict(arrowstyle="->", color="#4CAF50"),
1324
+ color="#4CAF50")
1325
+ ax.set_xlabel("Epoch")
1326
+ ax.set_ylabel("Accuracy (%)")
1327
+ ax.set_title("Action Prediction Accuracy")
1328
+ ax.legend()
1329
+ ax.grid(True, alpha=0.3)
1330
+ ax.set_xlim(1, len(train_loss))
1331
+
1332
+ # Location accuracy
1333
+ ax = axes[2]
1334
+ if val_loc_acc:
1335
+ ax.plot(epochs_range, [a * 100 for a in val_loc_acc], label="Location Accuracy",
1336
+ color="#FF9800", linewidth=2)
1337
+ best_loc_epoch = int(np.argmax(val_loc_acc)) + 1
1338
+ best_loc = max(val_loc_acc) * 100
1339
+ ax.axhline(y=best_loc, color="#FF9800", linestyle="--", alpha=0.4)
1340
+ ax.annotate(f"Best: {best_loc:.1f}% (epoch {best_loc_epoch})",
1341
+ xy=(best_loc_epoch, best_loc), fontsize=9,
1342
+ xytext=(best_loc_epoch + 1, best_loc - 3),
1343
+ arrowprops=dict(arrowstyle="->", color="#FF9800"),
1344
+ color="#FF9800")
1345
+ ax.set_xlabel("Epoch")
1346
+ ax.set_ylabel("Accuracy (%)")
1347
+ ax.set_title("Location Prediction Accuracy")
1348
+ ax.legend()
1349
+ ax.grid(True, alpha=0.3)
1350
+ ax.set_xlim(1, len(train_loss))
1351
+
1352
+ # Footer with training info
1353
+ footer = (
1354
+ f"Train: {stats.get('train_samples', '?'):,} samples | "
1355
+ f"Val: {stats.get('val_samples', '?'):,} samples | "
1356
+ f"Collected: {stats.get('collected_samples', 0):,} | "
1357
+ f"Model: {stats.get('model_size_kb', 0):.0f} KB"
1358
+ )
1359
+ fig.text(0.5, 0.01, footer, ha="center", fontsize=9, color="gray")
1360
+
1361
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1362
+
1363
+ graph_path = MODEL_DIR / "training_graphs.png"
1364
+ fig.savefig(str(graph_path), dpi=150, bbox_inches="tight")
1365
+ logger.info(f"Training graphs saved to {graph_path}")
1366
+
1367
+ # Try to display interactively
1368
+ try:
1369
+ import warnings
1370
+ with warnings.catch_warnings():
1371
+ warnings.simplefilter("ignore")
1372
+ matplotlib.use("TkAgg")
1373
+ plt.show(block=False)
1374
+ plt.pause(0.5)
1375
+ except Exception:
1376
+ pass # headless environment, PNG saved is enough
1377
+
1378
+ plt.close(fig)
1379
+
1380
+
1381
+ def _push_to_hub(best_pt, onnx_path, stats_path, repo_id, best_val_acc, epochs, num_train,
1382
+ base_url: str = "https://raymelius-soci2.hf.space"):
1383
+ """Upload model files to HuggingFace Hub, then trigger live reload."""
1384
  from huggingface_hub import HfApi, login
1385
 
1386
  token = os.environ.get("HF_TOKEN", "")
 
1423
 
1424
  logger.info(f"Model pushed to https://huggingface.co/{repo_id}")
1425
 
1426
+ # Trigger hot-reload on the live simulation
1427
+ try:
1428
+ import httpx
1429
+ resp = httpx.post(f"{base_url}/api/nn/reload", timeout=30.0)
1430
+ if resp.status_code == 200:
1431
+ logger.info(f"Live sim NN reloaded: {resp.json().get('message', 'ok')}")
1432
+ else:
1433
+ logger.warning(f"Could not reload live sim NN: HTTP {resp.status_code}")
1434
+ except Exception as e:
1435
+ logger.warning(f"Could not reach live sim for reload: {e}")
1436
+
1437
 
1438
  # ══════════════════════════════════════════════════════════════════════════
1439
  # CLI
 
1447
  python scripts/nn_train.py # Train from scratch
1448
  python scripts/nn_train.py --resume --epochs 50 # Continue training
1449
  python scripts/nn_train.py --data data/nn_training # Use collected samples
1450
+ python scripts/nn_train.py --push # Push existing model to HF Hub
1451
+ python scripts/nn_train.py --graph # Show graphs from last training
1452
  """,
1453
  )
1454
  parser.add_argument("--epochs", type=int, default=30, help="Training epochs (default: 30)")
 
1463
  parser.add_argument("--resume", action="store_true",
1464
  help="Resume from existing weights in models/")
1465
  parser.add_argument("--push", action="store_true",
1466
+ help="Push existing model to HuggingFace Hub (no training)")
1467
+ parser.add_argument("--graph", action="store_true",
1468
+ help="Display training graphs from last training run")
1469
  parser.add_argument("--repo", default="RayMelius/soci-agent-nn",
1470
  help="HF Hub repo ID (default: RayMelius/soci-agent-nn)")
1471
+ parser.add_argument("--url", default="https://raymelius-soci2.hf.space",
1472
+ help="Live simulation URL for hot-reload after push (default: HF Space)")
1473
  args = parser.parse_args()
1474
 
1475
+ # --graph: just display graphs and exit
1476
+ if args.graph:
1477
+ plot_training_graphs()
1478
+ return
1479
+
1480
+ # --push: just push existing model to HF Hub and exit
1481
+ if args.push:
1482
+ stats_path = MODEL_DIR / "training_stats.json"
1483
+ best_pt = MODEL_DIR / "soci_agent_best.pt"
1484
+ onnx_path = MODEL_DIR / "soci_agent.onnx"
1485
+ if stats_path.exists():
1486
+ stats = json.loads(stats_path.read_text())
1487
+ best_val_acc = stats.get("best_val_action_acc", 0)
1488
+ ep = stats.get("epochs", 0)
1489
+ n_train = stats.get("train_samples", 0)
1490
+ else:
1491
+ best_val_acc, ep, n_train = 0, 0, 0
1492
+ _push_to_hub(best_pt, onnx_path, stats_path, args.repo, best_val_acc, ep, n_train,
1493
+ base_url=args.url)
1494
+ return
1495
+
1496
+ # Default: train
1497
  train(
1498
  epochs=args.epochs,
1499
  batch_size=args.batch_size,
 
1502
  num_val=args.val_samples,
1503
  data_dir=args.data,
1504
  resume=args.resume,
 
 
1505
  )
1506
 
1507
 
src/soci/api/routes.py CHANGED
@@ -337,6 +337,38 @@ async def set_llm_provider(req: SwitchProviderRequest):
337
  raise HTTPException(status_code=500, detail=str(e))
338
 
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  @router.get("/llm/quota")
341
  async def get_llm_quota():
342
  """Return remaining daily quota and usage stats for budget planning.
 
337
  raise HTTPException(status_code=500, detail=str(e))
338
 
339
 
340
+ @router.post("/nn/reload")
341
+ async def reload_nn_model():
342
+ """Hot-reload the NN model from HuggingFace Hub without restarting."""
343
+ from soci.api.server import get_simulation, get_llm_provider
344
+ sim = get_simulation()
345
+
346
+ # If current provider is NN, reload directly
347
+ if get_llm_provider() == "nn":
348
+ from soci.engine.nn_client import NNClient
349
+ if isinstance(sim.llm, NNClient):
350
+ msg = sim.llm.reload()
351
+ return {"ok": True, "message": msg}
352
+
353
+ # NN not active — try to reload anyway if there's an NN client we can find
354
+ # or just re-download the model file for next time NN is activated
355
+ try:
356
+ from soci.engine.nn_client import _download_model, _MODEL_FILENAME
357
+ from pathlib import Path
358
+ local = Path("models") / _MODEL_FILENAME
359
+ if local.exists():
360
+ local.unlink()
361
+ path = _download_model()
362
+ size = Path(path).stat().st_size
363
+ return {
364
+ "ok": True,
365
+ "message": f"NN model re-downloaded ({size / 1024:.0f} KB). "
366
+ f"Switch to NN provider to use it.",
367
+ }
368
+ except Exception as e:
369
+ raise HTTPException(status_code=500, detail=f"Failed to reload NN model: {e}")
370
+
371
+
372
  @router.get("/llm/quota")
373
  async def get_llm_quota():
374
  """Return remaining daily quota and usage stats for budget planning.
src/soci/engine/nn_client.py CHANGED
@@ -292,8 +292,10 @@ class NNClient:
292
  "onnxruntime is required for the NN provider. "
293
  "Install it with: pip install onnxruntime"
294
  )
 
295
  if model_path is None:
296
  model_path = _download_model(repo_id)
 
297
  self.session = ort.InferenceSession(
298
  model_path,
299
  providers=["CPUExecutionProvider"],
@@ -302,6 +304,34 @@ class NNClient:
302
  self._last_error = ""
303
  logger.info(f"NN client loaded: {model_path}")
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  async def complete(
306
  self,
307
  system: str,
 
292
  "onnxruntime is required for the NN provider. "
293
  "Install it with: pip install onnxruntime"
294
  )
295
+ self._repo_id = repo_id
296
  if model_path is None:
297
  model_path = _download_model(repo_id)
298
+ self._model_path = model_path
299
  self.session = ort.InferenceSession(
300
  model_path,
301
  providers=["CPUExecutionProvider"],
 
304
  self._last_error = ""
305
  logger.info(f"NN client loaded: {model_path}")
306
 
307
+ def reload(self) -> str:
308
+ """Re-download the ONNX model from HF Hub and reload the session.
309
+
310
+ Returns a status message describing what happened.
311
+ """
312
+ local_path = Path(self._model_path)
313
+
314
+ # Delete cached model to force re-download
315
+ if local_path.exists():
316
+ old_size = local_path.stat().st_size
317
+ local_path.unlink()
318
+ logger.info(f"Deleted cached model ({old_size:,} bytes)")
319
+
320
+ # Re-download
321
+ new_path = _download_model(self._repo_id)
322
+ new_size = Path(new_path).stat().st_size
323
+
324
+ # Reload ONNX session
325
+ self.session = ort.InferenceSession(
326
+ new_path,
327
+ providers=["CPUExecutionProvider"],
328
+ )
329
+ self._model_path = new_path
330
+
331
+ msg = f"NN model reloaded from {self._repo_id} ({new_size / 1024:.0f} KB)"
332
+ logger.info(msg)
333
+ return msg
334
+
335
  async def complete(
336
  self,
337
  system: str,