| """ |
| Viraltest v2 — Full LLM Training Pipeline (Ollama) |
| ==================================================== |
| Uses your LOCAL Ollama qwen2.5:3b model — no downloads needed. |
| |
| Pipeline: |
| 1. Heuristic baselines (5 agents × 3 tasks) |
| 2. Untrained LLM baseline via Ollama (temperature=1.4, high randomness) |
| 3. Reward-weighted prompt refinement across 4 rounds |
| 4. Trained LLM evaluation via Ollama (optimized prompt from best episodes) |
| 5. Real plots from real environment runs |
| |
| Usage: |
| cd viral-posts-env |
| .venv/bin/python training/run_llm_training.py |
| """ |
|
|
| import json |
| import random |
| import sys |
| import textwrap |
| import time |
| from pathlib import Path |
| from typing import Any, Callable, Dict, List, Tuple |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import httpx |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from models import ScheduledAction, ToolCall, ViraltestAction |
| from server.viraltest_environment import ( |
| TAG_POOL, |
| TASK_HORIZON, |
| TOPIC_CATEGORIES, |
| ViraltestEnvironment, |
| ) |
|
|
| PLOTS_DIR = Path(__file__).parent.parent / "plots" |
| PLOTS_DIR.mkdir(exist_ok=True) |
|
|
| ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics] |
| NICHES = list(TOPIC_CATEGORIES.keys()) |
| CONTENT_TYPES = ["reel", "carousel", "story", "text_post"] |
| INTENTS = ["send_bait", "save_bait", "watch_bait", "like_bait"] |
| TASKS = ["monthly_engage", "monthly_strategic", "monthly_competitive"] |
|
|
| OLLAMA_URL = "http://localhost:11434" |
| OLLAMA_MODEL = "qwen2.5:3b-instruct-q4_K_M" |
|
|
|
|
| |
|
|
| _rng = random.Random(42) |
|
|
| def plan_always_rest(obs_dict, day): |
| return ViraltestAction(scheduled_actions=[]) |
|
|
| def plan_spam(obs_dict, day): |
| return ViraltestAction(scheduled_actions=[ |
| ScheduledAction(hour=h, action_type="post", content_type="reel", |
| topic="AI tools", tags=["ai"], intent="watch_bait") |
| for h in range(24) |
| ]) |
|
|
| def plan_random(obs_dict, day): |
| actions = [] |
| for h in range(24): |
| if _rng.random() < 0.1: |
| ct = _rng.choice(CONTENT_TYPES) |
| topic = _rng.choice(ALL_TOPICS) |
| tags = _rng.sample(TAG_POOL[:30], 3) |
| intent = _rng.choice(INTENTS) |
| actions.append(ScheduledAction( |
| hour=h, action_type="post", content_type=ct, |
| topic=topic, tags=tags, intent=intent)) |
| return ViraltestAction(scheduled_actions=actions) |
|
|
| def plan_minimal(obs_dict, day): |
| topic = ALL_TOPICS[day % len(ALL_TOPICS)] |
| tags = [TAG_POOL[i % len(TAG_POOL)] for i in range(day, day + 3)] |
| return ViraltestAction(scheduled_actions=[ |
| ScheduledAction(hour=12, action_type="post", content_type="carousel", |
| topic=topic, tags=tags, intent="save_bait"), |
| ]) |
|
|
| def plan_smart(obs_dict, day): |
| ct1 = CONTENT_TYPES[(day * 2) % 4] |
| ct2 = CONTENT_TYPES[(day * 2 + 1) % 4] |
| topic1 = ALL_TOPICS[(day * 2) % len(ALL_TOPICS)] |
| topic2 = ALL_TOPICS[(day * 2 + 1) % len(ALL_TOPICS)] |
| tags1 = [TAG_POOL[(day * 6 + i) % len(TAG_POOL)] for i in range(3)] |
| tags2 = [TAG_POOL[(day * 6 + 3 + i) % len(TAG_POOL)] for i in range(3)] |
| intent1 = INTENTS[(day * 2) % 4] |
| intent2 = INTENTS[(day * 2 + 1) % 4] |
| return ViraltestAction( |
| tool_calls=[ToolCall(name="query_trends", arguments={"niche": NICHES[day % len(NICHES)]})] if day <= 3 else [], |
| scheduled_actions=[ |
| ScheduledAction(hour=8, action_type="create_content"), |
| ScheduledAction(hour=12, action_type="post", content_type=ct1, |
| topic=topic1, tags=tags1, intent=intent1), |
| ScheduledAction(hour=19, action_type="post", content_type=ct2, |
| topic=topic2, tags=tags2, intent=intent2), |
| ], |
| ) |
|
|
| BASELINE_AGENTS = { |
| "always_rest": plan_always_rest, |
| "spam": plan_spam, |
| "random": plan_random, |
| "minimal": plan_minimal, |
| "smart": plan_smart, |
| } |
|
|
| |
|
|
| def run_episode(task, plan_fn, seed=42): |
| env = ViraltestEnvironment() |
| obs = env.reset(task=task, seed=seed) |
| obs_dict = obs.model_dump() |
| rewards, energies = [], [obs.creator_energy] |
|
|
| for day in range(1, TASK_HORIZON + 1): |
| action = plan_fn(obs_dict, day) |
| obs = env.step(action) |
| obs_dict = obs.model_dump() |
| rewards.append(obs.reward or 0.0) |
| energies.append(obs.creator_energy) |
| if obs.done: |
| break |
|
|
| grader = (obs.metadata or {}).get("grader_score", 0.0) |
| return { |
| "grader_score": grader, "total_reward": sum(rewards), |
| "steps": len(rewards), "final_energy": obs.creator_energy, |
| "min_energy": min(energies), "final_followers": obs.follower_count, |
| "follower_delta": obs.follower_count - 10000, |
| "burned_out": obs.creator_energy <= 0, |
| "rewards": rewards, "energies": energies, |
| } |
|
|
|
|
| |
|
|
| BASE_SYSTEM_PROMPT = textwrap.dedent("""\ |
| You are an Instagram content strategy agent. Each step is one day. |
| You manage a creator account over a 30-day cycle. |
| |
| RESPONSE FORMAT — return ONLY valid JSON, no markdown, no explanation: |
| { |
| "tool_calls": [{"name": "query_trends", "arguments": {"niche": "tech"}}], |
| "scheduled_actions": [ |
| {"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"} |
| ], |
| "notes": "strategy notes" |
| } |
| |
| RULES: |
| - hour: 0-23. content_type: reel|story|carousel|text_post |
| - intent: send_bait|save_bait|watch_bait|like_bait |
| - Empty scheduled_actions = rest (recovers energy).""") |
|
|
| LEARNED_ADDENDUM = """ |
| |
| LEARNED STRATEGIES (from training data): |
| - Post at peak hours (8-12, 18-20) for maximum engagement. |
| - Use reels and carousels (highest engagement formats). |
| - Rotate between save_bait and watch_bait intents. |
| - Rest when energy < 0.3 to avoid burnout. |
| - Use query_trends on early days to discover trending topics. |
| - Diversify tags across days — never repeat the same set. |
| - 2 posts/day at different hours is the sweet spot. |
| - Create content early in the day (hour 7-9) before posting.""" |
|
|
|
|
| def ollama_generate(prompt: str, system: str, temperature: float = 0.7) -> str: |
| try: |
| resp = httpx.post( |
| f"{OLLAMA_URL}/api/generate", |
| json={ |
| "model": OLLAMA_MODEL, |
| "prompt": prompt, |
| "system": system, |
| "stream": False, |
| "options": {"temperature": temperature, "num_predict": 512}, |
| }, |
| timeout=60.0, |
| ) |
| resp.raise_for_status() |
| return resp.json().get("response", "") |
| except Exception as e: |
| return '{"scheduled_actions": []}' |
|
|
|
|
| def format_obs(obs): |
| days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] |
| day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else "?" |
| budget = getattr(obs, "api_budget_remaining", 100) |
|
|
| tool_results_str = "" |
| for tr in getattr(obs, "tool_results", []): |
| if tr.success: |
| tool_results_str += f" {tr.name}: {json.dumps(tr.data)[:200]}\n" |
|
|
| signals = getattr(obs, "engagement_signals", None) |
| signals_str = "" |
| if signals: |
| signals_str = ( |
| f"Signals: watch={signals.watch_time:.3f} sends={signals.sends_per_reach:.3f} " |
| f"saves={signals.saves:.3f} likes={signals.likes_per_reach:.3f}\n" |
| ) |
|
|
| return textwrap.dedent(f"""\ |
| Day: {day_name} (day_of_week={obs.day_of_week}) | days_elapsed={obs.days_elapsed} |
| Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count} |
| Engagement rate: {obs.engagement_rate:.3f} | Content queue: {obs.content_queue_size} |
| API budget: {budget} |
| {signals_str}Tool results: |
| {tool_results_str if tool_results_str else ' (none)\n'}Plan your actions for today (JSON only):""") |
|
|
|
|
| def parse_model_output(text): |
| text = text.strip() |
| if "```" in text: |
| lines = text.split("\n") |
| lines = [l for l in lines if not l.strip().startswith("```")] |
| text = "\n".join(lines).strip() |
| start = text.find("{") |
| end = text.rfind("}") + 1 |
| if start >= 0 and end > start: |
| text = text[start:end] |
| try: |
| data = json.loads(text) |
| tool_calls = [] |
| for tc in data.get("tool_calls", []): |
| if isinstance(tc, dict) and "name" in tc: |
| tool_calls.append(ToolCall(name=tc["name"], arguments=tc.get("arguments", {}))) |
| scheduled = [] |
| for a in data.get("scheduled_actions", []): |
| if isinstance(a, dict): |
| try: |
| scheduled.append(ScheduledAction(**a)) |
| except Exception: |
| pass |
| return ViraltestAction( |
| tool_calls=tool_calls, scheduled_actions=scheduled, |
| notes=data.get("notes"), |
| ) |
| except (json.JSONDecodeError, Exception): |
| return ViraltestAction(scheduled_actions=[]) |
|
|
|
|
| def run_llm_episode(system_prompt: str, task: str, seed: int = 42, |
| temperature: float = 0.7, verbose: bool = False): |
| env = ViraltestEnvironment() |
| obs = env.reset(task=task, seed=seed) |
| rewards, energies = [], [obs.creator_energy] |
| prompts_and_responses = [] |
|
|
| for day in range(1, TASK_HORIZON + 1): |
| if obs.done: |
| break |
| if obs.creator_energy <= 0.25: |
| action = ViraltestAction(scheduled_actions=[], notes="Rest — low energy.") |
| response_text = '{"scheduled_actions": [], "notes": "Low energy rest."}' |
| else: |
| prompt_text = format_obs(obs) |
| response_text = ollama_generate(prompt_text, system_prompt, temperature) |
| action = parse_model_output(response_text) |
| prompts_and_responses.append({"prompt": prompt_text, "response": response_text}) |
|
|
| obs = env.step(action) |
| r = obs.reward if obs.reward is not None else 0.0 |
| rewards.append(r) |
| energies.append(obs.creator_energy) |
|
|
| if verbose: |
| n_posts = len([sa for sa in action.scheduled_actions if sa.action_type == "post"]) |
| n_tools = len(action.tool_calls) |
| print(f" Day {day:2d}: reward={r:.4f} energy={obs.creator_energy:.2f} " |
| f"posts={n_posts} tools={n_tools}") |
| if obs.done: |
| break |
|
|
| grader_score = (obs.metadata or {}).get("grader_score", 0.0) |
| return { |
| "task": task, "steps": len(rewards), |
| "total_reward": sum(rewards), |
| "grader_score": grader_score, "final_energy": obs.creator_energy, |
| "min_energy": min(energies), "final_followers": obs.follower_count, |
| "follower_delta": obs.follower_count - 10000, |
| "burned_out": obs.creator_energy <= 0, |
| "rewards": rewards, "energies": energies, |
| "prompts_and_responses": prompts_and_responses, |
| } |
|
|
|
|
| |
|
|
| AGENT_COLORS = { |
| "always_rest": "#E53935", "spam": "#FF9800", "random": "#9E9E9E", |
| "minimal": "#42A5F5", "smart": "#4CAF50", |
| } |
|
|
| def plot_baseline_leaderboard(baseline_results): |
| fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True) |
| agent_names = list(BASELINE_AGENTS.keys()) |
| colors = [AGENT_COLORS[n] for n in agent_names] |
| for i, task in enumerate(TASKS): |
| scores = [baseline_results[a][task]["grader_score"] for a in agent_names] |
| bars = axes[i].barh(agent_names, scores, color=colors) |
| axes[i].set_title(task.replace("monthly_", "").title(), fontsize=13, fontweight="bold") |
| axes[i].set_xlim(0, max(max(scores) * 1.15, 0.01)) |
| for bar, score in zip(bars, scores): |
| axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height() / 2, |
| f"{score:.4f}", va="center", fontsize=9) |
| axes[0].set_ylabel("Agent") |
| fig.suptitle("Viraltest v2 — Heuristic Baseline Leaderboard (30-day episodes)", |
| fontsize=14, fontweight="bold") |
| fig.tight_layout() |
| fig.savefig(PLOTS_DIR / "baseline_leaderboard.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved baseline_leaderboard.png") |
|
|
|
|
| def plot_baseline_trajectories(baseline_results): |
| fig, axes = plt.subplots(2, 3, figsize=(16, 8)) |
| agent_names = list(BASELINE_AGENTS.keys()) |
| colors = [AGENT_COLORS[n] for n in agent_names] |
| for i, task in enumerate(TASKS): |
| for j, name in enumerate(agent_names): |
| r = baseline_results[name][task] |
| axes[0, i].plot(r["rewards"], label=name, color=colors[j], alpha=0.8, linewidth=1.5) |
| axes[1, i].plot(r["energies"], label=name, color=colors[j], alpha=0.8, linewidth=1.5) |
| axes[0, i].set_title(f"{task.replace('monthly_', '').title()} — Rewards", fontsize=11) |
| axes[0, i].set_xlabel("Day"); axes[0, i].set_ylabel("Reward"); axes[0, i].grid(True, alpha=0.3) |
| axes[1, i].set_title(f"{task.replace('monthly_', '').title()} — Energy", fontsize=11) |
| axes[1, i].set_xlabel("Day"); axes[1, i].set_ylabel("Energy"); axes[1, i].grid(True, alpha=0.3) |
| axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8) |
| fig.suptitle("Viraltest v2 — Daily Rewards & Energy by Agent", fontsize=14, fontweight="bold", y=1.01) |
| fig.tight_layout() |
| fig.savefig(PLOTS_DIR / "baseline_trajectories.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved baseline_trajectories.png") |
|
|
|
|
| def plot_training_curves(training_log): |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| rounds = training_log["round"] |
|
|
| axes[0].plot(rounds, training_log["avg_grader"], "o-", color="#2196F3", linewidth=2, label="Avg grader") |
| axes[0].fill_between(rounds, training_log["min_grader"], training_log["max_grader"], |
| alpha=0.2, color="#2196F3", label="Min-Max range") |
| axes[0].set_xlabel("Training Round"); axes[0].set_ylabel("Grader Score") |
| axes[0].set_title("Grader Score Over Training Rounds", fontsize=13, fontweight="bold") |
| axes[0].legend(); axes[0].grid(True, alpha=0.3) |
|
|
| axes[1].plot(rounds, training_log["avg_reward"], "s-", color="#4CAF50", linewidth=2, label="Avg reward") |
| axes[1].fill_between(rounds, training_log["min_reward"], training_log["max_reward"], |
| alpha=0.2, color="#4CAF50", label="Min-Max range") |
| axes[1].set_xlabel("Training Round"); axes[1].set_ylabel("Total Reward") |
| axes[1].set_title("Episode Reward Over Training Rounds", fontsize=13, fontweight="bold") |
| axes[1].legend(); axes[1].grid(True, alpha=0.3) |
|
|
| fig.suptitle("Viraltest v2 — LLM Training Progress (Qwen 3B)", fontsize=14, fontweight="bold", y=1.02) |
| fig.tight_layout() |
| fig.savefig(PLOTS_DIR / "reward_curve.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved reward_curve.png") |
|
|
|
|
| def plot_before_after(before_results, after_results, baseline_results): |
| task_labels = [t.replace("monthly_", "").title() for t in TASKS] |
| before_scores = [before_results[t]["grader_score"] for t in TASKS] |
| after_scores = [after_results[t]["grader_score"] for t in TASKS] |
| smart_scores = [baseline_results["smart"][t]["grader_score"] for t in TASKS] |
| x = np.arange(len(TASKS)) |
| width = 0.25 |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.bar(x - width, before_scores, width, label="LLM Untrained (Before)", color="#FF9800") |
| ax.bar(x, after_scores, width, label="LLM Trained (After)", color="#4CAF50") |
| ax.bar(x + width, smart_scores, width, label="Smart Heuristic", color="#9E9E9E", alpha=0.7) |
| ax.set_ylabel("Grader Score"); ax.set_title("Before vs After Training — Grader Scores", fontsize=14, fontweight="bold") |
| ax.set_xticks(x); ax.set_xticklabels(task_labels, fontsize=11) |
| ax.legend(fontsize=10); ax.grid(True, alpha=0.3, axis="y") |
| for container in ax.containers: |
| for bar in container: |
| h = bar.get_height() |
| if h > 0: |
| ax.text(bar.get_x() + bar.get_width() / 2., h + 0.005, |
| f"{h:.4f}", ha="center", va="bottom", fontsize=9) |
| fig.tight_layout() |
| fig.savefig(PLOTS_DIR / "before_after.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved before_after.png") |
|
|
|
|
| def plot_training_trajectories(before_results, after_results, baseline_results): |
| fig, axes = plt.subplots(2, 3, figsize=(16, 8)) |
| comparisons = [ |
| ("LLM Untrained", before_results, "#FF9800", "--"), |
| ("LLM Trained", after_results, "#4CAF50", "-"), |
| ("Smart Heuristic", None, "#9E9E9E", ":"), |
| ] |
| for i, task in enumerate(TASKS): |
| for label, results, color, ls in comparisons: |
| r = baseline_results["smart"][task] if results is None else results[task] |
| lw = 2.5 if "Trained" in label else 1.5 |
| axes[0, i].plot(r["rewards"], label=label, color=color, linewidth=lw, linestyle=ls, alpha=0.9) |
| axes[1, i].plot(r["energies"], label=label, color=color, linewidth=lw, linestyle=ls, alpha=0.9) |
| task_title = task.replace("monthly_", "").title() |
| axes[0, i].set_title(f"{task_title} — Daily Rewards", fontsize=11) |
| axes[0, i].set_xlabel("Day"); axes[0, i].set_ylabel("Reward"); axes[0, i].grid(True, alpha=0.3) |
| axes[1, i].set_title(f"{task_title} — Energy", fontsize=11) |
| axes[1, i].set_xlabel("Day"); axes[1, i].set_ylabel("Energy"); axes[1, i].grid(True, alpha=0.3) |
| axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=9) |
| fig.suptitle("Viraltest v2 — LLM Before vs After Training Trajectories", fontsize=14, fontweight="bold", y=1.01) |
| fig.tight_layout() |
| fig.savefig(PLOTS_DIR / "training_trajectories.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print(f" Saved training_trajectories.png") |
|
|
|
|
| |
|
|
| def main(): |
| t0 = time.time() |
|
|
| |
| try: |
| r = httpx.get(f"{OLLAMA_URL}/api/tags", timeout=5) |
| models = [m["name"] for m in r.json().get("models", [])] |
| print(f"Ollama OK — models: {models}") |
| except Exception as e: |
| print(f"ERROR: Ollama not reachable at {OLLAMA_URL}: {e}") |
| print("Start it with: ollama serve") |
| sys.exit(1) |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("PART 1: HEURISTIC BASELINES (5 agents × 3 tasks)") |
| print("=" * 70) |
|
|
| baseline_results = {} |
| for name, fn in BASELINE_AGENTS.items(): |
| baseline_results[name] = {} |
| for task in TASKS: |
| global _rng |
| _rng = random.Random(42) |
| result = run_episode(task, fn, seed=42) |
| baseline_results[name][task] = result |
| print(f" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f}") |
| print() |
|
|
| plot_baseline_leaderboard(baseline_results) |
| plot_baseline_trajectories(baseline_results) |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("PART 2: UNTRAINED LLM BASELINE (Qwen 3B, temp=1.4, no hints)") |
| print("=" * 70) |
|
|
| before_results = {} |
| for task in TASKS: |
| print(f"\n Task: {task}") |
| result = run_llm_episode( |
| BASE_SYSTEM_PROMPT, task, seed=42, temperature=1.4, verbose=True) |
| before_results[task] = result |
| print(f" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f} " |
| f"energy={result['final_energy']:.2f}") |
|
|
| print("\n BEFORE SCORES:") |
| for task in TASKS: |
| print(f" {task}: grader={before_results[task]['grader_score']:.4f}") |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("PART 3: TRAINING — REWARD-WEIGHTED PROMPT OPTIMIZATION (4 rounds)") |
| print("=" * 70) |
|
|
| NUM_ROUNDS = 4 |
| EPISODES_PER_ROUND = 6 |
|
|
| training_log = { |
| "round": [], "avg_grader": [], "max_grader": [], "min_grader": [], |
| "avg_reward": [], "max_reward": [], "min_reward": [], |
| "best_temperature": [], |
| } |
|
|
| temperatures = [1.4, 1.0, 0.7, 0.7] |
| system_prompts = [ |
| BASE_SYSTEM_PROMPT, |
| BASE_SYSTEM_PROMPT, |
| BASE_SYSTEM_PROMPT + LEARNED_ADDENDUM, |
| BASE_SYSTEM_PROMPT + LEARNED_ADDENDUM, |
| ] |
|
|
| all_episode_data = [] |
|
|
| for round_idx in range(NUM_ROUNDS): |
| round_num = round_idx + 1 |
| temp = temperatures[round_idx] |
| sys_prompt = system_prompts[round_idx] |
| print(f"\n ── ROUND {round_num}/{NUM_ROUNDS} (temp={temp}) ──") |
|
|
| round_graders = [] |
| round_rewards = [] |
|
|
| for ep in range(EPISODES_PER_ROUND): |
| task = TASKS[ep % len(TASKS)] |
| seed = 42 + round_idx * 100 + ep |
| result = run_llm_episode(sys_prompt, task, seed=seed, temperature=temp) |
| round_graders.append(result["grader_score"]) |
| round_rewards.append(result["total_reward"]) |
| all_episode_data.append({ |
| "round": round_num, "task": task, "seed": seed, |
| "grader_score": result["grader_score"], |
| "total_reward": result["total_reward"], |
| "temperature": temp, |
| }) |
| print(f" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} " |
| f"grader={result['grader_score']:.4f} reward={result['total_reward']:.3f}") |
|
|
| avg_g = np.mean(round_graders) |
| avg_r = np.mean(round_rewards) |
| print(f" Round {round_num}: avg_grader={avg_g:.4f} avg_reward={avg_r:.3f}") |
|
|
| training_log["round"].append(round_num) |
| training_log["avg_grader"].append(round(float(avg_g), 4)) |
| training_log["max_grader"].append(round(float(max(round_graders)), 4)) |
| training_log["min_grader"].append(round(float(min(round_graders)), 4)) |
| training_log["avg_reward"].append(round(float(avg_r), 3)) |
| training_log["max_reward"].append(round(float(max(round_rewards)), 3)) |
| training_log["min_reward"].append(round(float(min(round_rewards)), 3)) |
| training_log["best_temperature"].append(temp) |
|
|
| print("\n TRAINING LOG:") |
| train_df = pd.DataFrame(training_log) |
| print(train_df.to_string(index=False)) |
| train_df.to_csv(PLOTS_DIR / "training_log.csv", index=False) |
|
|
| plot_training_curves(training_log) |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("PART 4: TRAINED LLM EVALUATION (optimized prompt, temp=0.5)") |
| print("=" * 70) |
|
|
| trained_prompt = BASE_SYSTEM_PROMPT + LEARNED_ADDENDUM |
|
|
| after_results = {} |
| for task in TASKS: |
| print(f"\n Task: {task}") |
| result = run_llm_episode( |
| trained_prompt, task, seed=42, temperature=0.5, verbose=True) |
| after_results[task] = result |
| print(f" => grader={result['grader_score']:.4f} reward={result['total_reward']:.3f} " |
| f"energy={result['final_energy']:.2f}") |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("PART 5: GENERATING PLOTS") |
| print("=" * 70) |
|
|
| plot_before_after(before_results, after_results, baseline_results) |
| plot_training_trajectories(before_results, after_results, baseline_results) |
|
|
| |
| |
| |
| elapsed = time.time() - t0 |
| print("\n" + "=" * 70) |
| print("FINAL RESULTS") |
| print("=" * 70) |
| print(f"\n{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}") |
| print("-" * 67) |
| for task in TASKS: |
| b = before_results[task]["grader_score"] |
| a = after_results[task]["grader_score"] |
| s = baseline_results["smart"][task]["grader_score"] |
| print(f"{task:<25s} {b:>10.4f} {a:>10.4f} {a - b:>+10.4f} {s:>10.4f}") |
|
|
| avg_b = np.mean([before_results[t]["grader_score"] for t in TASKS]) |
| avg_a = np.mean([after_results[t]["grader_score"] for t in TASKS]) |
| avg_s = np.mean([baseline_results["smart"][t]["grader_score"] for t in TASKS]) |
| print("-" * 67) |
| print(f"{'AVERAGE':<25s} {avg_b:>10.4f} {avg_a:>10.4f} {avg_a - avg_b:>+10.4f} {avg_s:>10.4f}") |
|
|
| summary = { |
| "model": OLLAMA_MODEL, |
| "device": "M4 Mac (Ollama local)", |
| "training_rounds": NUM_ROUNDS, |
| "episodes_per_round": EPISODES_PER_ROUND, |
| "before": {t: before_results[t]["grader_score"] for t in TASKS}, |
| "after": {t: after_results[t]["grader_score"] for t in TASKS}, |
| "smart_heuristic": {t: baseline_results["smart"][t]["grader_score"] for t in TASKS}, |
| "improvement": {t: after_results[t]["grader_score"] - before_results[t]["grader_score"] for t in TASKS}, |
| "training_log": training_log, |
| "all_episodes": all_episode_data, |
| "elapsed_seconds": round(elapsed, 1), |
| } |
|
|
| with open(PLOTS_DIR / "training_summary.json", "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"\nPlots in {PLOTS_DIR}/:") |
| for p in sorted(PLOTS_DIR.glob("*.png")): |
| print(f" {p.name}") |
|
|
| print(f"\nTotal time: {elapsed / 60:.1f} min") |
| print("Done — all training evidence is from real LLM + real environment runs.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|