AxiomForgeAI / src /rl /mdp_components.py
jampuramprem's picture
Initial Space deployment
ec4ae03
"""
MDP data structures for PPO self-improvement loop.
Defines the core components of the Markov Decision Process:
- State : text sequence at time t
- Action : token sampled from π_θ(·|s_t)
- Transition : (s_t, a_t, r_t, s_{t+1}, V(s_t), done)
- Trajectory : full episode τ = (s_0, a_0, r_0, …, s_T)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List
import torch
@dataclass
class State:
"""
Represents s_t = context token sequence at generation step t.
Attributes:
text : Decoded string (includes prompt).
input_ids : 1-D token-id tensor [seq_len].
attention_mask: 1-D mask tensor [seq_len].
phase : "question_generation" | "solution".
"""
text: str
input_ids: torch.Tensor
attention_mask: torch.Tensor
phase: str
@dataclass
class Action:
"""
Represents a_t = single token selected at step t.
Attributes:
token_id : Vocabulary index of the chosen token.
log_prob : log π_θ(a_t | s_t) (used for importance ratio).
entropy : H(π(·|s_t)) (used for entropy bonus).
"""
token_id: int
log_prob: float
entropy: float
@dataclass
class Transition:
"""
Single step in the MDP: (s_t, a_t, r_t, s_{t+1}, V(s_t), done).
Attributes:
state : s_t
action : a_t
reward : r_t (0.0 for non-terminal; sparse reward at episode end)
next_state: s_{t+1}
value : V(s_t) from critic at step t
done : Whether this is the terminal transition
"""
state: State
action: Action
reward: float
next_state: State
value: float
done: bool
class Trajectory:
"""
Complete episode τ = (s_0, a_0, r_0, …, s_T).
Provides helpers for reward summation and iteration.
"""
def __init__(self) -> None:
self.transitions: List[Transition] = []
self.metadata: Dict[str, Any] = {}
# ------------------------------------------------------------------
# Mutation
# ------------------------------------------------------------------
def add(self, transition: Transition) -> None:
"""Append a transition to the episode."""
self.transitions.append(transition)
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def total_reward(self) -> float:
"""Sum of all rewards in the episode R(τ) = Σ r_t."""
return sum(t.reward for t in self.transitions)
@property
def rewards(self) -> List[float]:
return [t.reward for t in self.transitions]
@property
def values(self) -> List[float]:
return [t.value for t in self.transitions]
@property
def log_probs(self) -> List[float]:
return [t.action.log_prob for t in self.transitions]
@property
def entropies(self) -> List[float]:
return [t.action.entropy for t in self.transitions]
@property
def dones(self) -> List[bool]:
return [t.done for t in self.transitions]
# ------------------------------------------------------------------
# Dunder helpers
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.transitions)
def __iter__(self) -> Iterator[Transition]:
return iter(self.transitions)
def __repr__(self) -> str: # pragma: no cover
return (
f"Trajectory(len={len(self)}, "
f"total_reward={self.total_reward:.3f})"
)