logiflow-rl / train_grpo.py
roshan5emerald's picture
Upload folder using huggingface_hub
47ee65f verified
from __future__ import annotations
import argparse
import inspect
import json
import random
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from statistics import mean
from typing import Any, Dict, Iterable, List
try:
import pandas as pd
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
except ImportError as exc:
raise RuntimeError(
"Missing GRPO training dependencies. Install with:\n"
" pip install -e .[train]\n"
"or install torch/trl/transformers/datasets/peft/accelerate manually."
) from exc
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
try:
from crisis_logistics_env.models import CrisisLogisticsAction
from crisis_logistics_env.server.crisis_logistics_env_environment import (
CrisisLogisticsEnvironment,
choose_network_action,
)
from crisis_logistics_env.tasks import list_tasks
except ImportError:
from models import CrisisLogisticsAction
from server.crisis_logistics_env_environment import (
CrisisLogisticsEnvironment,
choose_network_action,
)
from tasks import list_tasks
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
@dataclass
class EvalResult:
policy: str
task_id: str
score: float
total_reward: float
retail_delivered: float
sla_success_rate: float
priority_service_rate: float
invalid_actions: int
def build_prompt(observation, task_title: str) -> str:
return (
f"Task: {task_title}\n"
f"Objective: {observation.objective}\n"
f"Step: {observation.step_count + 1}/{observation.max_steps}\n"
f"Visible nodes: {observation.visible_node_ids}\n"
f"Observed node loads: {observation.observed_node_loads}\n"
f"Node capacities: {observation.node_capacities}\n"
f"Visible connectivity: {observation.visible_connectivity}\n"
f"Active disruptions: {observation.active_disruptions}\n"
f"In-transit shipments: {observation.in_transit_shipments[:8]}\n"
f"Incoming shipment: source={observation.pending_source_node}, volume={observation.incoming_load}\n"
f"Traffic event: {observation.event_label}\n"
f"Dynamic pressure: {observation.dynamic_pressure}\n"
f"Priority target: {observation.priority_target_name} (node {observation.priority_target_node})\n"
"Return exactly one JSON object with keys: reasoning, source_node, dest_node, shipment_volume."
)
def build_training_rows(samples_per_task: int = 42) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for task in list_tasks():
env = CrisisLogisticsEnvironment()
obs = env.reset(task_id=task.task_id)
for _ in range(samples_per_task):
source = int(obs.pending_source_node)
rows.append(
{
"prompt": build_prompt(obs, task.title),
"task_id": task.task_id,
"source_node": source,
"valid_dests": json.dumps(obs.connectivity.get(str(source), [])),
"incoming_load": float(obs.incoming_load),
"priority_target_node": int(obs.priority_target_node),
}
)
obs = env.step(choose_network_action(obs))
if obs.done:
break
return rows
def completion_to_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, dict):
content = completion.get("content", completion)
if isinstance(content, list):
chunks = []
for item in content:
if isinstance(item, dict) and "text" in item:
chunks.append(str(item["text"]))
else:
chunks.append(str(item))
return "\n".join(chunks)
return str(content)
if isinstance(completion, list):
chunks = []
for item in completion:
if isinstance(item, dict):
content = item.get("content", "")
if isinstance(content, list):
for part in content:
if isinstance(part, dict) and "text" in part:
chunks.append(str(part["text"]))
else:
chunks.append(str(part))
else:
chunks.append(str(content))
else:
chunks.append(str(item))
return "\n".join(chunks)
return str(completion)
def extract_json(text: str) -> Dict[str, Any]:
if not isinstance(text, str):
return {}
decoder = json.JSONDecoder()
candidates: List[Dict[str, Any]] = []
for index, char in enumerate(text):
if char != "{":
continue
try:
payload, _ = decoder.raw_decode(text[index:])
except Exception:
continue
if isinstance(payload, dict):
candidates.append(payload)
if not candidates:
return {}
required = {"reasoning", "source_node", "dest_node", "shipment_volume"}
for payload in reversed(candidates):
if required.issubset(payload.keys()):
return payload
return candidates[-1]
def as_batch(values: Any, n: int) -> List[Any]:
if values is None:
return [None] * n
if isinstance(values, (list, tuple)):
if len(values) == n:
return list(values)
if len(values) == 1:
return [values[0]] * n
return [values] * n
def action_reward(
completions,
source_node=None,
valid_dests=None,
incoming_load=None,
priority_target_node=None,
**kwargs,
):
n = len(completions)
source_batch = as_batch(source_node, n)
dest_batch = as_batch(valid_dests, n)
incoming_batch = as_batch(incoming_load, n)
priority_batch = as_batch(priority_target_node, n)
rewards = []
for index, completion in enumerate(completions):
reward = 0.0
parsed = extract_json(completion_to_text(completion))
if parsed:
reward += 0.20
required = {"reasoning", "source_node", "dest_node", "shipment_volume"}
if required.issubset(parsed.keys()):
reward += 0.20
try:
src = int(parsed.get("source_node", -999))
expected = int(source_batch[index]) if source_batch[index] is not None else None
if expected is not None and src == expected:
reward += 0.20
except Exception:
pass
try:
allowed = (
json.loads(dest_batch[index])
if isinstance(dest_batch[index], str)
else (dest_batch[index] or [])
)
dest = int(parsed.get("dest_node", -999))
if dest in allowed:
reward += 0.20
except Exception:
pass
try:
volume = float(parsed.get("shipment_volume", -1))
incoming = float(incoming_batch[index]) if incoming_batch[index] is not None else 10.0
if 0 < volume <= 60:
reward += 0.15
closeness = max(0.0, 1.0 - abs(volume - incoming) / max(1.0, incoming))
reward += 0.03 * closeness
except Exception:
pass
try:
target = int(priority_batch[index]) if priority_batch[index] is not None else None
dest = int(parsed.get("dest_node", -999))
if target is not None and dest == target:
reward += 0.02
except Exception:
pass
if not parsed:
reward -= 0.05
rewards.append(float(max(0.0, min(1.0, reward))))
return rewards
def reward_sanity_check() -> None:
good = ['{"reasoning":"route","source_node":2,"dest_node":4,"shipment_volume":11}']
bad = ["not json at all"]
good_score = action_reward(
good,
source_node=[2],
valid_dests=["[4,5]"],
incoming_load=[10],
priority_target_node=[4],
)[0]
bad_score = action_reward(
bad,
source_node=[2],
valid_dests=["[4,5]"],
incoming_load=[10],
priority_target_node=[4],
)[0]
print(f"reward_sanity good={good_score:.3f} bad={bad_score:.3f}")
if good_score <= 0.60 or bad_score >= 0.20:
raise RuntimeError("Reward sanity check failed; refusing to start GRPO training.")
def supported_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
supported = set(inspect.signature(cls).parameters)
dropped = sorted(set(kwargs) - supported)
if dropped:
print(f"Skipping unsupported {cls.__name__} args: {dropped}")
return {key: value for key, value in kwargs.items() if key in supported}
def generate_action(model, tokenizer, prompt: str) -> CrisisLogisticsAction:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=110,
do_sample=False,
temperature=0.0,
pad_token_id=tokenizer.eos_token_id,
)
prompt_tokens = int(inputs["input_ids"].shape[1])
generated_tokens = output[0][prompt_tokens:]
text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
if not text:
text = tokenizer.decode(output[0], skip_special_tokens=True).strip()
payload = extract_json(text)
if not payload:
return CrisisLogisticsAction(target_hub=0)
try:
return CrisisLogisticsAction(**payload)
except Exception:
return CrisisLogisticsAction(target_hub=0)
def run_policy(policy: str, task_id: str, model=None, tokenizer=None) -> EvalResult:
env = CrisisLogisticsEnvironment()
obs = env.reset(task_id=task_id)
round_robin_step = 0
while not obs.done:
if policy == "round_robin":
action = CrisisLogisticsAction(target_hub=round_robin_step % 3)
round_robin_step += 1
elif policy == "heuristic":
action = choose_network_action(obs)
elif policy == "model":
if model is None or tokenizer is None:
raise ValueError("model policy requires model and tokenizer")
action = generate_action(model, tokenizer, build_prompt(obs, env.task.title))
else:
raise ValueError(f"Unknown policy: {policy}")
obs = env.step(action)
return EvalResult(
policy=policy,
task_id=task_id,
score=float(env.score),
total_reward=float(env.total_reward),
retail_delivered=float(env.retail_delivered),
sla_success_rate=float(env._sla_success_rate()),
priority_service_rate=float(env._priority_service_rate()),
invalid_actions=int(env.invalid_actions),
)
def summarize_policy(results: Iterable[EvalResult], phase: str) -> pd.DataFrame:
frame = pd.DataFrame([asdict(item) for item in results])
summary = frame.groupby("policy", as_index=False).agg(
avg_score=("score", "mean"),
avg_reward=("total_reward", "mean"),
avg_sla=("sla_success_rate", "mean"),
)
summary["phase"] = phase
return summary
def extract_reward_history(log_history: List[Dict[str, Any]]) -> List[float]:
keys = ["rewards/mean", "reward", "train/reward", "objective/reward"]
history: List[float] = []
for row in log_history:
for key in keys:
if key in row and isinstance(row[key], (int, float)):
history.append(float(row[key]))
break
return history
def save_reward_curve_png(reward_history: List[float], output_path: Path) -> None:
if plt is not None:
plt.figure(figsize=(10, 4))
plt.plot(reward_history, linewidth=2)
plt.xlabel("Logging Step")
plt.ylabel("Mean Reward")
plt.title("GRPO Reward Curve - LogiFlow-RL")
plt.grid(alpha=0.2)
plt.savefig(output_path, dpi=160, bbox_inches="tight")
plt.close()
return
try:
from PIL import Image, ImageDraw
width, height = 1100, 540
margin_left, margin_right = 70, 30
margin_top, margin_bottom = 40, 60
plot_w = width - margin_left - margin_right
plot_h = height - margin_top - margin_bottom
image = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(image)
y_min = min(reward_history)
y_max = max(reward_history)
if abs(y_max - y_min) < 1e-6:
y_max = y_min + 1.0
def map_x(step: int) -> int:
if len(reward_history) <= 1:
return margin_left
return int(margin_left + (step / (len(reward_history) - 1)) * plot_w)
def map_y(value: float) -> int:
return int(margin_top + (1 - (value - y_min) / (y_max - y_min)) * plot_h)
draw.rectangle(
[margin_left, margin_top, margin_left + plot_w, margin_top + plot_h],
outline="#333333",
width=1,
)
points = [(map_x(step), map_y(value)) for step, value in enumerate(reward_history)]
if len(points) >= 2:
draw.line(points, fill="#1565C0", width=3)
elif len(points) == 1:
x, y = points[0]
draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill="#1565C0")
draw.text((margin_left, 10), "GRPO Reward Curve - LogiFlow-RL", fill="#111111")
draw.text((margin_left, height - 45), "Logging Step", fill="#111111")
draw.text((10, margin_top), "Mean Reward", fill="#111111")
image.save(output_path)
except Exception as exc:
raise RuntimeError(f"Could not save reward curve png without matplotlib: {exc}") from exc
def main() -> None:
parser = argparse.ArgumentParser(description="GRPO training script for LogiFlow-RL.")
parser.add_argument("--model-name", default="Qwen/Qwen2.5-0.5B-Instruct")
parser.add_argument("--samples-per-task", type=int, default=42)
parser.add_argument("--max-steps", type=int, default=140)
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--output-dir", default="outputs/logiflow-grpo-script")
args = parser.parse_args()
output_dir = Path(args.output_dir)
artifacts_dir = output_dir / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
reward_sanity_check()
train_rows = build_training_rows(samples_per_task=args.samples_per_task)
train_ds = Dataset.from_list(train_rows)
print(f"dataset rows={len(train_ds)}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
training_args = GRPOConfig(
**supported_kwargs(
GRPOConfig,
dict(
output_dir=str(output_dir),
learning_rate=1e-5,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_generations=args.num_generations,
max_prompt_length=1024,
max_completion_length=128,
max_steps=args.max_steps,
logging_steps=5,
save_steps=max(20, args.max_steps // 2),
report_to=[],
remove_unused_columns=False,
bf16=torch.cuda.is_available(),
),
)
)
trainer = GRPOTrainer(
**supported_kwargs(
GRPOTrainer,
dict(
model=args.model_name,
args=training_args,
train_dataset=train_ds,
reward_funcs=[action_reward],
peft_config=peft_config,
processing_class=tokenizer,
),
)
)
before_results: List[EvalResult] = []
for task in list_tasks():
before_results.append(run_policy("round_robin", task.task_id, trainer.model, tokenizer))
before_results.append(run_policy("heuristic", task.task_id, trainer.model, tokenizer))
before_results.append(run_policy("model", task.task_id, trainer.model, tokenizer))
trainer.train()
reward_history = extract_reward_history(trainer.state.log_history)
if not reward_history:
raise RuntimeError("No reward points found in trainer logs.")
reward_curve_path = artifacts_dir / "reward_curve.png"
save_reward_curve_png(reward_history, reward_curve_path)
after_results: List[EvalResult] = []
for task in list_tasks():
after_results.append(run_policy("round_robin", task.task_id, trainer.model, tokenizer))
after_results.append(run_policy("heuristic", task.task_id, trainer.model, tokenizer))
after_results.append(run_policy("model", task.task_id, trainer.model, tokenizer))
before_df = pd.DataFrame([asdict(item) for item in before_results])
after_df = pd.DataFrame([asdict(item) for item in after_results])
summary_before = summarize_policy(before_results, "before")
summary_after = summarize_policy(after_results, "after")
summary = pd.concat([summary_before, summary_after], ignore_index=True)
model_before = float(summary_before.loc[summary_before["policy"] == "model", "avg_score"].iloc[0])
model_after = float(summary_after.loc[summary_after["policy"] == "model", "avg_score"].iloc[0])
improvement = {
"model_avg_score_before": model_before,
"model_avg_score_after": model_after,
"delta_score": model_after - model_before,
"reward_history_points": len(reward_history),
"reward_history_mean": round(mean(reward_history), 6),
}
before_df.to_csv(artifacts_dir / "evaluation_before.csv", index=False)
after_df.to_csv(artifacts_dir / "evaluation_after.csv", index=False)
summary.to_csv(artifacts_dir / "evaluation_summary.csv", index=False)
payload = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"model_name": args.model_name,
"seed": SEED,
"before": before_df.to_dict(orient="records"),
"after": after_df.to_dict(orient="records"),
"summary": summary.to_dict(orient="records"),
"improvement": improvement,
}
(artifacts_dir / "evaluation_summary.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")
(artifacts_dir / "improvement.json").write_text(json.dumps(improvement, indent=2), encoding="utf-8")
trainer.save_model(str(output_dir))
print("Saved artifacts:")
print(f"- {reward_curve_path}")
print(f"- {artifacts_dir / 'evaluation_before.csv'}")
print(f"- {artifacts_dir / 'evaluation_after.csv'}")
print(f"- {artifacts_dir / 'evaluation_summary.csv'}")
print(f"- {artifacts_dir / 'evaluation_summary.json'}")
print(f"- {artifacts_dir / 'improvement.json'}")
if __name__ == "__main__":
main()