prithic07 commited on
Commit
222f8ce
·
1 Parent(s): f7594d7

refactor: Migrate to Gemini 1.5 Flash exclusively for pruning and validation

Browse files
Files changed (6) hide show
  1. .dockerignore +14 -7
  2. .gitignore +42 -0
  3. README.md +3 -0
  4. app_ui.py +155 -0
  5. inference.py +23 -15
  6. requirements.txt +4 -1
.dockerignore CHANGED
@@ -1,8 +1,15 @@
1
- __pycache__
 
 
 
2
  *.pyc
3
- .git
4
- .venv
5
- venv
6
- *.md
7
- .pytest_cache
8
- .mypy_cache
 
 
 
 
 
1
+ # Docker Ignore
2
+ .env
3
+ .git/
4
+ __pycache__/
5
  *.pyc
6
+ .pytest_cache/
7
+ .vscode/
8
+ .idea/
9
+ venv/
10
+ .venv/
11
+ README.md
12
+ walkthrough.md
13
+ task.md
14
+ implementation_plan.md
15
+ C:/Users/prith/.gemini/antigravity/brain/
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Keys & Sensitive Info
2
+ .env
3
+ *.pem
4
+ *.key
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ env/
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+
29
+ # Virtual Environments
30
+ venv/
31
+ .venv/
32
+ ENV/
33
+
34
+ # Testing
35
+ .pytest_cache/
36
+ .coverage
37
+ htmlcov/
38
+
39
+ # IDEs
40
+ .vscode/
41
+ .idea/
42
+ .DS_Store
README.md CHANGED
@@ -20,6 +20,9 @@
20
  # Install dependencies
21
  pip install -r requirements.txt
22
 
 
 
 
23
  # Verify the environment and task logic
24
  pytest test_tasks.py
25
  ```
 
20
  # Install dependencies
21
  pip install -r requirements.txt
22
 
23
+ # Set your Gemini API Key
24
+ export GOOGLE_API_KEY=your_key_here
25
+
26
  # Verify the environment and task logic
27
  pytest test_tasks.py
28
  ```
app_ui.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import asyncio
5
+ import gradio as gr
6
+ import google.generativeai as genai
7
+ from dotenv import load_dotenv
8
+
9
+ # Load API keys from .env
10
+ load_dotenv()
11
+ from typing import List, Tuple
12
+ from context_pruning_env.utils import count_tokens
13
+
14
+ # --- Configuration ---
15
+ # Set these in your environment or replace with mock keys for testing
16
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
17
+ if GOOGLE_API_KEY:
18
+ genai.configure(api_key=GOOGLE_API_KEY)
19
+
20
+ # --- Core Logic ---
21
+
22
+ async def call_gemini(prompt: str, model_name: str = "gemini-1.5-flash") -> str:
23
+ """Helper to call Gemini API."""
24
+ if not GOOGLE_API_KEY:
25
+ return "ERROR: GOOGLE_API_KEY not found."
26
+ try:
27
+ model = genai.GenerativeModel(model_name)
28
+ response = await model.generate_content_async(prompt)
29
+ return response.text
30
+ except Exception as e:
31
+ return f"ERROR: {str(e)}"
32
+
33
+ def chunk_text(text: str, max_chunks: int = 5) -> List[str]:
34
+ """Split text into manageable chunks (paragraphs or sentences)."""
35
+ # Split by double newline first
36
+ chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()]
37
+ if len(chunks) < 2:
38
+ # Split by sentence if only one paragraph
39
+ chunks = [c.strip() for c in re.split(r'(?<=[.!?])\s+', text) if c.strip()]
40
+
41
+ # Simple limit to 5-10 chunks for the demo
42
+ return chunks[:10]
43
+
44
+ async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]:
45
+ """
46
+ Main logic: Chunks text -> LLM selects -> Reassembles -> Calculates Metrics
47
+ """
48
+ if not query or not raw_text:
49
+ return "Please provide both query and raw context.", {}, ""
50
+
51
+ chunks = chunk_text(raw_text)
52
+
53
+ # Prompt for selection
54
+ selection_prompt = (
55
+ f"Query: {query}\n\n"
56
+ "Below are several context chunks. Identify which are RELEVANT and which are NOISE or DUPLICATES. "
57
+ "Output a JSON list of indices (0-indexed) of the chunks to KEEP.\n"
58
+ "Example output: [0, 2, 3]\n\n"
59
+ "Chunks:\n"
60
+ )
61
+ for i, c in enumerate(chunks):
62
+ selection_prompt += f"Chunk {i}: {c}\n\n"
63
+
64
+ raw_response = await call_gemini(selection_prompt)
65
+
66
+ # Extract indices
67
+ match = re.search(r"\[([\d\s,]+)\]", raw_response)
68
+ if match:
69
+ try:
70
+ indices = json.loads(f"[{match.group(1)}]")
71
+ kept_chunks = [chunks[i] for i in indices if i < len(chunks)]
72
+ except:
73
+ kept_chunks = chunks # Fallback
74
+ else:
75
+ kept_chunks = chunks # Fallback
76
+
77
+ optimized_text = " ".join(kept_chunks)
78
+
79
+ # Metrics
80
+ orig_tokens = count_tokens(raw_text)
81
+ final_tokens = count_tokens(optimized_text)
82
+ reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0
83
+
84
+ metrics = {
85
+ "Original Word Count": f"{orig_tokens} words",
86
+ "Final Word Count": f"{final_tokens} words",
87
+ "Reduction": f"{reduction:.1f}%"
88
+ }
89
+
90
+ # Groundedness Check
91
+ groundedness_prompt = (
92
+ f"Question: {query}\n"
93
+ f"Context: {optimized_text}\n\n"
94
+ "Task: Check if the context contains enough information to answer the question. "
95
+ "Respond with 'PASS' or 'FAIL' followed by a one-sentence reasoning."
96
+ )
97
+ ground_result = await call_gemini(groundedness_prompt)
98
+
99
+ return optimized_text, metrics, ground_result
100
+
101
+ # --- UI Components ---
102
+
103
+ def get_status_html(result: str):
104
+ if "PASS" in result.upper():
105
+ return f'<div style="background-color: #d1fae5; color: #065f46; padding: 10px; border-radius: 8px; border: 1px solid #10b981; font-weight: bold;">✅ GROUNDEDNESS PASS: {result.replace("PASS", "").strip()}</div>'
106
+ elif "FAIL" in result.upper():
107
+ return f'<div style="background-color: #fee2e2; color: #991b1b; padding: 10px; border-radius: 8px; border: 1px solid #ef4444; font-weight: bold;">❌ GROUNDEDNESS FAIL: {result.replace("FAIL", "").strip()}</div>'
108
+ return f'<div style="background-color: #f3f4f6; padding: 10px; border-radius: 8px;">{result}</div>'
109
+
110
+ with gr.Blocks(theme=gr.themes.Soft(), title="ContextPrune | Adaptive Context Optimization") as demo:
111
+ gr.Markdown("""
112
+ # 🧠 ContextPrune
113
+ ### Adaptive Context Optimization Agent
114
+ *Reduce noise and tokens in RAG pipelines while preserving answer quality.*
115
+ """)
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=1):
119
+ query_input = gr.Textbox(label="User Query", placeholder="e.g., When was the Eiffel Tower built?", value="Who was the first person to walk on the moon?")
120
+ context_input = gr.Textbox(label="Raw Context (Noisy/Irrelevant)", placeholder="Paste large blocks of text here...", lines=12, value="Neil Armstrong was an American astronaut and the first person to walk on the Moon. He was also a naval aviator, test pilot, and university professor. [IGNORE THIS] The sky is sometimes blue but often grey in London. Neil Armstrong set foot on the moon in 1969. Some say the moon is made of cheese, but that is a myth. Neil Armstrong was the first person to walk on the moon.")
121
+ submit_btn = gr.Button("Optimize Context", variant="primary")
122
+
123
+ with gr.Column(scale=1):
124
+ optimized_output = gr.Textbox(label="Optimized Context", lines=10, interactive=False)
125
+ status_output = gr.HTML(label="Groundedness Check")
126
+
127
+ with gr.Row():
128
+ word_count_orig = gr.Label(label="Original Word Count")
129
+ word_count_final = gr.Label(label="Final Word Count")
130
+ reduction_pct = gr.Label(label="% Token Reduction")
131
+
132
+ def process(query, context):
133
+ # Run the async function synchronously for Gradio
134
+ loop = asyncio.new_event_loop()
135
+ asyncio.set_event_loop(loop)
136
+ opt_text, metrics, ground = loop.run_until_complete(prune_context(query, context))
137
+
138
+ status_html = get_status_html(ground)
139
+
140
+ return (
141
+ opt_text,
142
+ status_html,
143
+ metrics.get("Original Word Count", "0"),
144
+ metrics.get("Final Word Count", "0"),
145
+ metrics.get("Reduction", "0%")
146
+ )
147
+
148
+ submit_btn.click(
149
+ process,
150
+ inputs=[query_input, context_input],
151
+ outputs=[optimized_output, status_output, word_count_orig, word_count_final, reduction_pct]
152
+ )
153
+
154
+ if __name__ == "__main__":
155
+ demo.launch(server_port=7861)
inference.py CHANGED
@@ -1,18 +1,31 @@
1
  import os
2
  import json
3
  import logging
4
- from openai import OpenAI
 
 
5
  from context_pruning_env.env import ContextPruningEnv
 
 
 
6
  from context_pruning_env.models import ContextAction
7
 
8
  # Setup simple logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
12
  def main():
13
- # 1. Setup OpenAI Client
14
- # Ensure OPENAI_API_KEY is set in your environment
15
- client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "your-api-key-here"))
 
 
 
16
 
17
  # 2. Initialize Environment
18
  env = ContextPruningEnv(squad_split="train")
@@ -27,11 +40,10 @@ def main():
27
  obs = env.reset(task_name=task_name)
28
  print(f"<OBSERVATION>{obs.model_dump_json()}</OBSERVATION>")
29
 
30
- # 4. Agent Logic (LLM Call)
31
- # Construct prompt for the model
32
  prompt = (
33
  f"Question: {obs.question}\n\n"
34
- "Below are 5 context chunks. Output a JSON list of 5 integers (0 or 1) "
35
  "where 1 means 'keep' and 0 means 'prune'. "
36
  "Prioritize keeping the answer while removing noise and duplicates.\n"
37
  f"Chunks: {json.dumps(obs.chunks, indent=2)}\n\n"
@@ -39,22 +51,18 @@ def main():
39
  )
40
 
41
  try:
42
- response = client.chat.completions.create(
43
- model="gpt-4o", # or gpt-3.5-turbo
44
- messages=[{"role": "user", "content": prompt}]
45
- )
46
- completion = response.choices[0].message.content
47
 
48
  # Simple extraction of the mask [x,x,x,x,x]
49
- import re
50
  match = re.search(r"\[\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*\]", completion)
51
  if match:
52
  mask = [int(m) for m in match.groups()]
53
  else:
54
- logger.warning("Failed to parse mask from LLM output, falling back to [1,1,1,1,1]")
55
  mask = [1, 1, 1, 1, 1]
56
  except Exception as e:
57
- logger.error(f"LLM Inference failed: {e}")
58
  mask = [1, 1, 1, 1, 1]
59
 
60
  # 5. Take Action
 
1
  import os
2
  import json
3
  import logging
4
+ import re
5
+ import google.generativeai as genai
6
+ from dotenv import load_dotenv
7
  from context_pruning_env.env import ContextPruningEnv
8
+
9
+ # Load API keys from .env
10
+ load_dotenv()
11
  from context_pruning_env.models import ContextAction
12
 
13
  # Setup simple logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Configure Gemini
18
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
19
+ if GOOGLE_API_KEY:
20
+ genai.configure(api_key=GOOGLE_API_KEY)
21
+
22
  def main():
23
+ if not GOOGLE_API_KEY:
24
+ logger.error("GOOGLE_API_KEY not found in environment or .env file.")
25
+ return
26
+
27
+ # 1. Setup Gemini Model
28
+ model = genai.GenerativeModel("gemini-1.5-flash")
29
 
30
  # 2. Initialize Environment
31
  env = ContextPruningEnv(squad_split="train")
 
40
  obs = env.reset(task_name=task_name)
41
  print(f"<OBSERVATION>{obs.model_dump_json()}</OBSERVATION>")
42
 
43
+ # 4. Agent Logic (Gemini Call)
 
44
  prompt = (
45
  f"Question: {obs.question}\n\n"
46
+ "Below are 5 context chunks. Output ONLY a JSON list of 5 integers (0 or 1) "
47
  "where 1 means 'keep' and 0 means 'prune'. "
48
  "Prioritize keeping the answer while removing noise and duplicates.\n"
49
  f"Chunks: {json.dumps(obs.chunks, indent=2)}\n\n"
 
51
  )
52
 
53
  try:
54
+ response = model.generate_content(prompt)
55
+ completion = response.text
 
 
 
56
 
57
  # Simple extraction of the mask [x,x,x,x,x]
 
58
  match = re.search(r"\[\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*,\s*([01])\s*\]", completion)
59
  if match:
60
  mask = [int(m) for m in match.groups()]
61
  else:
62
+ logger.warning(f"Failed to parse mask from Gemini output: {completion}. Falling back to [1,1,1,1,1]")
63
  mask = [1, 1, 1, 1, 1]
64
  except Exception as e:
65
+ logger.error(f"Gemini Inference failed: {e}")
66
  mask = [1, 1, 1, 1, 1]
67
 
68
  # 5. Take Action
requirements.txt CHANGED
@@ -7,5 +7,8 @@ datasets>=2.15.0
7
  transformers>=4.35.0
8
  trl>=0.7.4
9
  torch>=2.1.0
10
- openai>=1.5.0
11
  pytest>=7.4.0
 
 
 
 
7
  transformers>=4.35.0
8
  trl>=0.7.4
9
  torch>=2.1.0
10
+ python-dotenv>=1.0.0
11
  pytest>=7.4.0
12
+ gradio>=4.0.0
13
+ google-generativeai>=0.3.0
14
+ python-dotenv>=1.0.0