Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import time | |
| import asyncio | |
| import gradio as gr | |
| from google import genai | |
| from dotenv import load_dotenv | |
| from typing import List, Tuple | |
| from context_pruning_env.utils import count_tokens | |
| # Load API keys from .env | |
| load_dotenv() | |
| # --- Configuration --- | |
| API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") | |
| client = genai.Client(api_key=API_KEY) | |
| # Fallback sequence for 2026 availability & quota limits | |
| MODEL_SEQUENCE = [ | |
| os.environ.get("MODEL_NAME", "gemini-2.0-flash"), | |
| "gemini-2.5-flash", | |
| "gemini-3.1-flash-live-preview", | |
| "gemini-1.5-flash-8b" | |
| ] | |
| def call_gemini_with_retry(prompt: str) -> str: | |
| """Helper to call Gemini with exponential backoff and model fallback.""" | |
| if not API_KEY: | |
| return "ERROR: API Key not found." | |
| for model_name in MODEL_SEQUENCE: | |
| retries = 2 | |
| backoff = 3 | |
| for attempt in range(retries): | |
| try: | |
| response = client.models.generate_content( | |
| model=model_name, | |
| config={ | |
| 'temperature': 0.1, | |
| 'top_p': 0.95, | |
| 'max_output_tokens': 512, | |
| }, | |
| contents=prompt | |
| ) | |
| if response and response.text: | |
| return response.text | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| if "429" in err_str or "quota" in err_str: | |
| time.sleep(backoff) | |
| backoff *= 2 | |
| else: | |
| break # Try next model | |
| return "ERROR: All models hit quota or failed." | |
| def chunk_text(text: str, max_chunks: int = 20) -> List[str]: | |
| """Split text into chunks.""" | |
| initial_chunks = [c.strip() for c in re.split(r'\n\s*\n', text) if c.strip()] | |
| final_chunks = [] | |
| for chunk in initial_chunks: | |
| sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+|\n', chunk) if s.strip()] | |
| final_chunks.extend(sentences) | |
| return final_chunks[:max_chunks] | |
| async def prune_context(query: str, raw_text: str) -> Tuple[str, dict, str]: | |
| """Pruning logic with robust retry wrapper.""" | |
| if not query or not raw_text: | |
| return "Please provide both.", {}, "" | |
| chunks = chunk_text(raw_text) | |
| selection_prompt = ( | |
| f"Query: {query}\n\n" | |
| "TASK: AGGRESSIVE CONTEXT OPTIMIZATION. " | |
| "Goal: TOKEN REDUCTION. Prune noise and keep ONLY essential info.\n" | |
| f"OUTPUT: Output EXACTLY {len(chunks)} binary integers [0 or 1] as a JSON list.\n\n" | |
| "Chunks:\n" | |
| ) | |
| for i, c in enumerate(chunks): | |
| selection_prompt += f"Chunk {i}: {c}\n" | |
| loop = asyncio.get_event_loop() | |
| raw_response = await loop.run_in_executor(None, call_gemini_with_retry, selection_prompt) | |
| if "ERROR" in raw_response: | |
| return raw_response, {}, "FAIL" | |
| indices = [] | |
| try: | |
| match = re.search(r"\[([\d\s,]+)\]", raw_response) | |
| if match: | |
| mask = json.loads(match.group(0)) | |
| mask = (mask + [0] * len(chunks))[:len(chunks)] | |
| indices = [i for i, m in enumerate(mask) if int(m) == 1] | |
| except: | |
| indices = [] | |
| if not indices: | |
| optimized_text = "No matches found or optimization too aggressive." | |
| else: | |
| optimized_text = " ".join([chunks[i] for i in sorted(indices)]) | |
| orig_tokens = count_tokens(raw_text) | |
| final_tokens = count_tokens(optimized_text) | |
| reduction = ((orig_tokens - final_tokens) / orig_tokens * 100) if orig_tokens > 0 else 0 | |
| metrics = { | |
| "Original Tokens": f"{orig_tokens}", | |
| "Final Tokens": f"{final_tokens}", | |
| "Reduction Score": f"{reduction:.1f}%" | |
| } | |
| ground_prompt = f"Question: {query}\nContext: {optimized_text}\n\nTask: Response with 'PASS' if info present, else 'FAIL'." | |
| ground_result = await loop.run_in_executor(None, call_gemini_with_retry, ground_prompt) | |
| return optimized_text, metrics, ground_result | |
| # --- Gradio UI with Premium Styling --- | |
| def get_status_html(result: str): | |
| if "PASS" in result.upper(): | |
| return '<div style="background-color: #059669; color: white; padding: 12px; border-radius: 12px; font-weight: bold; text-align: center;">🚀 GROUNDEDNESS SUCCESS</div>' | |
| return '<div style="background-color: #dc2626; color: white; padding: 12px; border-radius: 12px; font-weight: bold; text-align: center;">⚠️ GROUNDEDNESS FAILURE</div>' | |
| CSS = """ | |
| body { background-color: #0f172a; color: white; } | |
| .gradio-container { border-radius: 20px !important; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5) !important; } | |
| #title { text-align: center; font-size: 2.5em; margin-bottom: 20px; color: #38bdf8; } | |
| """ | |
| with gr.Blocks(title="ContextPrune") as demo: | |
| gr.Markdown("# 🧠 ContextPrune AI: Quota-Resilient Context Compression", elem_id="title") | |
| with gr.Tabs(): | |
| with gr.TabItem("Optimizer"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| query_in = gr.Textbox(label="🔍 User Query", placeholder="What are the key technical findings?", lines=2) | |
| context_in = gr.Textbox(label="📄 Noisy Document Content", placeholder="Paste large blocks of text here...", lines=15) | |
| btn = gr.Button("🔥 Prune Context Now", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| metrics_lbl = gr.Label(label="Optimization Efficiency") | |
| status = gr.HTML() | |
| out = gr.Textbox(label="✨ Optimized Context (Ready for LLM)", interactive=False, lines=15) | |
| async def run_ui(q, c): | |
| txt, m, g = await prune_context(q, c) | |
| return txt, get_status_html(g), m | |
| btn.click(run_ui, [query_in, context_in], [out, status, metrics_lbl]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Default(primary_hue="blue", neutral_hue="slate"), css=CSS) | |