Spaces:
Sleeping
Sleeping
feat: Implement Context-Pruning-Env with SQuAD dataset and GRPOTrainer support
Browse files- Dockerfile.openenv +27 -0
- context_pruning_env/env.py +124 -0
- context_pruning_env/models.py +43 -0
- context_pruning_env/server/__init__.py +1 -0
- context_pruning_env/server/app.py +18 -0
- context_pruning_env/utils.py +45 -0
- requirements.txt +4 -0
- test_env.py +64 -0
- train_grpo.py +67 -0
Dockerfile.openenv
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PYTHONPATH=/app
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
build-essential \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy and install Python dependencies
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy the environment code
|
| 20 |
+
COPY context_pruning_env ./context_pruning_env
|
| 21 |
+
|
| 22 |
+
# Expose the default OpenEnv port
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# Command to run the environment server (standardized OpenEnv entrypoint)
|
| 26 |
+
# In a real environment, you'd use a server wrapper mapping Gymnasium resets/steps to API calls.
|
| 27 |
+
CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
context_pruning_env/env.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Any, Optional, List
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server.interfaces import Environment
|
| 6 |
+
from context_pruning_env.models import (
|
| 7 |
+
PruningAction,
|
| 8 |
+
PruningObservation,
|
| 9 |
+
PruningState,
|
| 10 |
+
ChunkItem
|
| 11 |
+
)
|
| 12 |
+
from context_pruning_env.utils import SQuADLoader, count_tokens
|
| 13 |
+
|
| 14 |
+
class ContextPruningEnv(Environment[PruningAction, PruningObservation, PruningState]):
|
| 15 |
+
"""
|
| 16 |
+
OpenEnv Reinforcement Learning Environment for RAG Context Pruning.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, squad_split: str = "train"):
|
| 20 |
+
super().__init__(transform=None, rubric=None)
|
| 21 |
+
self.loader = SQuADLoader(split=squad_split)
|
| 22 |
+
self._state = None
|
| 23 |
+
|
| 24 |
+
def reset(
|
| 25 |
+
self,
|
| 26 |
+
seed: Optional[int] = None,
|
| 27 |
+
episode_id: Optional[str] = None,
|
| 28 |
+
**kwargs: Any,
|
| 29 |
+
) -> PruningObservation:
|
| 30 |
+
"""
|
| 31 |
+
Loads a new question and 5 context chunks from SQuAD.
|
| 32 |
+
Returns the initial observation.
|
| 33 |
+
"""
|
| 34 |
+
question, chunks_data = self.loader.get_episode()
|
| 35 |
+
|
| 36 |
+
# Prepare internal state chunks
|
| 37 |
+
chunks = []
|
| 38 |
+
total_tokens = 0
|
| 39 |
+
for i, (text, is_gold) in enumerate(chunks_data):
|
| 40 |
+
tokens = count_tokens(text)
|
| 41 |
+
total_tokens += tokens
|
| 42 |
+
chunks.append(ChunkItem(
|
| 43 |
+
content=text,
|
| 44 |
+
is_gold=is_gold,
|
| 45 |
+
tokens=tokens
|
| 46 |
+
))
|
| 47 |
+
if is_gold:
|
| 48 |
+
gold_index = i
|
| 49 |
+
|
| 50 |
+
self._state = PruningState(
|
| 51 |
+
episode_id=episode_id or str(uuid4()),
|
| 52 |
+
question=question,
|
| 53 |
+
gold_index=gold_index,
|
| 54 |
+
chunks=chunks,
|
| 55 |
+
initial_tokens=total_tokens,
|
| 56 |
+
step_count=0,
|
| 57 |
+
done=False
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return self._observe(
|
| 61 |
+
message="Environment reset. 5 chunks loaded (1 gold, 4 noise)."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def _observe(self, message: str = "") -> PruningObservation:
|
| 65 |
+
"""Helper to create observation from current state."""
|
| 66 |
+
return PruningObservation(
|
| 67 |
+
done=self._state.done,
|
| 68 |
+
question=self._state.question,
|
| 69 |
+
chunks=[c.content for c in self._state.chunks],
|
| 70 |
+
token_count=sum(c.tokens for c in self._state.chunks),
|
| 71 |
+
message=message
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def step(
|
| 75 |
+
self,
|
| 76 |
+
action: PruningAction,
|
| 77 |
+
**kwargs: Any,
|
| 78 |
+
) -> PruningObservation:
|
| 79 |
+
"""
|
| 80 |
+
Evaluates the binary mask, calculates token reduction,
|
| 81 |
+
checks gold chunk presence, and returns the observation.
|
| 82 |
+
"""
|
| 83 |
+
if self._state.done:
|
| 84 |
+
return self._observe(message="Episode is already done.")
|
| 85 |
+
|
| 86 |
+
mask = action.mask
|
| 87 |
+
if len(mask) != 5:
|
| 88 |
+
# Should not happen if using Gymnasium space or Pydantic validation
|
| 89 |
+
self._state.done = True
|
| 90 |
+
return self._observe(message="Invalid action space size.")
|
| 91 |
+
|
| 92 |
+
# 1. Identify Gold Chunk Status
|
| 93 |
+
gold_kept = (mask[self._state.gold_index] == 1)
|
| 94 |
+
|
| 95 |
+
# 2. Calculate Token reduction
|
| 96 |
+
pruned_tokens = 0
|
| 97 |
+
for i, keep in enumerate(mask):
|
| 98 |
+
if keep == 0:
|
| 99 |
+
pruned_tokens += self._state.chunks[i].tokens
|
| 100 |
+
|
| 101 |
+
# 3. Reward Logic
|
| 102 |
+
reward = 0.0
|
| 103 |
+
|
| 104 |
+
if gold_kept:
|
| 105 |
+
reward += 10.0 # Accuracy Bonus
|
| 106 |
+
reward += 0.01 * pruned_tokens # Efficiency Bonus
|
| 107 |
+
msg = f"Task Success: Gold chunk kept. Pruned {pruned_tokens} tokens."
|
| 108 |
+
else:
|
| 109 |
+
reward -= 20.0 # Penalty: Lost the game
|
| 110 |
+
msg = "Task Failure: Gold chunk was pruned. Mission failed."
|
| 111 |
+
|
| 112 |
+
self._state.done = True
|
| 113 |
+
self._state.step_count += 1
|
| 114 |
+
|
| 115 |
+
# In OpenEnv, the reward is often part of the Observation or signaled via a rubric.
|
| 116 |
+
# We manually update the reward field here.
|
| 117 |
+
obs = self._observe(message=msg)
|
| 118 |
+
obs.reward = reward
|
| 119 |
+
|
| 120 |
+
return obs
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def state(self) -> PruningState:
|
| 124 |
+
return self._state
|
context_pruning_env/models.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import List, Optional, Any
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 5 |
+
|
| 6 |
+
class PruningAction(Action):
|
| 7 |
+
"""
|
| 8 |
+
Action space: A binary mask of 5 values (1 = keep, 0 = prune).
|
| 9 |
+
Example: [1, 0, 1, 1, 0]
|
| 10 |
+
"""
|
| 11 |
+
mask: List[int] = Field(
|
| 12 |
+
...,
|
| 13 |
+
min_items=5,
|
| 14 |
+
max_items=5,
|
| 15 |
+
description="Binary mask of 5 integers (0 or 1) indicating which chunks to keep."
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
class ChunkItem(BaseModel):
|
| 19 |
+
"""Represent a single context chunk."""
|
| 20 |
+
content: str
|
| 21 |
+
is_gold: bool = False
|
| 22 |
+
tokens: int = 0
|
| 23 |
+
|
| 24 |
+
class PruningObservation(Observation):
|
| 25 |
+
"""
|
| 26 |
+
Observation provided to the agent.
|
| 27 |
+
Contains the question and the 5 context chunks.
|
| 28 |
+
"""
|
| 29 |
+
question: str
|
| 30 |
+
chunks: List[str] = Field(default_factory=list, description="List of 5 context strings.")
|
| 31 |
+
token_count: int = 0
|
| 32 |
+
message: str = ""
|
| 33 |
+
|
| 34 |
+
class PruningState(State):
|
| 35 |
+
"""
|
| 36 |
+
Internal state of the environment.
|
| 37 |
+
"""
|
| 38 |
+
question: str
|
| 39 |
+
gold_index: int
|
| 40 |
+
chunks: List[ChunkItem]
|
| 41 |
+
initial_tokens: int
|
| 42 |
+
step_count: int = 0
|
| 43 |
+
done: bool = False
|
context_pruning_env/server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server package
|
context_pruning_env/server/app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openenv.core.env_server.http_server import create_fastapi_app
|
| 3 |
+
from context_pruning_env.env import ContextPruningEnv
|
| 4 |
+
from context_pruning_env.models import PruningAction, PruningObservation
|
| 5 |
+
|
| 6 |
+
app = create_fastapi_app(
|
| 7 |
+
ContextPruningEnv,
|
| 8 |
+
PruningAction,
|
| 9 |
+
PruningObservation,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
def main() -> None:
|
| 13 |
+
import uvicorn
|
| 14 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 15 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
main()
|
context_pruning_env/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class SQuADLoader:
|
| 9 |
+
def __init__(self, split: str = "train"):
|
| 10 |
+
self.dataset = load_dataset("squad", split=split)
|
| 11 |
+
self.indices = list(range(len(self.dataset)))
|
| 12 |
+
random.shuffle(self.indices)
|
| 13 |
+
self.current_ptr = 0
|
| 14 |
+
|
| 15 |
+
def get_episode(self) -> Tuple[str, List[Tuple[str, bool]]]:
|
| 16 |
+
"""
|
| 17 |
+
Returns (question, List[(chunk_text, is_gold)])
|
| 18 |
+
"""
|
| 19 |
+
if self.current_ptr >= len(self.indices):
|
| 20 |
+
random.shuffle(self.indices)
|
| 21 |
+
self.current_ptr = 0
|
| 22 |
+
|
| 23 |
+
idx = self.indices[self.current_ptr]
|
| 24 |
+
self.current_ptr += 1
|
| 25 |
+
|
| 26 |
+
entry = self.dataset[idx]
|
| 27 |
+
question = entry["question"]
|
| 28 |
+
gold_context = entry["context"]
|
| 29 |
+
|
| 30 |
+
# 1 Gold + 4 Noise
|
| 31 |
+
chunks = [(gold_context, True)]
|
| 32 |
+
|
| 33 |
+
# Sample 4 noise contexts from other entries
|
| 34 |
+
noise_indices = random.sample([i for i in range(len(self.dataset)) if i != idx], 4)
|
| 35 |
+
for nid in noise_indices:
|
| 36 |
+
chunks.append((self.dataset[nid]["context"], False))
|
| 37 |
+
|
| 38 |
+
# Shuffle chunks to avoid gold being first
|
| 39 |
+
random.shuffle(chunks)
|
| 40 |
+
|
| 41 |
+
return question, chunks
|
| 42 |
+
|
| 43 |
+
def count_tokens(text: str) -> int:
|
| 44 |
+
"""Simple token counter using whitespace splitting."""
|
| 45 |
+
return len(text.split())
|
requirements.txt
CHANGED
|
@@ -3,3 +3,7 @@ pydantic>=2.0
|
|
| 3 |
fastapi>=0.104.0
|
| 4 |
uvicorn[standard]>=0.24.0
|
| 5 |
typing_extensions>=4.8.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
fastapi>=0.104.0
|
| 4 |
uvicorn[standard]>=0.24.0
|
| 5 |
typing_extensions>=4.8.0
|
| 6 |
+
datasets>=2.15.0
|
| 7 |
+
transformers>=4.35.0
|
| 8 |
+
trl>=0.7.4
|
| 9 |
+
torch>=2.1.0
|
test_env.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from unittest.mock import MagicMock
|
| 3 |
+
from context_pruning_env.env import ContextPruningEnv
|
| 4 |
+
from context_pruning_env.models import PruningAction, ChunkItem
|
| 5 |
+
|
| 6 |
+
class TestContextPruningEnv(unittest.TestCase):
|
| 7 |
+
def setUp(self):
|
| 8 |
+
# Create env and mock the loader to avoid HF dataset download
|
| 9 |
+
self.env = ContextPruningEnv(squad_split="train")
|
| 10 |
+
self.env.loader = MagicMock()
|
| 11 |
+
|
| 12 |
+
# Mock episode data: 1 Gold, 4 Noise
|
| 13 |
+
self.mock_question = "What color is the sky?"
|
| 14 |
+
self.mock_chunks = [
|
| 15 |
+
("The sky appears blue due to Rayleigh scattering.", True),
|
| 16 |
+
("Grass is usually green.", False),
|
| 17 |
+
("Pizza is delicious.", False),
|
| 18 |
+
("Computers process binary data.", False),
|
| 19 |
+
("Antarctica is cold.", False)
|
| 20 |
+
]
|
| 21 |
+
self.env.loader.get_episode.return_value = (self.mock_question, self.mock_chunks)
|
| 22 |
+
|
| 23 |
+
def test_reset(self):
|
| 24 |
+
obs = self.env.reset()
|
| 25 |
+
self.assertEqual(obs.question, self.mock_question)
|
| 26 |
+
self.assertEqual(len(obs.chunks), 5)
|
| 27 |
+
self.assertFalse(obs.done)
|
| 28 |
+
|
| 29 |
+
def test_step_keep_gold(self):
|
| 30 |
+
self.env.reset()
|
| 31 |
+
# Gold is at index 0
|
| 32 |
+
# Action: Keep all
|
| 33 |
+
action = PruningAction(mask=[1, 1, 1, 1, 1])
|
| 34 |
+
obs = self.env.step(action)
|
| 35 |
+
|
| 36 |
+
self.assertTrue(obs.done)
|
| 37 |
+
# Accuracy bonus + tokens saved (0 in this case)
|
| 38 |
+
self.assertEqual(obs.reward, 10.0)
|
| 39 |
+
self.assertIn("Success", obs.message)
|
| 40 |
+
|
| 41 |
+
def test_step_prune_gold(self):
|
| 42 |
+
self.env.reset()
|
| 43 |
+
# Gold is at index 0
|
| 44 |
+
# Action: Prune gold, keep others
|
| 45 |
+
action = PruningAction(mask=[0, 1, 1, 1, 1])
|
| 46 |
+
obs = self.env.step(action)
|
| 47 |
+
|
| 48 |
+
self.assertTrue(obs.done)
|
| 49 |
+
self.assertEqual(obs.reward, -20.0)
|
| 50 |
+
self.assertIn("Failure", obs.message)
|
| 51 |
+
|
| 52 |
+
def test_step_efficiency(self):
|
| 53 |
+
self.env.reset()
|
| 54 |
+
# Gold is at index 0
|
| 55 |
+
# Action: Keep gold, prune others
|
| 56 |
+
action = PruningAction(mask=[1, 0, 0, 0, 0])
|
| 57 |
+
obs = self.env.step(action)
|
| 58 |
+
|
| 59 |
+
self.assertTrue(obs.done)
|
| 60 |
+
self.assertGreater(obs.reward, 10.0) # Accuracy (10) + Efficiency (>0)
|
| 61 |
+
self.assertIn("Success", obs.message)
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
unittest.main()
|
train_grpo.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
from context_pruning_env.env import ContextPruningEnv
|
| 5 |
+
from context_pruning_env.models import PruningAction
|
| 6 |
+
|
| 7 |
+
# 1. Setup Environment
|
| 8 |
+
env = ContextPruningEnv(squad_split="train")
|
| 9 |
+
|
| 10 |
+
def reward_func(prompts, completions, **kwargs):
|
| 11 |
+
"""
|
| 12 |
+
Reward function wrapper for GRPOTrainer.
|
| 13 |
+
"""
|
| 14 |
+
rewards = []
|
| 15 |
+
for prompt, completion in zip(prompts, completions):
|
| 16 |
+
# In a real GRPOTrainer setup, we process multiple completions for the same prompt.
|
| 17 |
+
# Here we simulate the interface mapping back to our environment logic.
|
| 18 |
+
|
| 19 |
+
# 1. Extract action mask from completion (LLM output)
|
| 20 |
+
# Assuming the model outputs something like "Action: [1, 0, 1, 1, 0]"
|
| 21 |
+
try:
|
| 22 |
+
# Simple parse logic
|
| 23 |
+
if "[" in completion and "]" in completion:
|
| 24 |
+
mask_str = completion.split("[")[1].split("]")[0]
|
| 25 |
+
mask = [int(x.strip()) for x in mask_str.split(",")]
|
| 26 |
+
else:
|
| 27 |
+
mask = [1, 1, 1, 1, 1] # Fallback to keeping everything
|
| 28 |
+
except:
|
| 29 |
+
mask = [1, 1, 1, 1, 1]
|
| 30 |
+
|
| 31 |
+
# 2. Step the environment (Simulated for the snippet)
|
| 32 |
+
# In actual GRPO, we might reset env to the state corresponding to the prompt.
|
| 33 |
+
# env.reset(seed=...)
|
| 34 |
+
action = PruningAction(mask=mask)
|
| 35 |
+
obs = env.step(action)
|
| 36 |
+
rewards.append(obs.reward)
|
| 37 |
+
|
| 38 |
+
return rewards
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
model_id = "meta-llama/Llama-3-8B" # Reference model
|
| 42 |
+
|
| 43 |
+
# 2. Config for GRPO
|
| 44 |
+
training_args = GRPOConfig(
|
| 45 |
+
output_dir="./llama-3-rag-pruning",
|
| 46 |
+
learning_rate=5e-6,
|
| 47 |
+
per_batch_size=1,
|
| 48 |
+
gradient_accumulation_steps=16,
|
| 49 |
+
num_train_epochs=3,
|
| 50 |
+
logging_steps=10,
|
| 51 |
+
group_size=8, # GRPO specific: group size for relative reward calculation
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# 3. Initialize Trainer
|
| 55 |
+
# Note: In a real implementation, you'd need the dataset formatted for the trainer
|
| 56 |
+
trainer = GRPOTrainer(
|
| 57 |
+
model=model_id,
|
| 58 |
+
reward_funcs=[reward_func],
|
| 59 |
+
args=training_args,
|
| 60 |
+
# train_dataset=rag_pruning_dataset, # Pre-formatted dataset
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print("Starting Training with GRPOTrainer...")
|
| 64 |
+
# trainer.train()
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|