Spaces:
Sleeping
Sleeping
Deploy: Synchronize ports to 7860 and add Hugging Face Space metadata
Browse files- Dockerfile +2 -2
- README.md +1 -2
- app.py +1 -1
- context_pruning_env/env.py +0 -152
- context_pruning_env/graders.py +0 -56
- context_pruning_env/models.py +0 -54
- context_pruning_env/utils.py +0 -80
- openenv.yaml +1 -1
- rag_gc_env/__init__.py +0 -11
- rag_gc_env/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/environment.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/grader.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/inference.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/models.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/rewards.cpython-311.pyc +0 -0
- rag_gc_env/__pycache__/tasks.cpython-311.pyc +0 -0
- rag_gc_env/environment.py +0 -187
- rag_gc_env/grader.py +0 -43
- rag_gc_env/inference.py +0 -49
- rag_gc_env/models.py +0 -53
- rag_gc_env/rewards.py +0 -89
- rag_gc_env/server/__init__.py +0 -1
- rag_gc_env/server/__pycache__/__init__.cpython-311.pyc +0 -0
- rag_gc_env/server/__pycache__/app.cpython-311.pyc +0 -0
- rag_gc_env/server/app.py +0 -23
- rag_gc_env/tasks.py +0 -144
Dockerfile
CHANGED
|
@@ -12,6 +12,6 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
|
| 12 |
|
| 13 |
COPY . /app
|
| 14 |
|
| 15 |
-
EXPOSE
|
| 16 |
|
| 17 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "
|
|
|
|
| 12 |
|
| 13 |
COPY . /app
|
| 14 |
|
| 15 |
+
EXPOSE 7860
|
| 16 |
|
| 17 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -75,7 +75,7 @@ ContextPrune includes three canonical tasks that simulate high-pressure operatio
|
|
| 75 |
## 5. Technical Components
|
| 76 |
|
| 77 |
- **`rag_optimizer_env/`**: Core state management, hybrid retrieval (Keyword + Semantic), and token estimation using `llm_runtime`.
|
| 78 |
-
- **`app.py`**: A standard FastAPI implementation.
|
| 79 |
- **`inference.py`**: A baseline agent script demonstrating how to use the OpenAI-compatible interface.
|
| 80 |
- **`validate.py`**: A robust validation suite that runs a full episode lifecycle locally to ensure 100% environment compliance.
|
| 81 |
|
|
@@ -88,4 +88,3 @@ ContextPrune includes three canonical tasks that simulate high-pressure operatio
|
|
| 88 |
3. **Control Panel**: `streamlit run optimizer_ui.py`
|
| 89 |
4. **Validation**: `python validate.py`
|
| 90 |
|
| 91 |
-
Built for Context Optimization Research.
|
|
|
|
| 75 |
## 5. Technical Components
|
| 76 |
|
| 77 |
- **`rag_optimizer_env/`**: Core state management, hybrid retrieval (Keyword + Semantic), and token estimation using `llm_runtime`.
|
| 78 |
+
- **`app.py`**: A standard FastAPI implementation. Built for Context Optimization Research.
|
| 79 |
- **`inference.py`**: A baseline agent script demonstrating how to use the OpenAI-compatible interface.
|
| 80 |
- **`validate.py`**: A robust validation suite that runs a full episode lifecycle locally to ensure 100% environment compliance.
|
| 81 |
|
|
|
|
| 88 |
3. **Control Panel**: `streamlit run optimizer_ui.py`
|
| 89 |
4. **Validation**: `python validate.py`
|
| 90 |
|
|
|
app.py
CHANGED
|
@@ -387,4 +387,4 @@ async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
|
|
| 387 |
if __name__ == "__main__":
|
| 388 |
import uvicorn
|
| 389 |
|
| 390 |
-
uvicorn.run("app:app", host="0.0.0.0", port=
|
|
|
|
| 387 |
if __name__ == "__main__":
|
| 388 |
import uvicorn
|
| 389 |
|
| 390 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|
context_pruning_env/env.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from typing import Any, Optional, List, Dict
|
| 3 |
-
from uuid import uuid4
|
| 4 |
-
|
| 5 |
-
from openenv.core.env_server.interfaces import Environment
|
| 6 |
-
from context_pruning_env.models import (
|
| 7 |
-
ContextAction,
|
| 8 |
-
ContextObservation,
|
| 9 |
-
ContextReward,
|
| 10 |
-
PruningState,
|
| 11 |
-
ChunkItem
|
| 12 |
-
)
|
| 13 |
-
from context_pruning_env.utils import SQuADLoader, count_tokens
|
| 14 |
-
from context_pruning_env.graders import (
|
| 15 |
-
grade_noise_purge,
|
| 16 |
-
grade_dedupe_arena,
|
| 17 |
-
grade_signal_extract
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
|
| 21 |
-
"""
|
| 22 |
-
Hackathon-compliant Context Pruning Environment.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, squad_split: str = "train"):
|
| 26 |
-
super().__init__(transform=None, rubric=None)
|
| 27 |
-
self.loader = SQuADLoader(split=squad_split)
|
| 28 |
-
self._state = None
|
| 29 |
-
|
| 30 |
-
def reset(
|
| 31 |
-
self,
|
| 32 |
-
seed: Optional[int] = None,
|
| 33 |
-
episode_id: Optional[str] = None,
|
| 34 |
-
task_name: Optional[str] = "noise_purge",
|
| 35 |
-
**kwargs: Any,
|
| 36 |
-
) -> ContextObservation:
|
| 37 |
-
"""
|
| 38 |
-
Starts a new episode with the specified task.
|
| 39 |
-
"""
|
| 40 |
-
task_name = task_name or "noise_purge"
|
| 41 |
-
question, chunks_data = self.loader.get_episode(task_name)
|
| 42 |
-
|
| 43 |
-
chunks = []
|
| 44 |
-
total_tokens = 0
|
| 45 |
-
for item in chunks_data:
|
| 46 |
-
tokens = count_tokens(item["content"])
|
| 47 |
-
total_tokens += tokens
|
| 48 |
-
chunks.append(ChunkItem(
|
| 49 |
-
content=item["content"],
|
| 50 |
-
is_gold=item["is_gold"],
|
| 51 |
-
is_duplicate=item["is_duplicate"],
|
| 52 |
-
tokens=tokens
|
| 53 |
-
))
|
| 54 |
-
|
| 55 |
-
self._state = PruningState(
|
| 56 |
-
episode_id=episode_id or str(uuid4()),
|
| 57 |
-
task_name=task_name,
|
| 58 |
-
question=question,
|
| 59 |
-
chunks=chunks,
|
| 60 |
-
initial_tokens=total_tokens,
|
| 61 |
-
step_count=0,
|
| 62 |
-
done=False
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
return self._observe(message=f"Task '{task_name}' initialized.")
|
| 66 |
-
|
| 67 |
-
def _observe(self, message: str = "") -> ContextObservation:
|
| 68 |
-
"""Create observation from state."""
|
| 69 |
-
return ContextObservation(
|
| 70 |
-
done=self._state.done,
|
| 71 |
-
question=self._state.question,
|
| 72 |
-
chunks=[c.content for c in self._state.chunks],
|
| 73 |
-
initial_token_count=self._state.initial_tokens,
|
| 74 |
-
current_token_count=sum(c.tokens for c in self._state.chunks),
|
| 75 |
-
task_name=self._state.task_name,
|
| 76 |
-
message=message
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def step(
|
| 80 |
-
self,
|
| 81 |
-
action: ContextAction,
|
| 82 |
-
**kwargs: Any,
|
| 83 |
-
) -> ContextObservation:
|
| 84 |
-
"""
|
| 85 |
-
Takes a binary mask and calculates rewards based on trajectory signals.
|
| 86 |
-
"""
|
| 87 |
-
if self._state.done:
|
| 88 |
-
return self._observe(message="Episode is already done.")
|
| 89 |
-
|
| 90 |
-
mask = action.mask
|
| 91 |
-
if len(mask) != len(self._state.chunks):
|
| 92 |
-
# Pad with 0 (Prune) instead of 1 (Keep) to ensure agent optimization
|
| 93 |
-
mask = (mask + [0] * len(self._state.chunks))[:len(self._state.chunks)]
|
| 94 |
-
|
| 95 |
-
# Trajectory Simulation Logic
|
| 96 |
-
total_reward = 0.0
|
| 97 |
-
efficiency_reward = 0.0
|
| 98 |
-
accuracy_reward = 0.0
|
| 99 |
-
gold_penalty = 0.0
|
| 100 |
-
success = True
|
| 101 |
-
|
| 102 |
-
for i, kept in enumerate(mask):
|
| 103 |
-
chunk = self._state.chunks[i]
|
| 104 |
-
if not kept: # Pruned
|
| 105 |
-
if chunk.is_gold:
|
| 106 |
-
# Critical Failure
|
| 107 |
-
gold_penalty = -1.0
|
| 108 |
-
success = False
|
| 109 |
-
break # Immediate stop
|
| 110 |
-
else:
|
| 111 |
-
# Correctly pruned noise/duplicate
|
| 112 |
-
efficiency_reward += 0.1
|
| 113 |
-
else: # Kept
|
| 114 |
-
pass
|
| 115 |
-
|
| 116 |
-
# Final Accuracy Bonus
|
| 117 |
-
if success:
|
| 118 |
-
accuracy_reward = 0.7
|
| 119 |
-
|
| 120 |
-
total_reward = efficiency_reward + accuracy_reward + gold_penalty
|
| 121 |
-
|
| 122 |
-
# Task Score (Normalized 0.0 to 1.0 for the evaluator)
|
| 123 |
-
if self._state.task_name == "noise_purge":
|
| 124 |
-
score_obj = grade_noise_purge(mask, self._state.chunks)
|
| 125 |
-
elif self._state.task_name == "dedupe_arena":
|
| 126 |
-
score_obj = grade_dedupe_arena(mask, self._state.chunks)
|
| 127 |
-
elif self._state.task_name == "signal_extract":
|
| 128 |
-
score_obj = grade_signal_extract(mask, self._state.chunks)
|
| 129 |
-
else:
|
| 130 |
-
score_obj = grade_noise_purge(mask, self._state.chunks)
|
| 131 |
-
|
| 132 |
-
self._state.done = True
|
| 133 |
-
self._state.step_count += 1
|
| 134 |
-
|
| 135 |
-
obs = self._observe(message=score_obj.message)
|
| 136 |
-
obs.reward = total_reward # Trajectory reward
|
| 137 |
-
|
| 138 |
-
if not obs.metadata:
|
| 139 |
-
obs.metadata = {}
|
| 140 |
-
obs.metadata["eval_score"] = score_obj.score # Grader score
|
| 141 |
-
obs.metadata["reward_detail"] = {
|
| 142 |
-
"efficiency": efficiency_reward,
|
| 143 |
-
"accuracy": accuracy_reward,
|
| 144 |
-
"penalty": gold_penalty
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
return obs
|
| 148 |
-
|
| 149 |
-
@property
|
| 150 |
-
def state(self) -> PruningState:
|
| 151 |
-
"""Official state access as required by openenv-core."""
|
| 152 |
-
return self._state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_pruning_env/graders.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
from context_pruning_env.models import ChunkItem, ContextReward
|
| 3 |
-
|
| 4 |
-
def grade_noise_purge(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 5 |
-
"""
|
| 6 |
-
Easy Task: Score 1.0 if gold kept AND noise pruned.
|
| 7 |
-
"""
|
| 8 |
-
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 9 |
-
noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
|
| 10 |
-
|
| 11 |
-
if not gold_kept:
|
| 12 |
-
return ContextReward(score=0.0, gold_penalty=-1.0, message="Critical: Gold chunk lost.")
|
| 13 |
-
|
| 14 |
-
if noise_pruned:
|
| 15 |
-
return ContextReward(score=1.0, message="Perfect: All noise purged.")
|
| 16 |
-
else:
|
| 17 |
-
return ContextReward(score=0.5, message="Partial: Gold kept but noise remains.")
|
| 18 |
-
|
| 19 |
-
def grade_dedupe_arena(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 20 |
-
"""
|
| 21 |
-
Medium Task: 1.0 if word count reduced > 50% AND gold kept.
|
| 22 |
-
"""
|
| 23 |
-
initial_words = sum(len(c.content.split()) for c in chunks)
|
| 24 |
-
final_words = sum(len(chunks[i].content.split()) for i, kept in enumerate(mask) if kept)
|
| 25 |
-
|
| 26 |
-
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 27 |
-
reduction = 1.0 - (final_words / initial_words) if initial_words > 0 else 1.0
|
| 28 |
-
|
| 29 |
-
if not gold_kept:
|
| 30 |
-
return ContextReward(score=0.0, message="Critical: Answer lost during deduplication.")
|
| 31 |
-
|
| 32 |
-
if reduction >= 0.5:
|
| 33 |
-
return ContextReward(score=1.0, message=f"Great: {reduction:.1%} word reduction achieved.")
|
| 34 |
-
else:
|
| 35 |
-
return ContextReward(score=0.5, message=f"Partial: Only {reduction:.1%} reduction.")
|
| 36 |
-
|
| 37 |
-
def grade_signal_extract(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
|
| 38 |
-
"""
|
| 39 |
-
Hard Task: 1 - (FinalTokens/InitialTokens) if gold kept.
|
| 40 |
-
"""
|
| 41 |
-
initial_tokens = sum(c.tokens for c in chunks)
|
| 42 |
-
final_tokens = sum(chunks[i].tokens for i, kept in enumerate(mask) if kept)
|
| 43 |
-
|
| 44 |
-
gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
|
| 45 |
-
|
| 46 |
-
if not gold_kept:
|
| 47 |
-
return ContextReward(score=0.0, message="Critical: Signal lost in noise.")
|
| 48 |
-
|
| 49 |
-
reduction_score = 1.0 - (final_tokens / initial_tokens) if initial_tokens > 0 else 0.0
|
| 50 |
-
# Ensure score is at least positive if gold is kept
|
| 51 |
-
final_score = max(0.1, reduction_score)
|
| 52 |
-
|
| 53 |
-
return ContextReward(
|
| 54 |
-
score=final_score,
|
| 55 |
-
message=f"Signal Extracted: {reduction_score:.1%} compression."
|
| 56 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_pruning_env/models.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
from typing import List, Optional, Any, Dict
|
| 3 |
-
from pydantic import BaseModel, Field
|
| 4 |
-
from openenv.core.env_server.types import Action, Observation, State
|
| 5 |
-
|
| 6 |
-
class ContextAction(Action):
|
| 7 |
-
"""
|
| 8 |
-
Action space: A binary mask of N values (1 = keep, 0 = prune).
|
| 9 |
-
"""
|
| 10 |
-
mask: List[int] = Field(
|
| 11 |
-
...,
|
| 12 |
-
min_length=1,
|
| 13 |
-
description="Binary mask of integers (0 or 1) indicating which chunks to keep."
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
class ContextObservation(Observation):
|
| 17 |
-
"""
|
| 18 |
-
Observation provided to the agent.
|
| 19 |
-
"""
|
| 20 |
-
question: str
|
| 21 |
-
chunks: List[str] = Field(default_factory=list, description="Current context chunks.")
|
| 22 |
-
initial_token_count: int = 0
|
| 23 |
-
current_token_count: int = 0
|
| 24 |
-
task_name: str = ""
|
| 25 |
-
message: str = ""
|
| 26 |
-
|
| 27 |
-
class ContextReward(BaseModel):
|
| 28 |
-
"""
|
| 29 |
-
Detailed reward breakdown for Meta x Scaler audit.
|
| 30 |
-
"""
|
| 31 |
-
score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
|
| 32 |
-
efficiency_reward: float = 0.0
|
| 33 |
-
accuracy_reward: float = 0.0
|
| 34 |
-
gold_penalty: float = 0.0
|
| 35 |
-
message: str = ""
|
| 36 |
-
|
| 37 |
-
class ChunkItem(BaseModel):
|
| 38 |
-
"""Internal representation of a context chunk."""
|
| 39 |
-
content: str
|
| 40 |
-
is_gold: bool = False
|
| 41 |
-
tokens: int = 0
|
| 42 |
-
is_duplicate: bool = False
|
| 43 |
-
|
| 44 |
-
class PruningState(State):
|
| 45 |
-
"""
|
| 46 |
-
Internal state for ContextPrune.
|
| 47 |
-
"""
|
| 48 |
-
task_name: str
|
| 49 |
-
question: str
|
| 50 |
-
chunks: List[ChunkItem]
|
| 51 |
-
initial_tokens: int
|
| 52 |
-
step_count: int = 0
|
| 53 |
-
done: bool = False
|
| 54 |
-
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context_pruning_env/utils.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
import re
|
| 3 |
-
from typing import List, Tuple, Dict, Any
|
| 4 |
-
from datasets import load_dataset
|
| 5 |
-
import logging
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger(__name__)
|
| 8 |
-
|
| 9 |
-
class SQuADLoader:
|
| 10 |
-
def __init__(self, split: str = "train"):
|
| 11 |
-
try:
|
| 12 |
-
self.dataset = load_dataset("squad", split=split)
|
| 13 |
-
except Exception as e:
|
| 14 |
-
logger.error(f"Failed to load SQuAD: {e}")
|
| 15 |
-
self.dataset = []
|
| 16 |
-
self.indices = list(range(len(self.dataset)))
|
| 17 |
-
random.shuffle(self.indices)
|
| 18 |
-
self.current_ptr = 0
|
| 19 |
-
|
| 20 |
-
def _get_next_entry(self):
|
| 21 |
-
if self.current_ptr >= len(self.indices):
|
| 22 |
-
random.shuffle(self.indices)
|
| 23 |
-
self.current_ptr = 0
|
| 24 |
-
idx = self.indices[self.current_ptr]
|
| 25 |
-
self.current_ptr += 1
|
| 26 |
-
return idx, self.dataset[idx]
|
| 27 |
-
|
| 28 |
-
def get_episode(self, task_name: str) -> Tuple[str, List[Dict[str, Any]]]:
|
| 29 |
-
"""
|
| 30 |
-
Returns (question, List[Dict(content, is_gold, is_duplicate)])
|
| 31 |
-
"""
|
| 32 |
-
idx, entry = self._get_next_entry()
|
| 33 |
-
question = entry["question"]
|
| 34 |
-
gold_context = entry["context"]
|
| 35 |
-
|
| 36 |
-
chunks = []
|
| 37 |
-
|
| 38 |
-
if task_name == "noise_purge":
|
| 39 |
-
# Easy: 1 Gold + 1 Irrelevant
|
| 40 |
-
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 41 |
-
_, noise_entry = self._get_next_entry()
|
| 42 |
-
chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
|
| 43 |
-
|
| 44 |
-
elif task_name == "dedupe_arena":
|
| 45 |
-
# Medium: 1 Gold + 2 Near-Duplicates (Simulated by repeating gold)
|
| 46 |
-
chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
|
| 47 |
-
# Duplicate 1: slightly modified or identical
|
| 48 |
-
chunks.append({"content": gold_context + " ", "is_gold": True, "is_duplicate": True})
|
| 49 |
-
# Duplicate 2: slightly modified
|
| 50 |
-
chunks.append({"content": "Actually, " + gold_context, "is_gold": True, "is_duplicate": True})
|
| 51 |
-
|
| 52 |
-
elif task_name == "signal_extract":
|
| 53 |
-
# Hard: 1 Gold context + multiple noise (2,000+ words total)
|
| 54 |
-
long_context_parts = [gold_context]
|
| 55 |
-
current_words = len(gold_context.split())
|
| 56 |
-
while current_words < 2200: # Ensure 2,000+ words
|
| 57 |
-
_, noise_entry = self._get_next_entry()
|
| 58 |
-
content = noise_entry["context"]
|
| 59 |
-
long_context_parts.append(content)
|
| 60 |
-
current_words += len(content.split())
|
| 61 |
-
|
| 62 |
-
# Shuffling the parts so the gold one isn't first
|
| 63 |
-
random.shuffle(long_context_parts)
|
| 64 |
-
for part in long_context_parts:
|
| 65 |
-
is_gold = (part == gold_context)
|
| 66 |
-
chunks.append({"content": part, "is_gold": is_gold, "is_duplicate": False})
|
| 67 |
-
|
| 68 |
-
else:
|
| 69 |
-
# Default to noise_purge
|
| 70 |
-
return self.get_episode("noise_purge")
|
| 71 |
-
|
| 72 |
-
# Shuffle chunks for non-signal tasks
|
| 73 |
-
if task_name != "signal_extract":
|
| 74 |
-
random.shuffle(chunks)
|
| 75 |
-
|
| 76 |
-
return question, chunks
|
| 77 |
-
|
| 78 |
-
def count_tokens(text: str) -> int:
|
| 79 |
-
"""Standard token counter for efficiency rewards."""
|
| 80 |
-
return len(text.split())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv.yaml
CHANGED
|
@@ -14,5 +14,5 @@ tasks:
|
|
| 14 |
action_space: ["inspect_artifact", "prioritize_artifact", "summarize_artifact", "set_resolution_plan", "submit_report"]
|
| 15 |
observation_space: ["case_summary", "objective", "workflow_stage", "available_artifacts", "reviewed_artifacts", "prioritized_artifacts", "plan_draft", "total_tokens_used", "token_budget"]
|
| 16 |
reward_range: [0.0, 1.0]
|
| 17 |
-
port:
|
| 18 |
app: app.py
|
|
|
|
| 14 |
action_space: ["inspect_artifact", "prioritize_artifact", "summarize_artifact", "set_resolution_plan", "submit_report"]
|
| 15 |
observation_space: ["case_summary", "objective", "workflow_stage", "available_artifacts", "reviewed_artifacts", "prioritized_artifacts", "plan_draft", "total_tokens_used", "token_budget"]
|
| 16 |
reward_range: [0.0, 1.0]
|
| 17 |
+
port: 7860
|
| 18 |
app: app.py
|
rag_gc_env/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
from rag_gc_env.models import RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
|
| 2 |
-
from rag_gc_env.environment import RAGGCEnvironment
|
| 3 |
-
|
| 4 |
-
__all__ = [
|
| 5 |
-
"RAGGCAction",
|
| 6 |
-
"RAGGCObservation",
|
| 7 |
-
"RAGGCReward",
|
| 8 |
-
"RAGGCState",
|
| 9 |
-
"RAGGCEnvironment",
|
| 10 |
-
]
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (440 Bytes)
|
|
|
rag_gc_env/__pycache__/environment.cpython-311.pyc
DELETED
|
Binary file (9.34 kB)
|
|
|
rag_gc_env/__pycache__/grader.cpython-311.pyc
DELETED
|
Binary file (2.75 kB)
|
|
|
rag_gc_env/__pycache__/inference.cpython-311.pyc
DELETED
|
Binary file (2.66 kB)
|
|
|
rag_gc_env/__pycache__/models.cpython-311.pyc
DELETED
|
Binary file (3.51 kB)
|
|
|
rag_gc_env/__pycache__/rewards.cpython-311.pyc
DELETED
|
Binary file (3.92 kB)
|
|
|
rag_gc_env/__pycache__/tasks.cpython-311.pyc
DELETED
|
Binary file (5.11 kB)
|
|
|
rag_gc_env/environment.py
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any, Optional
|
| 4 |
-
from uuid import uuid4
|
| 5 |
-
|
| 6 |
-
from openenv.core.env_server.interfaces import Environment
|
| 7 |
-
|
| 8 |
-
from rag_gc_env.grader import grade_context
|
| 9 |
-
from rag_gc_env.models import DocumentItem, RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
|
| 10 |
-
from rag_gc_env.rewards import step_reward, summarize_deterministic
|
| 11 |
-
from rag_gc_env.tasks import ALL_TASKS, TaskSpec, task_by_seed
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class RAGGCEnvironment(Environment[RAGGCAction, RAGGCObservation, RAGGCState]):
|
| 15 |
-
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 16 |
-
|
| 17 |
-
def __init__(self) -> None:
|
| 18 |
-
super().__init__(transform=None, rubric=None)
|
| 19 |
-
self._state = RAGGCState(episode_id=str(uuid4()), step_count=0)
|
| 20 |
-
self._task: TaskSpec = task_by_seed(0)
|
| 21 |
-
self._docs: dict[str, DocumentItem] = {}
|
| 22 |
-
self._removed_critical = False
|
| 23 |
-
|
| 24 |
-
def _load_task(self, spec: TaskSpec) -> None:
|
| 25 |
-
self._docs = {}
|
| 26 |
-
for did, text, tok, _meta in spec.documents:
|
| 27 |
-
self._docs[did] = DocumentItem(document_id=did, text=text, tokens=tok)
|
| 28 |
-
|
| 29 |
-
def reset(
|
| 30 |
-
self,
|
| 31 |
-
seed: Optional[int] = None,
|
| 32 |
-
episode_id: Optional[str] = None,
|
| 33 |
-
task_name: Optional[str] = None,
|
| 34 |
-
**kwargs: Any,
|
| 35 |
-
) -> RAGGCObservation:
|
| 36 |
-
self._reset_rubric()
|
| 37 |
-
sid = episode_id or str(uuid4())
|
| 38 |
-
if task_name and task_name in ALL_TASKS:
|
| 39 |
-
self._task = ALL_TASKS[task_name]
|
| 40 |
-
elif seed is not None:
|
| 41 |
-
self._task = task_by_seed(int(seed))
|
| 42 |
-
else:
|
| 43 |
-
self._task = task_by_seed(0)
|
| 44 |
-
self._load_task(self._task)
|
| 45 |
-
self._removed_critical = False
|
| 46 |
-
self._state = RAGGCState(
|
| 47 |
-
episode_id=sid,
|
| 48 |
-
step_count=0,
|
| 49 |
-
task_name=self._task.name,
|
| 50 |
-
max_steps=64,
|
| 51 |
-
removed_critical=False,
|
| 52 |
-
submitted=False,
|
| 53 |
-
)
|
| 54 |
-
return self._observe(done=False, reward_value=0.0, msg="ready")
|
| 55 |
-
|
| 56 |
-
def _total_tokens(self) -> int:
|
| 57 |
-
return sum(d.tokens for d in self._docs.values())
|
| 58 |
-
|
| 59 |
-
def _observe(
|
| 60 |
-
self,
|
| 61 |
-
done: bool,
|
| 62 |
-
reward_value: float,
|
| 63 |
-
msg: str,
|
| 64 |
-
reward_detail: Optional[RAGGCReward] = None,
|
| 65 |
-
grader_score: Optional[float] = None,
|
| 66 |
-
) -> RAGGCObservation:
|
| 67 |
-
docs = sorted(self._docs.values(), key=lambda x: x.document_id)
|
| 68 |
-
return RAGGCObservation(
|
| 69 |
-
done=done,
|
| 70 |
-
reward=reward_value,
|
| 71 |
-
query=self._task.query,
|
| 72 |
-
documents=docs,
|
| 73 |
-
token_count=self._total_tokens(),
|
| 74 |
-
token_budget=self._task.token_budget,
|
| 75 |
-
task_name=self._task.name,
|
| 76 |
-
message=msg,
|
| 77 |
-
grader_score=grader_score,
|
| 78 |
-
reward_detail=reward_detail,
|
| 79 |
-
metadata={
|
| 80 |
-
"relevance": {
|
| 81 |
-
row[0]: row[3].get("relevance", 0.5)
|
| 82 |
-
for row in self._task.documents
|
| 83 |
-
if row[0] in self._docs
|
| 84 |
-
},
|
| 85 |
-
"hints": {row[0]: row[3].get("hint", "") for row in self._task.documents},
|
| 86 |
-
},
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
def step(
|
| 90 |
-
self,
|
| 91 |
-
action: RAGGCAction,
|
| 92 |
-
timeout_s: Optional[float] = None,
|
| 93 |
-
**kwargs: Any,
|
| 94 |
-
) -> RAGGCObservation:
|
| 95 |
-
self._state.step_count += 1
|
| 96 |
-
docs_before = dict(self._docs)
|
| 97 |
-
|
| 98 |
-
if action.verb == "submit":
|
| 99 |
-
score = grade_context(self._task, list(self._docs.values()))
|
| 100 |
-
self._state.submitted = True
|
| 101 |
-
r = RAGGCReward(
|
| 102 |
-
step_reward=score,
|
| 103 |
-
final_score=score,
|
| 104 |
-
)
|
| 105 |
-
obs = self._observe(
|
| 106 |
-
done=True,
|
| 107 |
-
reward_value=score,
|
| 108 |
-
msg="submitted",
|
| 109 |
-
reward_detail=r,
|
| 110 |
-
grader_score=score,
|
| 111 |
-
)
|
| 112 |
-
return self._apply_transform(obs)
|
| 113 |
-
|
| 114 |
-
if action.document_id is None or action.document_id not in self._docs:
|
| 115 |
-
obs = self._observe(
|
| 116 |
-
done=False,
|
| 117 |
-
reward_value=-0.1,
|
| 118 |
-
msg="unknown_document",
|
| 119 |
-
)
|
| 120 |
-
return self._apply_transform(obs)
|
| 121 |
-
|
| 122 |
-
did = action.document_id
|
| 123 |
-
removed_critical = False
|
| 124 |
-
|
| 125 |
-
if action.verb == "delete":
|
| 126 |
-
if did in self._task.critical_document_ids:
|
| 127 |
-
self._removed_critical = True
|
| 128 |
-
removed_critical = True
|
| 129 |
-
self._docs.pop(did, None)
|
| 130 |
-
|
| 131 |
-
elif action.verb == "keep":
|
| 132 |
-
pass
|
| 133 |
-
|
| 134 |
-
elif action.verb == "summarize":
|
| 135 |
-
item = self._docs[did]
|
| 136 |
-
new_text, new_tok = summarize_deterministic(item.text)
|
| 137 |
-
self._docs[did] = DocumentItem(
|
| 138 |
-
document_id=did,
|
| 139 |
-
text=new_text,
|
| 140 |
-
tokens=new_tok,
|
| 141 |
-
)
|
| 142 |
-
if did in self._task.critical_document_ids:
|
| 143 |
-
for p in self._task.required_phrases:
|
| 144 |
-
if p not in new_text:
|
| 145 |
-
self._removed_critical = True
|
| 146 |
-
removed_critical = True
|
| 147 |
-
|
| 148 |
-
rdetail = step_reward(
|
| 149 |
-
self._task,
|
| 150 |
-
action.verb,
|
| 151 |
-
did,
|
| 152 |
-
docs_before,
|
| 153 |
-
self._docs,
|
| 154 |
-
removed_critical,
|
| 155 |
-
)
|
| 156 |
-
self._state.removed_critical = self._removed_critical
|
| 157 |
-
|
| 158 |
-
over = self._total_tokens() > self._task.token_budget
|
| 159 |
-
if over:
|
| 160 |
-
penalty = -0.08 * (self._total_tokens() - self._task.token_budget)
|
| 161 |
-
rdetail.token_penalty += penalty
|
| 162 |
-
rdetail.step_reward += penalty
|
| 163 |
-
done = self._state.step_count >= self._state.max_steps
|
| 164 |
-
final_score: Optional[float] = None
|
| 165 |
-
if done:
|
| 166 |
-
final_score = grade_context(self._task, list(self._docs.values()))
|
| 167 |
-
rdetail.final_score = final_score
|
| 168 |
-
rdetail.step_reward += final_score * 0.5
|
| 169 |
-
|
| 170 |
-
reward_val = rdetail.step_reward
|
| 171 |
-
if done:
|
| 172 |
-
# When done, the reward is primarily the final grader score,
|
| 173 |
-
# but we can preserve the step-specific bonus we added.
|
| 174 |
-
# final_score is the main signal.
|
| 175 |
-
reward_val = final_score if final_score is not None else rdetail.step_reward
|
| 176 |
-
obs = self._observe(
|
| 177 |
-
done=done,
|
| 178 |
-
reward_value=reward_val,
|
| 179 |
-
msg="over_budget" if over else ("graded" if done else "ok"),
|
| 180 |
-
reward_detail=rdetail,
|
| 181 |
-
grader_score=final_score if done else None,
|
| 182 |
-
)
|
| 183 |
-
return self._apply_transform(obs)
|
| 184 |
-
|
| 185 |
-
@property
|
| 186 |
-
def state(self) -> RAGGCState:
|
| 187 |
-
return self._state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/grader.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from rag_gc_env.models import DocumentItem
|
| 4 |
-
from rag_gc_env.tasks import TaskSpec
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def _joined_text(docs: list[DocumentItem]) -> str:
|
| 8 |
-
return " ".join(d.text for d in docs)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def grade_context(task: TaskSpec, final_documents: list[DocumentItem]) -> float:
|
| 12 |
-
"""
|
| 13 |
-
Deterministic score in [0.0, 0.5, 1.0]:
|
| 14 |
-
1.0 — required facts present, budget respected, efficient (near optimal tokens)
|
| 15 |
-
0.5 — required facts present but inefficient or borderline budget
|
| 16 |
-
0.0 — missing facts, forbidden content present, or critical docs removed incorrectly
|
| 17 |
-
"""
|
| 18 |
-
text = _joined_text(final_documents)
|
| 19 |
-
total_tokens = sum(d.tokens for d in final_documents)
|
| 20 |
-
|
| 21 |
-
for phrase in task.required_phrases:
|
| 22 |
-
if phrase not in text:
|
| 23 |
-
return 0.0
|
| 24 |
-
|
| 25 |
-
for phrase in task.forbidden_phrases:
|
| 26 |
-
if phrase in text:
|
| 27 |
-
return 0.0
|
| 28 |
-
|
| 29 |
-
for pid in task.poison_document_ids:
|
| 30 |
-
still = any(d.document_id == pid for d in final_documents)
|
| 31 |
-
if still:
|
| 32 |
-
return 0.0
|
| 33 |
-
|
| 34 |
-
if total_tokens > task.token_budget:
|
| 35 |
-
return 0.0
|
| 36 |
-
|
| 37 |
-
if not task.critical_document_ids.issubset({d.document_id for d in final_documents}):
|
| 38 |
-
return 0.0
|
| 39 |
-
|
| 40 |
-
if total_tokens <= task.optimal_max_tokens:
|
| 41 |
-
return 1.0
|
| 42 |
-
|
| 43 |
-
return 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/inference.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Reproducible baseline policy for Adaptive Context Optimization (RAG GC).
|
| 3 |
-
Deterministic: fixed action sequences per task derived from metadata.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from __future__ import annotations
|
| 7 |
-
|
| 8 |
-
from rag_gc_env.environment import RAGGCEnvironment
|
| 9 |
-
from rag_gc_env.models import RAGGCAction
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def run_baseline(task_name: str, seed: int = 0) -> tuple[float, list[str]]:
|
| 13 |
-
env = RAGGCEnvironment()
|
| 14 |
-
obs = env.reset(seed=seed, task_name=task_name)
|
| 15 |
-
log: list[str] = ["reset"]
|
| 16 |
-
|
| 17 |
-
def step(verb: str, doc_id: str | None) -> None:
|
| 18 |
-
nonlocal obs
|
| 19 |
-
obs = env.step(RAGGCAction(verb=verb, document_id=doc_id))
|
| 20 |
-
log.append(f"{verb}:{doc_id}")
|
| 21 |
-
|
| 22 |
-
if task_name == "easy_irrelevant_removal":
|
| 23 |
-
step("delete", "d1")
|
| 24 |
-
step("submit", None)
|
| 25 |
-
elif task_name == "medium_token_compression":
|
| 26 |
-
step("delete", "m2")
|
| 27 |
-
while obs.token_count > obs.token_budget and not obs.done:
|
| 28 |
-
step("summarize", "m0")
|
| 29 |
-
if len(log) > 40:
|
| 30 |
-
break
|
| 31 |
-
step("submit", None)
|
| 32 |
-
elif task_name == "hard_contradiction_removal":
|
| 33 |
-
step("delete", "h1")
|
| 34 |
-
step("submit", None)
|
| 35 |
-
else:
|
| 36 |
-
step("submit", None)
|
| 37 |
-
|
| 38 |
-
score = float(obs.grader_score or obs.reward or 0.0)
|
| 39 |
-
return score, log
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
if __name__ == "__main__":
|
| 43 |
-
for name in (
|
| 44 |
-
"easy_irrelevant_removal",
|
| 45 |
-
"medium_token_compression",
|
| 46 |
-
"hard_contradiction_removal",
|
| 47 |
-
):
|
| 48 |
-
s, lg = run_baseline(name, seed=0)
|
| 49 |
-
print(name, "score=", s, "trace=", lg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/models.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any, Literal, Optional
|
| 4 |
-
|
| 5 |
-
from openenv.core.env_server.types import Action, Observation, State
|
| 6 |
-
from pydantic import BaseModel, Field
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class DocumentItem(BaseModel):
|
| 10 |
-
document_id: str
|
| 11 |
-
text: str
|
| 12 |
-
tokens: int = Field(description="Estimated tokens for this snippet")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class RAGGCAction(Action):
|
| 16 |
-
verb: Literal["keep", "delete", "summarize", "submit"] = Field(
|
| 17 |
-
description="Document operation or submit to finalize and grade"
|
| 18 |
-
)
|
| 19 |
-
document_id: Optional[str] = Field(
|
| 20 |
-
default=None,
|
| 21 |
-
description="Target document for keep/delete/summarize; omit for submit",
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class RAGGCReward(BaseModel):
|
| 26 |
-
step_reward: float = 0.0
|
| 27 |
-
relevance: float = 0.0
|
| 28 |
-
compression: float = 0.0
|
| 29 |
-
token_penalty: float = 0.0
|
| 30 |
-
critical_penalty: float = 0.0
|
| 31 |
-
final_score: Optional[float] = Field(
|
| 32 |
-
default=None, description="0.0–1.0 after submit; aligns with grader"
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class RAGGCObservation(Observation):
|
| 37 |
-
query: str = ""
|
| 38 |
-
documents: list[DocumentItem] = Field(default_factory=list)
|
| 39 |
-
token_count: int = 0
|
| 40 |
-
token_budget: int = 0
|
| 41 |
-
task_name: str = ""
|
| 42 |
-
reward_detail: Optional[RAGGCReward] = None
|
| 43 |
-
message: str = ""
|
| 44 |
-
grader_score: Optional[float] = Field(
|
| 45 |
-
default=None, description="Deterministic score after episode ends"
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class RAGGCState(State):
|
| 50 |
-
task_name: str = ""
|
| 51 |
-
max_steps: int = 64
|
| 52 |
-
removed_critical: bool = False
|
| 53 |
-
submitted: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/rewards.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from rag_gc_env.models import DocumentItem, RAGGCReward
|
| 4 |
-
from rag_gc_env.tasks import TaskSpec
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def summarize_deterministic(text: str) -> tuple[str, int]:
|
| 8 |
-
"""Deterministic compression: first sentence or capped prefix."""
|
| 9 |
-
stripped = text.strip()
|
| 10 |
-
if not stripped:
|
| 11 |
-
return "", 1
|
| 12 |
-
cut = stripped.split(". ")
|
| 13 |
-
first = cut[0] + ("." if not cut[0].endswith(".") else "")
|
| 14 |
-
if len(first) < 40 and len(cut) > 1:
|
| 15 |
-
first = cut[0] + ". " + cut[1] + ("." if not cut[1].endswith(".") else "")
|
| 16 |
-
cap = 280
|
| 17 |
-
out = first[:cap] + ("..." if len(first) > cap else "")
|
| 18 |
-
tokens = max(1, len(out) // 4)
|
| 19 |
-
return out, tokens
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def estimate_tokens(text: str) -> int:
|
| 23 |
-
return max(1, len(text) // 4)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def step_reward(
|
| 27 |
-
task: TaskSpec,
|
| 28 |
-
verb: str,
|
| 29 |
-
doc_id: str | None,
|
| 30 |
-
docs_before: dict[str, DocumentItem],
|
| 31 |
-
docs_after: dict[str, DocumentItem],
|
| 32 |
-
removed_critical_flag: bool,
|
| 33 |
-
) -> RAGGCReward:
|
| 34 |
-
rel = 0.0
|
| 35 |
-
comp = 0.0
|
| 36 |
-
tok_pen = 0.0
|
| 37 |
-
crit = 0.0
|
| 38 |
-
|
| 39 |
-
if removed_critical_flag:
|
| 40 |
-
crit = -3.0
|
| 41 |
-
|
| 42 |
-
if verb == "delete" and doc_id in docs_before:
|
| 43 |
-
meta = next(
|
| 44 |
-
(m for did, _, _, m in task.documents if did == doc_id),
|
| 45 |
-
{},
|
| 46 |
-
)
|
| 47 |
-
# Reward deleting irrelevant or poison documents
|
| 48 |
-
if doc_id in task.irrelevant_document_ids:
|
| 49 |
-
rel += 0.4
|
| 50 |
-
elif doc_id in task.poison_document_ids:
|
| 51 |
-
rel += 0.6
|
| 52 |
-
elif doc_id in task.critical_document_ids:
|
| 53 |
-
crit -= 3.0
|
| 54 |
-
elif meta.get("hint") == "fluff":
|
| 55 |
-
rel += 0.2
|
| 56 |
-
|
| 57 |
-
# Deleting tokens should NOT result in a penalty proportional to the deleted tokens;
|
| 58 |
-
# instead, it removes the 'keep' penalty they would have incurred.
|
| 59 |
-
# We can add a small constant 'action cost' for deleting if desired, but 0.0 is fine here.
|
| 60 |
-
tok_pen = 0.0
|
| 61 |
-
|
| 62 |
-
if verb == "summarize" and doc_id in docs_before:
|
| 63 |
-
before_t = docs_before[doc_id].tokens
|
| 64 |
-
after = docs_after.get(doc_id)
|
| 65 |
-
if after is not None:
|
| 66 |
-
# Reward for the reduction in size (efficiency)
|
| 67 |
-
reduction_ratio = (before_t - after.tokens) / max(before_t, 1)
|
| 68 |
-
comp += 0.3 * max(0.0, reduction_ratio)
|
| 69 |
-
|
| 70 |
-
# The remaining tokens still incur a small penalty
|
| 71 |
-
tok_pen -= 0.01 * after.tokens
|
| 72 |
-
|
| 73 |
-
if doc_id in task.critical_document_ids:
|
| 74 |
-
for p in task.required_phrases:
|
| 75 |
-
if p not in after.text:
|
| 76 |
-
crit -= 2.5
|
| 77 |
-
|
| 78 |
-
if verb == "keep" and doc_id in docs_before:
|
| 79 |
-
# Standard penalty for keeping tokens in context
|
| 80 |
-
tok_pen -= 0.01 * docs_before[doc_id].tokens
|
| 81 |
-
|
| 82 |
-
step = rel + comp + tok_pen + crit
|
| 83 |
-
return RAGGCReward(
|
| 84 |
-
step_reward=step,
|
| 85 |
-
relevance=rel,
|
| 86 |
-
compression=comp,
|
| 87 |
-
token_penalty=tok_pen,
|
| 88 |
-
critical_penalty=crit,
|
| 89 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/server/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Server package for OpenEnv HTTP deployment
|
|
|
|
|
|
rag_gc_env/server/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (154 Bytes)
|
|
|
rag_gc_env/server/__pycache__/app.cpython-311.pyc
DELETED
|
Binary file (1.04 kB)
|
|
|
rag_gc_env/server/app.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
from openenv.core.env_server.http_server import create_fastapi_app
|
| 4 |
-
|
| 5 |
-
from rag_gc_env.environment import RAGGCEnvironment
|
| 6 |
-
from rag_gc_env.models import RAGGCAction, RAGGCObservation
|
| 7 |
-
|
| 8 |
-
app = create_fastapi_app(
|
| 9 |
-
RAGGCEnvironment,
|
| 10 |
-
RAGGCAction,
|
| 11 |
-
RAGGCObservation,
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def main() -> None:
|
| 16 |
-
import uvicorn
|
| 17 |
-
|
| 18 |
-
port = int(os.environ.get("PORT", "8000"))
|
| 19 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
if __name__ == "__main__":
|
| 23 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rag_gc_env/tasks.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
-
from typing import Any, FrozenSet
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass(frozen=True)
|
| 8 |
-
class TaskSpec:
|
| 9 |
-
name: str
|
| 10 |
-
query: str
|
| 11 |
-
token_budget: int
|
| 12 |
-
documents: list[tuple[str, str, int, dict[str, Any]]]
|
| 13 |
-
# document_id, text, tokens, metadata (relevance, flags)
|
| 14 |
-
required_phrases: FrozenSet[str] = field(default_factory=frozenset)
|
| 15 |
-
forbidden_phrases: FrozenSet[str] = field(default_factory=frozenset)
|
| 16 |
-
critical_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 17 |
-
irrelevant_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 18 |
-
poison_document_ids: FrozenSet[str] = field(default_factory=frozenset)
|
| 19 |
-
optimal_max_tokens: int = 0
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _docs(
|
| 23 |
-
rows: list[tuple[str, str, int, dict[str, Any]]],
|
| 24 |
-
) -> list[tuple[str, str, int, dict[str, Any]]]:
|
| 25 |
-
return rows
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
TASK_EASY = TaskSpec(
|
| 29 |
-
name="easy_irrelevant_removal",
|
| 30 |
-
query="What is the capital city of France?",
|
| 31 |
-
token_budget=400,
|
| 32 |
-
documents=_docs(
|
| 33 |
-
[
|
| 34 |
-
(
|
| 35 |
-
"d0",
|
| 36 |
-
"Paris has been the capital of France since political centralization in the country.",
|
| 37 |
-
24,
|
| 38 |
-
{"relevance": 0.95, "hint": "high"},
|
| 39 |
-
),
|
| 40 |
-
(
|
| 41 |
-
"d1",
|
| 42 |
-
"Penguins thrive in Antarctica and are unrelated to European geography.",
|
| 43 |
-
18,
|
| 44 |
-
{"relevance": 0.08, "hint": "noise"},
|
| 45 |
-
),
|
| 46 |
-
(
|
| 47 |
-
"d2",
|
| 48 |
-
"Lyon is a major French city but not the national capital.",
|
| 49 |
-
16,
|
| 50 |
-
{"relevance": 0.55, "hint": "partial"},
|
| 51 |
-
),
|
| 52 |
-
]
|
| 53 |
-
),
|
| 54 |
-
required_phrases=frozenset({"Paris"}),
|
| 55 |
-
forbidden_phrases=frozenset(),
|
| 56 |
-
critical_document_ids=frozenset({"d0"}),
|
| 57 |
-
irrelevant_document_ids=frozenset({"d1"}),
|
| 58 |
-
poison_document_ids=frozenset(),
|
| 59 |
-
optimal_max_tokens=120,
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
_LONG_DUP = (
|
| 63 |
-
"Paris is the capital of France. " * 18
|
| 64 |
-
+ "This repetition exists only to inflate token usage for compression tests."
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
TASK_MEDIUM = TaskSpec(
|
| 68 |
-
name="medium_token_compression",
|
| 69 |
-
query="Which city is the capital of France?",
|
| 70 |
-
token_budget=120,
|
| 71 |
-
documents=_docs(
|
| 72 |
-
[
|
| 73 |
-
(
|
| 74 |
-
"m0",
|
| 75 |
-
_LONG_DUP,
|
| 76 |
-
max(1, len(_LONG_DUP) // 4),
|
| 77 |
-
{"relevance": 0.9, "hint": "verbose"},
|
| 78 |
-
),
|
| 79 |
-
(
|
| 80 |
-
"m1",
|
| 81 |
-
"Administrative records list a capital city but this line omits the name intentionally.",
|
| 82 |
-
14,
|
| 83 |
-
{"relevance": 0.55, "hint": "no_answer"},
|
| 84 |
-
),
|
| 85 |
-
(
|
| 86 |
-
"m2",
|
| 87 |
-
"French cuisine is diverse; it does not change which city is the capital.",
|
| 88 |
-
14,
|
| 89 |
-
{"relevance": 0.35, "hint": "fluff"},
|
| 90 |
-
),
|
| 91 |
-
]
|
| 92 |
-
),
|
| 93 |
-
required_phrases=frozenset({"Paris"}),
|
| 94 |
-
forbidden_phrases=frozenset(),
|
| 95 |
-
critical_document_ids=frozenset(),
|
| 96 |
-
irrelevant_document_ids=frozenset({"m2"}),
|
| 97 |
-
poison_document_ids=frozenset(),
|
| 98 |
-
optimal_max_tokens=90,
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
TASK_HARD = TaskSpec(
|
| 102 |
-
name="hard_contradiction_removal",
|
| 103 |
-
query="Compute the integer result of 17 + 25.",
|
| 104 |
-
token_budget=350,
|
| 105 |
-
documents=_docs(
|
| 106 |
-
[
|
| 107 |
-
(
|
| 108 |
-
"h0",
|
| 109 |
-
"Arithmetic trace: 17 + 25 = 42.",
|
| 110 |
-
12,
|
| 111 |
-
{"relevance": 0.9, "hint": "consistent", "trust": 0.95},
|
| 112 |
-
),
|
| 113 |
-
(
|
| 114 |
-
"h1",
|
| 115 |
-
"Quick math note: 17 + 25 equals 43 for budgeting purposes.",
|
| 116 |
-
14,
|
| 117 |
-
{"relevance": 0.88, "hint": "poison", "trust": 0.2},
|
| 118 |
-
),
|
| 119 |
-
(
|
| 120 |
-
"h2",
|
| 121 |
-
"Addition of integers is associative and commutative.",
|
| 122 |
-
10,
|
| 123 |
-
{"relevance": 0.4, "hint": "generic"},
|
| 124 |
-
),
|
| 125 |
-
]
|
| 126 |
-
),
|
| 127 |
-
required_phrases=frozenset({"42"}),
|
| 128 |
-
forbidden_phrases=frozenset({"43"}),
|
| 129 |
-
critical_document_ids=frozenset({"h0"}),
|
| 130 |
-
irrelevant_document_ids=frozenset(),
|
| 131 |
-
poison_document_ids=frozenset({"h1"}),
|
| 132 |
-
optimal_max_tokens=200,
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
ALL_TASKS: dict[str, TaskSpec] = {
|
| 136 |
-
TASK_EASY.name: TASK_EASY,
|
| 137 |
-
TASK_MEDIUM.name: TASK_MEDIUM,
|
| 138 |
-
TASK_HARD.name: TASK_HARD,
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def task_by_seed(seed: int) -> TaskSpec:
|
| 143 |
-
order = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
|
| 144 |
-
return order[seed % 3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|