prithic07 commited on
Commit
2d5dd85
·
1 Parent(s): 92edb88

feat: Implement Context-Pruning-Env with SQuAD dataset and GRPOTrainer support

Browse files
Dockerfile.openenv ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PYTHONPATH=/app
6
+
7
+ WORKDIR /app
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ build-essential \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy and install Python dependencies
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy the environment code
20
+ COPY context_pruning_env ./context_pruning_env
21
+
22
+ # Expose the default OpenEnv port
23
+ EXPOSE 7860
24
+
25
+ # Command to run the environment server (standardized OpenEnv entrypoint)
26
+ # In a real environment, you'd use a server wrapper mapping Gymnasium resets/steps to API calls.
27
+ CMD ["uvicorn", "context_pruning_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
context_pruning_env/env.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any, Optional, List
3
+ from uuid import uuid4
4
+
5
+ from openenv.core.env_server.interfaces import Environment
6
+ from context_pruning_env.models import (
7
+ PruningAction,
8
+ PruningObservation,
9
+ PruningState,
10
+ ChunkItem
11
+ )
12
+ from context_pruning_env.utils import SQuADLoader, count_tokens
13
+
14
+ class ContextPruningEnv(Environment[PruningAction, PruningObservation, PruningState]):
15
+ """
16
+ OpenEnv Reinforcement Learning Environment for RAG Context Pruning.
17
+ """
18
+
19
+ def __init__(self, squad_split: str = "train"):
20
+ super().__init__(transform=None, rubric=None)
21
+ self.loader = SQuADLoader(split=squad_split)
22
+ self._state = None
23
+
24
+ def reset(
25
+ self,
26
+ seed: Optional[int] = None,
27
+ episode_id: Optional[str] = None,
28
+ **kwargs: Any,
29
+ ) -> PruningObservation:
30
+ """
31
+ Loads a new question and 5 context chunks from SQuAD.
32
+ Returns the initial observation.
33
+ """
34
+ question, chunks_data = self.loader.get_episode()
35
+
36
+ # Prepare internal state chunks
37
+ chunks = []
38
+ total_tokens = 0
39
+ for i, (text, is_gold) in enumerate(chunks_data):
40
+ tokens = count_tokens(text)
41
+ total_tokens += tokens
42
+ chunks.append(ChunkItem(
43
+ content=text,
44
+ is_gold=is_gold,
45
+ tokens=tokens
46
+ ))
47
+ if is_gold:
48
+ gold_index = i
49
+
50
+ self._state = PruningState(
51
+ episode_id=episode_id or str(uuid4()),
52
+ question=question,
53
+ gold_index=gold_index,
54
+ chunks=chunks,
55
+ initial_tokens=total_tokens,
56
+ step_count=0,
57
+ done=False
58
+ )
59
+
60
+ return self._observe(
61
+ message="Environment reset. 5 chunks loaded (1 gold, 4 noise)."
62
+ )
63
+
64
+ def _observe(self, message: str = "") -> PruningObservation:
65
+ """Helper to create observation from current state."""
66
+ return PruningObservation(
67
+ done=self._state.done,
68
+ question=self._state.question,
69
+ chunks=[c.content for c in self._state.chunks],
70
+ token_count=sum(c.tokens for c in self._state.chunks),
71
+ message=message
72
+ )
73
+
74
+ def step(
75
+ self,
76
+ action: PruningAction,
77
+ **kwargs: Any,
78
+ ) -> PruningObservation:
79
+ """
80
+ Evaluates the binary mask, calculates token reduction,
81
+ checks gold chunk presence, and returns the observation.
82
+ """
83
+ if self._state.done:
84
+ return self._observe(message="Episode is already done.")
85
+
86
+ mask = action.mask
87
+ if len(mask) != 5:
88
+ # Should not happen if using Gymnasium space or Pydantic validation
89
+ self._state.done = True
90
+ return self._observe(message="Invalid action space size.")
91
+
92
+ # 1. Identify Gold Chunk Status
93
+ gold_kept = (mask[self._state.gold_index] == 1)
94
+
95
+ # 2. Calculate Token reduction
96
+ pruned_tokens = 0
97
+ for i, keep in enumerate(mask):
98
+ if keep == 0:
99
+ pruned_tokens += self._state.chunks[i].tokens
100
+
101
+ # 3. Reward Logic
102
+ reward = 0.0
103
+
104
+ if gold_kept:
105
+ reward += 10.0 # Accuracy Bonus
106
+ reward += 0.01 * pruned_tokens # Efficiency Bonus
107
+ msg = f"Task Success: Gold chunk kept. Pruned {pruned_tokens} tokens."
108
+ else:
109
+ reward -= 20.0 # Penalty: Lost the game
110
+ msg = "Task Failure: Gold chunk was pruned. Mission failed."
111
+
112
+ self._state.done = True
113
+ self._state.step_count += 1
114
+
115
+ # In OpenEnv, the reward is often part of the Observation or signaled via a rubric.
116
+ # We manually update the reward field here.
117
+ obs = self._observe(message=msg)
118
+ obs.reward = reward
119
+
120
+ return obs
121
+
122
+ @property
123
+ def state(self) -> PruningState:
124
+ return self._state
context_pruning_env/models.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Any
3
+ from pydantic import Field
4
+ from openenv.core.env_server.types import Action, Observation, State
5
+
6
+ class PruningAction(Action):
7
+ """
8
+ Action space: A binary mask of 5 values (1 = keep, 0 = prune).
9
+ Example: [1, 0, 1, 1, 0]
10
+ """
11
+ mask: List[int] = Field(
12
+ ...,
13
+ min_items=5,
14
+ max_items=5,
15
+ description="Binary mask of 5 integers (0 or 1) indicating which chunks to keep."
16
+ )
17
+
18
+ class ChunkItem(BaseModel):
19
+ """Represent a single context chunk."""
20
+ content: str
21
+ is_gold: bool = False
22
+ tokens: int = 0
23
+
24
+ class PruningObservation(Observation):
25
+ """
26
+ Observation provided to the agent.
27
+ Contains the question and the 5 context chunks.
28
+ """
29
+ question: str
30
+ chunks: List[str] = Field(default_factory=list, description="List of 5 context strings.")
31
+ token_count: int = 0
32
+ message: str = ""
33
+
34
+ class PruningState(State):
35
+ """
36
+ Internal state of the environment.
37
+ """
38
+ question: str
39
+ gold_index: int
40
+ chunks: List[ChunkItem]
41
+ initial_tokens: int
42
+ step_count: int = 0
43
+ done: bool = False
context_pruning_env/server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server package
context_pruning_env/server/app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openenv.core.env_server.http_server import create_fastapi_app
3
+ from context_pruning_env.env import ContextPruningEnv
4
+ from context_pruning_env.models import PruningAction, PruningObservation
5
+
6
+ app = create_fastapi_app(
7
+ ContextPruningEnv,
8
+ PruningAction,
9
+ PruningObservation,
10
+ )
11
+
12
+ def main() -> None:
13
+ import uvicorn
14
+ port = int(os.environ.get("PORT", "7860"))
15
+ uvicorn.run(app, host="0.0.0.0", port=port)
16
+
17
+ if __name__ == "__main__":
18
+ main()
context_pruning_env/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Tuple
3
+ from datasets import load_dataset
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class SQuADLoader:
9
+ def __init__(self, split: str = "train"):
10
+ self.dataset = load_dataset("squad", split=split)
11
+ self.indices = list(range(len(self.dataset)))
12
+ random.shuffle(self.indices)
13
+ self.current_ptr = 0
14
+
15
+ def get_episode(self) -> Tuple[str, List[Tuple[str, bool]]]:
16
+ """
17
+ Returns (question, List[(chunk_text, is_gold)])
18
+ """
19
+ if self.current_ptr >= len(self.indices):
20
+ random.shuffle(self.indices)
21
+ self.current_ptr = 0
22
+
23
+ idx = self.indices[self.current_ptr]
24
+ self.current_ptr += 1
25
+
26
+ entry = self.dataset[idx]
27
+ question = entry["question"]
28
+ gold_context = entry["context"]
29
+
30
+ # 1 Gold + 4 Noise
31
+ chunks = [(gold_context, True)]
32
+
33
+ # Sample 4 noise contexts from other entries
34
+ noise_indices = random.sample([i for i in range(len(self.dataset)) if i != idx], 4)
35
+ for nid in noise_indices:
36
+ chunks.append((self.dataset[nid]["context"], False))
37
+
38
+ # Shuffle chunks to avoid gold being first
39
+ random.shuffle(chunks)
40
+
41
+ return question, chunks
42
+
43
+ def count_tokens(text: str) -> int:
44
+ """Simple token counter using whitespace splitting."""
45
+ return len(text.split())
requirements.txt CHANGED
@@ -3,3 +3,7 @@ pydantic>=2.0
3
  fastapi>=0.104.0
4
  uvicorn[standard]>=0.24.0
5
  typing_extensions>=4.8.0
 
 
 
 
 
3
  fastapi>=0.104.0
4
  uvicorn[standard]>=0.24.0
5
  typing_extensions>=4.8.0
6
+ datasets>=2.15.0
7
+ transformers>=4.35.0
8
+ trl>=0.7.4
9
+ torch>=2.1.0
test_env.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import MagicMock
3
+ from context_pruning_env.env import ContextPruningEnv
4
+ from context_pruning_env.models import PruningAction, ChunkItem
5
+
6
+ class TestContextPruningEnv(unittest.TestCase):
7
+ def setUp(self):
8
+ # Create env and mock the loader to avoid HF dataset download
9
+ self.env = ContextPruningEnv(squad_split="train")
10
+ self.env.loader = MagicMock()
11
+
12
+ # Mock episode data: 1 Gold, 4 Noise
13
+ self.mock_question = "What color is the sky?"
14
+ self.mock_chunks = [
15
+ ("The sky appears blue due to Rayleigh scattering.", True),
16
+ ("Grass is usually green.", False),
17
+ ("Pizza is delicious.", False),
18
+ ("Computers process binary data.", False),
19
+ ("Antarctica is cold.", False)
20
+ ]
21
+ self.env.loader.get_episode.return_value = (self.mock_question, self.mock_chunks)
22
+
23
+ def test_reset(self):
24
+ obs = self.env.reset()
25
+ self.assertEqual(obs.question, self.mock_question)
26
+ self.assertEqual(len(obs.chunks), 5)
27
+ self.assertFalse(obs.done)
28
+
29
+ def test_step_keep_gold(self):
30
+ self.env.reset()
31
+ # Gold is at index 0
32
+ # Action: Keep all
33
+ action = PruningAction(mask=[1, 1, 1, 1, 1])
34
+ obs = self.env.step(action)
35
+
36
+ self.assertTrue(obs.done)
37
+ # Accuracy bonus + tokens saved (0 in this case)
38
+ self.assertEqual(obs.reward, 10.0)
39
+ self.assertIn("Success", obs.message)
40
+
41
+ def test_step_prune_gold(self):
42
+ self.env.reset()
43
+ # Gold is at index 0
44
+ # Action: Prune gold, keep others
45
+ action = PruningAction(mask=[0, 1, 1, 1, 1])
46
+ obs = self.env.step(action)
47
+
48
+ self.assertTrue(obs.done)
49
+ self.assertEqual(obs.reward, -20.0)
50
+ self.assertIn("Failure", obs.message)
51
+
52
+ def test_step_efficiency(self):
53
+ self.env.reset()
54
+ # Gold is at index 0
55
+ # Action: Keep gold, prune others
56
+ action = PruningAction(mask=[1, 0, 0, 0, 0])
57
+ obs = self.env.step(action)
58
+
59
+ self.assertTrue(obs.done)
60
+ self.assertGreater(obs.reward, 10.0) # Accuracy (10) + Efficiency (>0)
61
+ self.assertIn("Success", obs.message)
62
+
63
+ if __name__ == "__main__":
64
+ unittest.main()
train_grpo.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from trl import GRPOTrainer, GRPOConfig
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from context_pruning_env.env import ContextPruningEnv
5
+ from context_pruning_env.models import PruningAction
6
+
7
+ # 1. Setup Environment
8
+ env = ContextPruningEnv(squad_split="train")
9
+
10
+ def reward_func(prompts, completions, **kwargs):
11
+ """
12
+ Reward function wrapper for GRPOTrainer.
13
+ """
14
+ rewards = []
15
+ for prompt, completion in zip(prompts, completions):
16
+ # In a real GRPOTrainer setup, we process multiple completions for the same prompt.
17
+ # Here we simulate the interface mapping back to our environment logic.
18
+
19
+ # 1. Extract action mask from completion (LLM output)
20
+ # Assuming the model outputs something like "Action: [1, 0, 1, 1, 0]"
21
+ try:
22
+ # Simple parse logic
23
+ if "[" in completion and "]" in completion:
24
+ mask_str = completion.split("[")[1].split("]")[0]
25
+ mask = [int(x.strip()) for x in mask_str.split(",")]
26
+ else:
27
+ mask = [1, 1, 1, 1, 1] # Fallback to keeping everything
28
+ except:
29
+ mask = [1, 1, 1, 1, 1]
30
+
31
+ # 2. Step the environment (Simulated for the snippet)
32
+ # In actual GRPO, we might reset env to the state corresponding to the prompt.
33
+ # env.reset(seed=...)
34
+ action = PruningAction(mask=mask)
35
+ obs = env.step(action)
36
+ rewards.append(obs.reward)
37
+
38
+ return rewards
39
+
40
+ def main():
41
+ model_id = "meta-llama/Llama-3-8B" # Reference model
42
+
43
+ # 2. Config for GRPO
44
+ training_args = GRPOConfig(
45
+ output_dir="./llama-3-rag-pruning",
46
+ learning_rate=5e-6,
47
+ per_batch_size=1,
48
+ gradient_accumulation_steps=16,
49
+ num_train_epochs=3,
50
+ logging_steps=10,
51
+ group_size=8, # GRPO specific: group size for relative reward calculation
52
+ )
53
+
54
+ # 3. Initialize Trainer
55
+ # Note: In a real implementation, you'd need the dataset formatted for the trainer
56
+ trainer = GRPOTrainer(
57
+ model=model_id,
58
+ reward_funcs=[reward_func],
59
+ args=training_args,
60
+ # train_dataset=rag_pruning_dataset, # Pre-formatted dataset
61
+ )
62
+
63
+ print("Starting Training with GRPOTrainer...")
64
+ # trainer.train()
65
+
66
+ if __name__ == "__main__":
67
+ main()