prithic07 commited on
Commit
582387d
·
1 Parent(s): 3838887

Deploy: Synchronize ports to 7860 and add Hugging Face Space metadata

Browse files
Dockerfile CHANGED
@@ -12,6 +12,6 @@ RUN pip install --no-cache-dir --upgrade pip && \
12
 
13
  COPY . /app
14
 
15
- EXPOSE 8000
16
 
17
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
 
12
 
13
  COPY . /app
14
 
15
+ EXPOSE 7860
16
 
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -75,7 +75,7 @@ ContextPrune includes three canonical tasks that simulate high-pressure operatio
75
  ## 5. Technical Components
76
 
77
  - **`rag_optimizer_env/`**: Core state management, hybrid retrieval (Keyword + Semantic), and token estimation using `llm_runtime`.
78
- - **`app.py`**: A standard FastAPI implementation.
79
  - **`inference.py`**: A baseline agent script demonstrating how to use the OpenAI-compatible interface.
80
  - **`validate.py`**: A robust validation suite that runs a full episode lifecycle locally to ensure 100% environment compliance.
81
 
@@ -88,4 +88,3 @@ ContextPrune includes three canonical tasks that simulate high-pressure operatio
88
  3. **Control Panel**: `streamlit run optimizer_ui.py`
89
  4. **Validation**: `python validate.py`
90
 
91
- Built for Context Optimization Research.
 
75
  ## 5. Technical Components
76
 
77
  - **`rag_optimizer_env/`**: Core state management, hybrid retrieval (Keyword + Semantic), and token estimation using `llm_runtime`.
78
+ - **`app.py`**: A standard FastAPI implementation. Built for Context Optimization Research.
79
  - **`inference.py`**: A baseline agent script demonstrating how to use the OpenAI-compatible interface.
80
  - **`validate.py`**: A robust validation suite that runs a full episode lifecycle locally to ensure 100% environment compliance.
81
 
 
88
  3. **Control Panel**: `streamlit run optimizer_ui.py`
89
  4. **Validation**: `python validate.py`
90
 
 
app.py CHANGED
@@ -387,4 +387,4 @@ async def optimize_prompt_endpoint(payload: OptimizePromptRequest):
387
  if __name__ == "__main__":
388
  import uvicorn
389
 
390
- uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
 
387
  if __name__ == "__main__":
388
  import uvicorn
389
 
390
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
context_pruning_env/env.py DELETED
@@ -1,152 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Any, Optional, List, Dict
3
- from uuid import uuid4
4
-
5
- from openenv.core.env_server.interfaces import Environment
6
- from context_pruning_env.models import (
7
- ContextAction,
8
- ContextObservation,
9
- ContextReward,
10
- PruningState,
11
- ChunkItem
12
- )
13
- from context_pruning_env.utils import SQuADLoader, count_tokens
14
- from context_pruning_env.graders import (
15
- grade_noise_purge,
16
- grade_dedupe_arena,
17
- grade_signal_extract
18
- )
19
-
20
- class ContextPruningEnv(Environment[ContextAction, ContextObservation, PruningState]):
21
- """
22
- Hackathon-compliant Context Pruning Environment.
23
- """
24
-
25
- def __init__(self, squad_split: str = "train"):
26
- super().__init__(transform=None, rubric=None)
27
- self.loader = SQuADLoader(split=squad_split)
28
- self._state = None
29
-
30
- def reset(
31
- self,
32
- seed: Optional[int] = None,
33
- episode_id: Optional[str] = None,
34
- task_name: Optional[str] = "noise_purge",
35
- **kwargs: Any,
36
- ) -> ContextObservation:
37
- """
38
- Starts a new episode with the specified task.
39
- """
40
- task_name = task_name or "noise_purge"
41
- question, chunks_data = self.loader.get_episode(task_name)
42
-
43
- chunks = []
44
- total_tokens = 0
45
- for item in chunks_data:
46
- tokens = count_tokens(item["content"])
47
- total_tokens += tokens
48
- chunks.append(ChunkItem(
49
- content=item["content"],
50
- is_gold=item["is_gold"],
51
- is_duplicate=item["is_duplicate"],
52
- tokens=tokens
53
- ))
54
-
55
- self._state = PruningState(
56
- episode_id=episode_id or str(uuid4()),
57
- task_name=task_name,
58
- question=question,
59
- chunks=chunks,
60
- initial_tokens=total_tokens,
61
- step_count=0,
62
- done=False
63
- )
64
-
65
- return self._observe(message=f"Task '{task_name}' initialized.")
66
-
67
- def _observe(self, message: str = "") -> ContextObservation:
68
- """Create observation from state."""
69
- return ContextObservation(
70
- done=self._state.done,
71
- question=self._state.question,
72
- chunks=[c.content for c in self._state.chunks],
73
- initial_token_count=self._state.initial_tokens,
74
- current_token_count=sum(c.tokens for c in self._state.chunks),
75
- task_name=self._state.task_name,
76
- message=message
77
- )
78
-
79
- def step(
80
- self,
81
- action: ContextAction,
82
- **kwargs: Any,
83
- ) -> ContextObservation:
84
- """
85
- Takes a binary mask and calculates rewards based on trajectory signals.
86
- """
87
- if self._state.done:
88
- return self._observe(message="Episode is already done.")
89
-
90
- mask = action.mask
91
- if len(mask) != len(self._state.chunks):
92
- # Pad with 0 (Prune) instead of 1 (Keep) to ensure agent optimization
93
- mask = (mask + [0] * len(self._state.chunks))[:len(self._state.chunks)]
94
-
95
- # Trajectory Simulation Logic
96
- total_reward = 0.0
97
- efficiency_reward = 0.0
98
- accuracy_reward = 0.0
99
- gold_penalty = 0.0
100
- success = True
101
-
102
- for i, kept in enumerate(mask):
103
- chunk = self._state.chunks[i]
104
- if not kept: # Pruned
105
- if chunk.is_gold:
106
- # Critical Failure
107
- gold_penalty = -1.0
108
- success = False
109
- break # Immediate stop
110
- else:
111
- # Correctly pruned noise/duplicate
112
- efficiency_reward += 0.1
113
- else: # Kept
114
- pass
115
-
116
- # Final Accuracy Bonus
117
- if success:
118
- accuracy_reward = 0.7
119
-
120
- total_reward = efficiency_reward + accuracy_reward + gold_penalty
121
-
122
- # Task Score (Normalized 0.0 to 1.0 for the evaluator)
123
- if self._state.task_name == "noise_purge":
124
- score_obj = grade_noise_purge(mask, self._state.chunks)
125
- elif self._state.task_name == "dedupe_arena":
126
- score_obj = grade_dedupe_arena(mask, self._state.chunks)
127
- elif self._state.task_name == "signal_extract":
128
- score_obj = grade_signal_extract(mask, self._state.chunks)
129
- else:
130
- score_obj = grade_noise_purge(mask, self._state.chunks)
131
-
132
- self._state.done = True
133
- self._state.step_count += 1
134
-
135
- obs = self._observe(message=score_obj.message)
136
- obs.reward = total_reward # Trajectory reward
137
-
138
- if not obs.metadata:
139
- obs.metadata = {}
140
- obs.metadata["eval_score"] = score_obj.score # Grader score
141
- obs.metadata["reward_detail"] = {
142
- "efficiency": efficiency_reward,
143
- "accuracy": accuracy_reward,
144
- "penalty": gold_penalty
145
- }
146
-
147
- return obs
148
-
149
- @property
150
- def state(self) -> PruningState:
151
- """Official state access as required by openenv-core."""
152
- return self._state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
context_pruning_env/graders.py DELETED
@@ -1,56 +0,0 @@
1
- from typing import List
2
- from context_pruning_env.models import ChunkItem, ContextReward
3
-
4
- def grade_noise_purge(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
5
- """
6
- Easy Task: Score 1.0 if gold kept AND noise pruned.
7
- """
8
- gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
9
- noise_pruned = all(mask[i] == 0 for i in range(len(mask)) if not chunks[i].is_gold)
10
-
11
- if not gold_kept:
12
- return ContextReward(score=0.0, gold_penalty=-1.0, message="Critical: Gold chunk lost.")
13
-
14
- if noise_pruned:
15
- return ContextReward(score=1.0, message="Perfect: All noise purged.")
16
- else:
17
- return ContextReward(score=0.5, message="Partial: Gold kept but noise remains.")
18
-
19
- def grade_dedupe_arena(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
20
- """
21
- Medium Task: 1.0 if word count reduced > 50% AND gold kept.
22
- """
23
- initial_words = sum(len(c.content.split()) for c in chunks)
24
- final_words = sum(len(chunks[i].content.split()) for i, kept in enumerate(mask) if kept)
25
-
26
- gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
27
- reduction = 1.0 - (final_words / initial_words) if initial_words > 0 else 1.0
28
-
29
- if not gold_kept:
30
- return ContextReward(score=0.0, message="Critical: Answer lost during deduplication.")
31
-
32
- if reduction >= 0.5:
33
- return ContextReward(score=1.0, message=f"Great: {reduction:.1%} word reduction achieved.")
34
- else:
35
- return ContextReward(score=0.5, message=f"Partial: Only {reduction:.1%} reduction.")
36
-
37
- def grade_signal_extract(mask: List[int], chunks: List[ChunkItem]) -> ContextReward:
38
- """
39
- Hard Task: 1 - (FinalTokens/InitialTokens) if gold kept.
40
- """
41
- initial_tokens = sum(c.tokens for c in chunks)
42
- final_tokens = sum(chunks[i].tokens for i, kept in enumerate(mask) if kept)
43
-
44
- gold_kept = any(mask[i] == 1 and chunks[i].is_gold for i in range(len(mask)))
45
-
46
- if not gold_kept:
47
- return ContextReward(score=0.0, message="Critical: Signal lost in noise.")
48
-
49
- reduction_score = 1.0 - (final_tokens / initial_tokens) if initial_tokens > 0 else 0.0
50
- # Ensure score is at least positive if gold is kept
51
- final_score = max(0.1, reduction_score)
52
-
53
- return ContextReward(
54
- score=final_score,
55
- message=f"Signal Extracted: {reduction_score:.1%} compression."
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
context_pruning_env/models.py DELETED
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
- from typing import List, Optional, Any, Dict
3
- from pydantic import BaseModel, Field
4
- from openenv.core.env_server.types import Action, Observation, State
5
-
6
- class ContextAction(Action):
7
- """
8
- Action space: A binary mask of N values (1 = keep, 0 = prune).
9
- """
10
- mask: List[int] = Field(
11
- ...,
12
- min_length=1,
13
- description="Binary mask of integers (0 or 1) indicating which chunks to keep."
14
- )
15
-
16
- class ContextObservation(Observation):
17
- """
18
- Observation provided to the agent.
19
- """
20
- question: str
21
- chunks: List[str] = Field(default_factory=list, description="Current context chunks.")
22
- initial_token_count: int = 0
23
- current_token_count: int = 0
24
- task_name: str = ""
25
- message: str = ""
26
-
27
- class ContextReward(BaseModel):
28
- """
29
- Detailed reward breakdown for Meta x Scaler audit.
30
- """
31
- score: float = Field(0.0, ge=0.0, le=1.0, description="Overall task score (0 to 1).")
32
- efficiency_reward: float = 0.0
33
- accuracy_reward: float = 0.0
34
- gold_penalty: float = 0.0
35
- message: str = ""
36
-
37
- class ChunkItem(BaseModel):
38
- """Internal representation of a context chunk."""
39
- content: str
40
- is_gold: bool = False
41
- tokens: int = 0
42
- is_duplicate: bool = False
43
-
44
- class PruningState(State):
45
- """
46
- Internal state for ContextPrune.
47
- """
48
- task_name: str
49
- question: str
50
- chunks: List[ChunkItem]
51
- initial_tokens: int
52
- step_count: int = 0
53
- done: bool = False
54
- metadata: Dict[str, Any] = Field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
context_pruning_env/utils.py DELETED
@@ -1,80 +0,0 @@
1
- import random
2
- import re
3
- from typing import List, Tuple, Dict, Any
4
- from datasets import load_dataset
5
- import logging
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
- class SQuADLoader:
10
- def __init__(self, split: str = "train"):
11
- try:
12
- self.dataset = load_dataset("squad", split=split)
13
- except Exception as e:
14
- logger.error(f"Failed to load SQuAD: {e}")
15
- self.dataset = []
16
- self.indices = list(range(len(self.dataset)))
17
- random.shuffle(self.indices)
18
- self.current_ptr = 0
19
-
20
- def _get_next_entry(self):
21
- if self.current_ptr >= len(self.indices):
22
- random.shuffle(self.indices)
23
- self.current_ptr = 0
24
- idx = self.indices[self.current_ptr]
25
- self.current_ptr += 1
26
- return idx, self.dataset[idx]
27
-
28
- def get_episode(self, task_name: str) -> Tuple[str, List[Dict[str, Any]]]:
29
- """
30
- Returns (question, List[Dict(content, is_gold, is_duplicate)])
31
- """
32
- idx, entry = self._get_next_entry()
33
- question = entry["question"]
34
- gold_context = entry["context"]
35
-
36
- chunks = []
37
-
38
- if task_name == "noise_purge":
39
- # Easy: 1 Gold + 1 Irrelevant
40
- chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
41
- _, noise_entry = self._get_next_entry()
42
- chunks.append({"content": noise_entry["context"], "is_gold": False, "is_duplicate": False})
43
-
44
- elif task_name == "dedupe_arena":
45
- # Medium: 1 Gold + 2 Near-Duplicates (Simulated by repeating gold)
46
- chunks.append({"content": gold_context, "is_gold": True, "is_duplicate": False})
47
- # Duplicate 1: slightly modified or identical
48
- chunks.append({"content": gold_context + " ", "is_gold": True, "is_duplicate": True})
49
- # Duplicate 2: slightly modified
50
- chunks.append({"content": "Actually, " + gold_context, "is_gold": True, "is_duplicate": True})
51
-
52
- elif task_name == "signal_extract":
53
- # Hard: 1 Gold context + multiple noise (2,000+ words total)
54
- long_context_parts = [gold_context]
55
- current_words = len(gold_context.split())
56
- while current_words < 2200: # Ensure 2,000+ words
57
- _, noise_entry = self._get_next_entry()
58
- content = noise_entry["context"]
59
- long_context_parts.append(content)
60
- current_words += len(content.split())
61
-
62
- # Shuffling the parts so the gold one isn't first
63
- random.shuffle(long_context_parts)
64
- for part in long_context_parts:
65
- is_gold = (part == gold_context)
66
- chunks.append({"content": part, "is_gold": is_gold, "is_duplicate": False})
67
-
68
- else:
69
- # Default to noise_purge
70
- return self.get_episode("noise_purge")
71
-
72
- # Shuffle chunks for non-signal tasks
73
- if task_name != "signal_extract":
74
- random.shuffle(chunks)
75
-
76
- return question, chunks
77
-
78
- def count_tokens(text: str) -> int:
79
- """Standard token counter for efficiency rewards."""
80
- return len(text.split())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv.yaml CHANGED
@@ -14,5 +14,5 @@ tasks:
14
  action_space: ["inspect_artifact", "prioritize_artifact", "summarize_artifact", "set_resolution_plan", "submit_report"]
15
  observation_space: ["case_summary", "objective", "workflow_stage", "available_artifacts", "reviewed_artifacts", "prioritized_artifacts", "plan_draft", "total_tokens_used", "token_budget"]
16
  reward_range: [0.0, 1.0]
17
- port: 8000
18
  app: app.py
 
14
  action_space: ["inspect_artifact", "prioritize_artifact", "summarize_artifact", "set_resolution_plan", "submit_report"]
15
  observation_space: ["case_summary", "objective", "workflow_stage", "available_artifacts", "reviewed_artifacts", "prioritized_artifacts", "plan_draft", "total_tokens_used", "token_budget"]
16
  reward_range: [0.0, 1.0]
17
+ port: 7860
18
  app: app.py
rag_gc_env/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- from rag_gc_env.models import RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
2
- from rag_gc_env.environment import RAGGCEnvironment
3
-
4
- __all__ = [
5
- "RAGGCAction",
6
- "RAGGCObservation",
7
- "RAGGCReward",
8
- "RAGGCState",
9
- "RAGGCEnvironment",
10
- ]
11
-
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (440 Bytes)
 
rag_gc_env/__pycache__/environment.cpython-311.pyc DELETED
Binary file (9.34 kB)
 
rag_gc_env/__pycache__/grader.cpython-311.pyc DELETED
Binary file (2.75 kB)
 
rag_gc_env/__pycache__/inference.cpython-311.pyc DELETED
Binary file (2.66 kB)
 
rag_gc_env/__pycache__/models.cpython-311.pyc DELETED
Binary file (3.51 kB)
 
rag_gc_env/__pycache__/rewards.cpython-311.pyc DELETED
Binary file (3.92 kB)
 
rag_gc_env/__pycache__/tasks.cpython-311.pyc DELETED
Binary file (5.11 kB)
 
rag_gc_env/environment.py DELETED
@@ -1,187 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Optional
4
- from uuid import uuid4
5
-
6
- from openenv.core.env_server.interfaces import Environment
7
-
8
- from rag_gc_env.grader import grade_context
9
- from rag_gc_env.models import DocumentItem, RAGGCAction, RAGGCObservation, RAGGCReward, RAGGCState
10
- from rag_gc_env.rewards import step_reward, summarize_deterministic
11
- from rag_gc_env.tasks import ALL_TASKS, TaskSpec, task_by_seed
12
-
13
-
14
- class RAGGCEnvironment(Environment[RAGGCAction, RAGGCObservation, RAGGCState]):
15
- SUPPORTS_CONCURRENT_SESSIONS = True
16
-
17
- def __init__(self) -> None:
18
- super().__init__(transform=None, rubric=None)
19
- self._state = RAGGCState(episode_id=str(uuid4()), step_count=0)
20
- self._task: TaskSpec = task_by_seed(0)
21
- self._docs: dict[str, DocumentItem] = {}
22
- self._removed_critical = False
23
-
24
- def _load_task(self, spec: TaskSpec) -> None:
25
- self._docs = {}
26
- for did, text, tok, _meta in spec.documents:
27
- self._docs[did] = DocumentItem(document_id=did, text=text, tokens=tok)
28
-
29
- def reset(
30
- self,
31
- seed: Optional[int] = None,
32
- episode_id: Optional[str] = None,
33
- task_name: Optional[str] = None,
34
- **kwargs: Any,
35
- ) -> RAGGCObservation:
36
- self._reset_rubric()
37
- sid = episode_id or str(uuid4())
38
- if task_name and task_name in ALL_TASKS:
39
- self._task = ALL_TASKS[task_name]
40
- elif seed is not None:
41
- self._task = task_by_seed(int(seed))
42
- else:
43
- self._task = task_by_seed(0)
44
- self._load_task(self._task)
45
- self._removed_critical = False
46
- self._state = RAGGCState(
47
- episode_id=sid,
48
- step_count=0,
49
- task_name=self._task.name,
50
- max_steps=64,
51
- removed_critical=False,
52
- submitted=False,
53
- )
54
- return self._observe(done=False, reward_value=0.0, msg="ready")
55
-
56
- def _total_tokens(self) -> int:
57
- return sum(d.tokens for d in self._docs.values())
58
-
59
- def _observe(
60
- self,
61
- done: bool,
62
- reward_value: float,
63
- msg: str,
64
- reward_detail: Optional[RAGGCReward] = None,
65
- grader_score: Optional[float] = None,
66
- ) -> RAGGCObservation:
67
- docs = sorted(self._docs.values(), key=lambda x: x.document_id)
68
- return RAGGCObservation(
69
- done=done,
70
- reward=reward_value,
71
- query=self._task.query,
72
- documents=docs,
73
- token_count=self._total_tokens(),
74
- token_budget=self._task.token_budget,
75
- task_name=self._task.name,
76
- message=msg,
77
- grader_score=grader_score,
78
- reward_detail=reward_detail,
79
- metadata={
80
- "relevance": {
81
- row[0]: row[3].get("relevance", 0.5)
82
- for row in self._task.documents
83
- if row[0] in self._docs
84
- },
85
- "hints": {row[0]: row[3].get("hint", "") for row in self._task.documents},
86
- },
87
- )
88
-
89
- def step(
90
- self,
91
- action: RAGGCAction,
92
- timeout_s: Optional[float] = None,
93
- **kwargs: Any,
94
- ) -> RAGGCObservation:
95
- self._state.step_count += 1
96
- docs_before = dict(self._docs)
97
-
98
- if action.verb == "submit":
99
- score = grade_context(self._task, list(self._docs.values()))
100
- self._state.submitted = True
101
- r = RAGGCReward(
102
- step_reward=score,
103
- final_score=score,
104
- )
105
- obs = self._observe(
106
- done=True,
107
- reward_value=score,
108
- msg="submitted",
109
- reward_detail=r,
110
- grader_score=score,
111
- )
112
- return self._apply_transform(obs)
113
-
114
- if action.document_id is None or action.document_id not in self._docs:
115
- obs = self._observe(
116
- done=False,
117
- reward_value=-0.1,
118
- msg="unknown_document",
119
- )
120
- return self._apply_transform(obs)
121
-
122
- did = action.document_id
123
- removed_critical = False
124
-
125
- if action.verb == "delete":
126
- if did in self._task.critical_document_ids:
127
- self._removed_critical = True
128
- removed_critical = True
129
- self._docs.pop(did, None)
130
-
131
- elif action.verb == "keep":
132
- pass
133
-
134
- elif action.verb == "summarize":
135
- item = self._docs[did]
136
- new_text, new_tok = summarize_deterministic(item.text)
137
- self._docs[did] = DocumentItem(
138
- document_id=did,
139
- text=new_text,
140
- tokens=new_tok,
141
- )
142
- if did in self._task.critical_document_ids:
143
- for p in self._task.required_phrases:
144
- if p not in new_text:
145
- self._removed_critical = True
146
- removed_critical = True
147
-
148
- rdetail = step_reward(
149
- self._task,
150
- action.verb,
151
- did,
152
- docs_before,
153
- self._docs,
154
- removed_critical,
155
- )
156
- self._state.removed_critical = self._removed_critical
157
-
158
- over = self._total_tokens() > self._task.token_budget
159
- if over:
160
- penalty = -0.08 * (self._total_tokens() - self._task.token_budget)
161
- rdetail.token_penalty += penalty
162
- rdetail.step_reward += penalty
163
- done = self._state.step_count >= self._state.max_steps
164
- final_score: Optional[float] = None
165
- if done:
166
- final_score = grade_context(self._task, list(self._docs.values()))
167
- rdetail.final_score = final_score
168
- rdetail.step_reward += final_score * 0.5
169
-
170
- reward_val = rdetail.step_reward
171
- if done:
172
- # When done, the reward is primarily the final grader score,
173
- # but we can preserve the step-specific bonus we added.
174
- # final_score is the main signal.
175
- reward_val = final_score if final_score is not None else rdetail.step_reward
176
- obs = self._observe(
177
- done=done,
178
- reward_value=reward_val,
179
- msg="over_budget" if over else ("graded" if done else "ok"),
180
- reward_detail=rdetail,
181
- grader_score=final_score if done else None,
182
- )
183
- return self._apply_transform(obs)
184
-
185
- @property
186
- def state(self) -> RAGGCState:
187
- return self._state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/grader.py DELETED
@@ -1,43 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from rag_gc_env.models import DocumentItem
4
- from rag_gc_env.tasks import TaskSpec
5
-
6
-
7
- def _joined_text(docs: list[DocumentItem]) -> str:
8
- return " ".join(d.text for d in docs)
9
-
10
-
11
- def grade_context(task: TaskSpec, final_documents: list[DocumentItem]) -> float:
12
- """
13
- Deterministic score in [0.0, 0.5, 1.0]:
14
- 1.0 — required facts present, budget respected, efficient (near optimal tokens)
15
- 0.5 — required facts present but inefficient or borderline budget
16
- 0.0 — missing facts, forbidden content present, or critical docs removed incorrectly
17
- """
18
- text = _joined_text(final_documents)
19
- total_tokens = sum(d.tokens for d in final_documents)
20
-
21
- for phrase in task.required_phrases:
22
- if phrase not in text:
23
- return 0.0
24
-
25
- for phrase in task.forbidden_phrases:
26
- if phrase in text:
27
- return 0.0
28
-
29
- for pid in task.poison_document_ids:
30
- still = any(d.document_id == pid for d in final_documents)
31
- if still:
32
- return 0.0
33
-
34
- if total_tokens > task.token_budget:
35
- return 0.0
36
-
37
- if not task.critical_document_ids.issubset({d.document_id for d in final_documents}):
38
- return 0.0
39
-
40
- if total_tokens <= task.optimal_max_tokens:
41
- return 1.0
42
-
43
- return 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/inference.py DELETED
@@ -1,49 +0,0 @@
1
- """
2
- Reproducible baseline policy for Adaptive Context Optimization (RAG GC).
3
- Deterministic: fixed action sequences per task derived from metadata.
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- from rag_gc_env.environment import RAGGCEnvironment
9
- from rag_gc_env.models import RAGGCAction
10
-
11
-
12
- def run_baseline(task_name: str, seed: int = 0) -> tuple[float, list[str]]:
13
- env = RAGGCEnvironment()
14
- obs = env.reset(seed=seed, task_name=task_name)
15
- log: list[str] = ["reset"]
16
-
17
- def step(verb: str, doc_id: str | None) -> None:
18
- nonlocal obs
19
- obs = env.step(RAGGCAction(verb=verb, document_id=doc_id))
20
- log.append(f"{verb}:{doc_id}")
21
-
22
- if task_name == "easy_irrelevant_removal":
23
- step("delete", "d1")
24
- step("submit", None)
25
- elif task_name == "medium_token_compression":
26
- step("delete", "m2")
27
- while obs.token_count > obs.token_budget and not obs.done:
28
- step("summarize", "m0")
29
- if len(log) > 40:
30
- break
31
- step("submit", None)
32
- elif task_name == "hard_contradiction_removal":
33
- step("delete", "h1")
34
- step("submit", None)
35
- else:
36
- step("submit", None)
37
-
38
- score = float(obs.grader_score or obs.reward or 0.0)
39
- return score, log
40
-
41
-
42
- if __name__ == "__main__":
43
- for name in (
44
- "easy_irrelevant_removal",
45
- "medium_token_compression",
46
- "hard_contradiction_removal",
47
- ):
48
- s, lg = run_baseline(name, seed=0)
49
- print(name, "score=", s, "trace=", lg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/models.py DELETED
@@ -1,53 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Literal, Optional
4
-
5
- from openenv.core.env_server.types import Action, Observation, State
6
- from pydantic import BaseModel, Field
7
-
8
-
9
- class DocumentItem(BaseModel):
10
- document_id: str
11
- text: str
12
- tokens: int = Field(description="Estimated tokens for this snippet")
13
-
14
-
15
- class RAGGCAction(Action):
16
- verb: Literal["keep", "delete", "summarize", "submit"] = Field(
17
- description="Document operation or submit to finalize and grade"
18
- )
19
- document_id: Optional[str] = Field(
20
- default=None,
21
- description="Target document for keep/delete/summarize; omit for submit",
22
- )
23
-
24
-
25
- class RAGGCReward(BaseModel):
26
- step_reward: float = 0.0
27
- relevance: float = 0.0
28
- compression: float = 0.0
29
- token_penalty: float = 0.0
30
- critical_penalty: float = 0.0
31
- final_score: Optional[float] = Field(
32
- default=None, description="0.0–1.0 after submit; aligns with grader"
33
- )
34
-
35
-
36
- class RAGGCObservation(Observation):
37
- query: str = ""
38
- documents: list[DocumentItem] = Field(default_factory=list)
39
- token_count: int = 0
40
- token_budget: int = 0
41
- task_name: str = ""
42
- reward_detail: Optional[RAGGCReward] = None
43
- message: str = ""
44
- grader_score: Optional[float] = Field(
45
- default=None, description="Deterministic score after episode ends"
46
- )
47
-
48
-
49
- class RAGGCState(State):
50
- task_name: str = ""
51
- max_steps: int = 64
52
- removed_critical: bool = False
53
- submitted: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/rewards.py DELETED
@@ -1,89 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from rag_gc_env.models import DocumentItem, RAGGCReward
4
- from rag_gc_env.tasks import TaskSpec
5
-
6
-
7
- def summarize_deterministic(text: str) -> tuple[str, int]:
8
- """Deterministic compression: first sentence or capped prefix."""
9
- stripped = text.strip()
10
- if not stripped:
11
- return "", 1
12
- cut = stripped.split(". ")
13
- first = cut[0] + ("." if not cut[0].endswith(".") else "")
14
- if len(first) < 40 and len(cut) > 1:
15
- first = cut[0] + ". " + cut[1] + ("." if not cut[1].endswith(".") else "")
16
- cap = 280
17
- out = first[:cap] + ("..." if len(first) > cap else "")
18
- tokens = max(1, len(out) // 4)
19
- return out, tokens
20
-
21
-
22
- def estimate_tokens(text: str) -> int:
23
- return max(1, len(text) // 4)
24
-
25
-
26
- def step_reward(
27
- task: TaskSpec,
28
- verb: str,
29
- doc_id: str | None,
30
- docs_before: dict[str, DocumentItem],
31
- docs_after: dict[str, DocumentItem],
32
- removed_critical_flag: bool,
33
- ) -> RAGGCReward:
34
- rel = 0.0
35
- comp = 0.0
36
- tok_pen = 0.0
37
- crit = 0.0
38
-
39
- if removed_critical_flag:
40
- crit = -3.0
41
-
42
- if verb == "delete" and doc_id in docs_before:
43
- meta = next(
44
- (m for did, _, _, m in task.documents if did == doc_id),
45
- {},
46
- )
47
- # Reward deleting irrelevant or poison documents
48
- if doc_id in task.irrelevant_document_ids:
49
- rel += 0.4
50
- elif doc_id in task.poison_document_ids:
51
- rel += 0.6
52
- elif doc_id in task.critical_document_ids:
53
- crit -= 3.0
54
- elif meta.get("hint") == "fluff":
55
- rel += 0.2
56
-
57
- # Deleting tokens should NOT result in a penalty proportional to the deleted tokens;
58
- # instead, it removes the 'keep' penalty they would have incurred.
59
- # We can add a small constant 'action cost' for deleting if desired, but 0.0 is fine here.
60
- tok_pen = 0.0
61
-
62
- if verb == "summarize" and doc_id in docs_before:
63
- before_t = docs_before[doc_id].tokens
64
- after = docs_after.get(doc_id)
65
- if after is not None:
66
- # Reward for the reduction in size (efficiency)
67
- reduction_ratio = (before_t - after.tokens) / max(before_t, 1)
68
- comp += 0.3 * max(0.0, reduction_ratio)
69
-
70
- # The remaining tokens still incur a small penalty
71
- tok_pen -= 0.01 * after.tokens
72
-
73
- if doc_id in task.critical_document_ids:
74
- for p in task.required_phrases:
75
- if p not in after.text:
76
- crit -= 2.5
77
-
78
- if verb == "keep" and doc_id in docs_before:
79
- # Standard penalty for keeping tokens in context
80
- tok_pen -= 0.01 * docs_before[doc_id].tokens
81
-
82
- step = rel + comp + tok_pen + crit
83
- return RAGGCReward(
84
- step_reward=step,
85
- relevance=rel,
86
- compression=comp,
87
- token_penalty=tok_pen,
88
- critical_penalty=crit,
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/server/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Server package for OpenEnv HTTP deployment
 
 
rag_gc_env/server/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (154 Bytes)
 
rag_gc_env/server/__pycache__/app.cpython-311.pyc DELETED
Binary file (1.04 kB)
 
rag_gc_env/server/app.py DELETED
@@ -1,23 +0,0 @@
1
- import os
2
-
3
- from openenv.core.env_server.http_server import create_fastapi_app
4
-
5
- from rag_gc_env.environment import RAGGCEnvironment
6
- from rag_gc_env.models import RAGGCAction, RAGGCObservation
7
-
8
- app = create_fastapi_app(
9
- RAGGCEnvironment,
10
- RAGGCAction,
11
- RAGGCObservation,
12
- )
13
-
14
-
15
- def main() -> None:
16
- import uvicorn
17
-
18
- port = int(os.environ.get("PORT", "8000"))
19
- uvicorn.run(app, host="0.0.0.0", port=port)
20
-
21
-
22
- if __name__ == "__main__":
23
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_gc_env/tasks.py DELETED
@@ -1,144 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass, field
4
- from typing import Any, FrozenSet
5
-
6
-
7
- @dataclass(frozen=True)
8
- class TaskSpec:
9
- name: str
10
- query: str
11
- token_budget: int
12
- documents: list[tuple[str, str, int, dict[str, Any]]]
13
- # document_id, text, tokens, metadata (relevance, flags)
14
- required_phrases: FrozenSet[str] = field(default_factory=frozenset)
15
- forbidden_phrases: FrozenSet[str] = field(default_factory=frozenset)
16
- critical_document_ids: FrozenSet[str] = field(default_factory=frozenset)
17
- irrelevant_document_ids: FrozenSet[str] = field(default_factory=frozenset)
18
- poison_document_ids: FrozenSet[str] = field(default_factory=frozenset)
19
- optimal_max_tokens: int = 0
20
-
21
-
22
- def _docs(
23
- rows: list[tuple[str, str, int, dict[str, Any]]],
24
- ) -> list[tuple[str, str, int, dict[str, Any]]]:
25
- return rows
26
-
27
-
28
- TASK_EASY = TaskSpec(
29
- name="easy_irrelevant_removal",
30
- query="What is the capital city of France?",
31
- token_budget=400,
32
- documents=_docs(
33
- [
34
- (
35
- "d0",
36
- "Paris has been the capital of France since political centralization in the country.",
37
- 24,
38
- {"relevance": 0.95, "hint": "high"},
39
- ),
40
- (
41
- "d1",
42
- "Penguins thrive in Antarctica and are unrelated to European geography.",
43
- 18,
44
- {"relevance": 0.08, "hint": "noise"},
45
- ),
46
- (
47
- "d2",
48
- "Lyon is a major French city but not the national capital.",
49
- 16,
50
- {"relevance": 0.55, "hint": "partial"},
51
- ),
52
- ]
53
- ),
54
- required_phrases=frozenset({"Paris"}),
55
- forbidden_phrases=frozenset(),
56
- critical_document_ids=frozenset({"d0"}),
57
- irrelevant_document_ids=frozenset({"d1"}),
58
- poison_document_ids=frozenset(),
59
- optimal_max_tokens=120,
60
- )
61
-
62
- _LONG_DUP = (
63
- "Paris is the capital of France. " * 18
64
- + "This repetition exists only to inflate token usage for compression tests."
65
- )
66
-
67
- TASK_MEDIUM = TaskSpec(
68
- name="medium_token_compression",
69
- query="Which city is the capital of France?",
70
- token_budget=120,
71
- documents=_docs(
72
- [
73
- (
74
- "m0",
75
- _LONG_DUP,
76
- max(1, len(_LONG_DUP) // 4),
77
- {"relevance": 0.9, "hint": "verbose"},
78
- ),
79
- (
80
- "m1",
81
- "Administrative records list a capital city but this line omits the name intentionally.",
82
- 14,
83
- {"relevance": 0.55, "hint": "no_answer"},
84
- ),
85
- (
86
- "m2",
87
- "French cuisine is diverse; it does not change which city is the capital.",
88
- 14,
89
- {"relevance": 0.35, "hint": "fluff"},
90
- ),
91
- ]
92
- ),
93
- required_phrases=frozenset({"Paris"}),
94
- forbidden_phrases=frozenset(),
95
- critical_document_ids=frozenset(),
96
- irrelevant_document_ids=frozenset({"m2"}),
97
- poison_document_ids=frozenset(),
98
- optimal_max_tokens=90,
99
- )
100
-
101
- TASK_HARD = TaskSpec(
102
- name="hard_contradiction_removal",
103
- query="Compute the integer result of 17 + 25.",
104
- token_budget=350,
105
- documents=_docs(
106
- [
107
- (
108
- "h0",
109
- "Arithmetic trace: 17 + 25 = 42.",
110
- 12,
111
- {"relevance": 0.9, "hint": "consistent", "trust": 0.95},
112
- ),
113
- (
114
- "h1",
115
- "Quick math note: 17 + 25 equals 43 for budgeting purposes.",
116
- 14,
117
- {"relevance": 0.88, "hint": "poison", "trust": 0.2},
118
- ),
119
- (
120
- "h2",
121
- "Addition of integers is associative and commutative.",
122
- 10,
123
- {"relevance": 0.4, "hint": "generic"},
124
- ),
125
- ]
126
- ),
127
- required_phrases=frozenset({"42"}),
128
- forbidden_phrases=frozenset({"43"}),
129
- critical_document_ids=frozenset({"h0"}),
130
- irrelevant_document_ids=frozenset(),
131
- poison_document_ids=frozenset({"h1"}),
132
- optimal_max_tokens=200,
133
- )
134
-
135
- ALL_TASKS: dict[str, TaskSpec] = {
136
- TASK_EASY.name: TASK_EASY,
137
- TASK_MEDIUM.name: TASK_MEDIUM,
138
- TASK_HARD.name: TASK_HARD,
139
- }
140
-
141
-
142
- def task_by_seed(seed: int) -> TaskSpec:
143
- order = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
144
- return order[seed % 3]