Spaces:
Sleeping
Sleeping
| """ | |
| MDP data structures for PPO self-improvement loop. | |
| Defines the core components of the Markov Decision Process: | |
| - State : text sequence at time t | |
| - Action : token sampled from π_θ(·|s_t) | |
| - Transition : (s_t, a_t, r_t, s_{t+1}, V(s_t), done) | |
| - Trajectory : full episode τ = (s_0, a_0, r_0, …, s_T) | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Iterator, List | |
| import torch | |
| class State: | |
| """ | |
| Represents s_t = context token sequence at generation step t. | |
| Attributes: | |
| text : Decoded string (includes prompt). | |
| input_ids : 1-D token-id tensor [seq_len]. | |
| attention_mask: 1-D mask tensor [seq_len]. | |
| phase : "question_generation" | "solution". | |
| """ | |
| text: str | |
| input_ids: torch.Tensor | |
| attention_mask: torch.Tensor | |
| phase: str | |
| class Action: | |
| """ | |
| Represents a_t = single token selected at step t. | |
| Attributes: | |
| token_id : Vocabulary index of the chosen token. | |
| log_prob : log π_θ(a_t | s_t) (used for importance ratio). | |
| entropy : H(π(·|s_t)) (used for entropy bonus). | |
| """ | |
| token_id: int | |
| log_prob: float | |
| entropy: float | |
| class Transition: | |
| """ | |
| Single step in the MDP: (s_t, a_t, r_t, s_{t+1}, V(s_t), done). | |
| Attributes: | |
| state : s_t | |
| action : a_t | |
| reward : r_t (0.0 for non-terminal; sparse reward at episode end) | |
| next_state: s_{t+1} | |
| value : V(s_t) from critic at step t | |
| done : Whether this is the terminal transition | |
| """ | |
| state: State | |
| action: Action | |
| reward: float | |
| next_state: State | |
| value: float | |
| done: bool | |
| class Trajectory: | |
| """ | |
| Complete episode τ = (s_0, a_0, r_0, …, s_T). | |
| Provides helpers for reward summation and iteration. | |
| """ | |
| def __init__(self) -> None: | |
| self.transitions: List[Transition] = [] | |
| self.metadata: Dict[str, Any] = {} | |
| # ------------------------------------------------------------------ | |
| # Mutation | |
| # ------------------------------------------------------------------ | |
| def add(self, transition: Transition) -> None: | |
| """Append a transition to the episode.""" | |
| self.transitions.append(transition) | |
| # ------------------------------------------------------------------ | |
| # Properties | |
| # ------------------------------------------------------------------ | |
| def total_reward(self) -> float: | |
| """Sum of all rewards in the episode R(τ) = Σ r_t.""" | |
| return sum(t.reward for t in self.transitions) | |
| def rewards(self) -> List[float]: | |
| return [t.reward for t in self.transitions] | |
| def values(self) -> List[float]: | |
| return [t.value for t in self.transitions] | |
| def log_probs(self) -> List[float]: | |
| return [t.action.log_prob for t in self.transitions] | |
| def entropies(self) -> List[float]: | |
| return [t.action.entropy for t in self.transitions] | |
| def dones(self) -> List[bool]: | |
| return [t.done for t in self.transitions] | |
| # ------------------------------------------------------------------ | |
| # Dunder helpers | |
| # ------------------------------------------------------------------ | |
| def __len__(self) -> int: | |
| return len(self.transitions) | |
| def __iter__(self) -> Iterator[Transition]: | |
| return iter(self.transitions) | |
| def __repr__(self) -> str: # pragma: no cover | |
| return ( | |
| f"Trajectory(len={len(self)}, " | |
| f"total_reward={self.total_reward:.3f})" | |
| ) | |