context-prune / inference_pruning_legacy.py
prithic07's picture
Upgrade RAG project with advanced Context Optimizer environment and baseline inference
0b89610
import os
import json
import re
import time
from google import genai
from dotenv import load_dotenv
from context_pruning_env.env import ContextPruningEnv
from context_pruning_env.models import ContextAction
# Load .env
load_dotenv()
# --- SDK CONFIGURATION ---
API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
client = genai.Client(api_key=API_KEY)
# Fallback sequence for 2026 availability & quota limits
MODEL_SEQUENCE = [
os.environ.get("MODEL_NAME", "gemini-2.0-flash"),
"gemini-2.5-flash",
"gemini-3.1-flash-live-preview",
"gemini-1.5-flash-8b"
]
def call_with_retry(prompt):
"""Calls Gemini with exponential backoff and model fallback for 429 errors."""
for model_name in MODEL_SEQUENCE:
retries = 3
backoff = 5 # Start with 5s for 2026 free tier
for attempt in range(retries):
try:
print(f"DEBUG: Attempting {model_name} (Attempt {attempt+1}/{retries})...")
response = client.models.generate_content(
model=model_name,
config={
'temperature': 0.1,
'top_p': 0.95,
'max_output_tokens': 512,
},
contents=prompt
)
if response and response.text:
return response.text, model_name
except Exception as e:
err_str = str(e).lower()
if "429" in err_str or "quota" in err_str or "resource" in err_str:
print(f"DEBUG: QUOTA EXCEEDED for {model_name}. Retrying in {backoff}s...")
time.sleep(backoff)
backoff *= 2
elif "404" in err_str or "not found" in err_str:
print(f"DEBUG: MODEL {model_name} NOT FOUND. Falling back to next model.")
break # Try next model in sequence
else:
print(f"LOUD ERROR: {e}")
# If it's a non-retryable error, we still try the next model
break
return None, None
def run_inference():
env = ContextPruningEnv()
tasks = ["noise_purge", "dedupe_arena", "signal_extract"]
total_score = 0
for task in tasks:
print(f"\n--- Starting Task: {task} ---")
print(f"[START] task={task}")
obs = env.reset(task_name=task)
# PROMPT
prompt = (
f"Query: {obs.question}\n\n"
f"TASKS: Prune the following {len(obs.chunks)} chunks. Output EXACTLY {len(obs.chunks)} binary integers [0 or 1] as a JSON list.\n"
"Chunks:\n"
)
for i, c in enumerate(obs.chunks):
prompt += f"[{i}]: {c}\n"
raw_text, used_model = call_with_retry(prompt)
mask = [1] * len(obs.chunks) # Default fallback
if raw_text:
print(f"DEBUG: Used Model: {used_model} | RAW RESP: {raw_text}")
# Extract list
match = re.search(r"\[([\d\s,]+)\]", raw_text)
if match:
try:
mask = json.loads(match.group(0))
except:
# manual parse if json fails
mask = [int(x) for x in re.findall(r'[01]', match.group(1))]
# Robust Pad/truncate
mask = (mask + [1] * len(obs.chunks))[:len(obs.chunks)]
else:
print(f"DEBUG: ALL MODELS FAILED for {task}. Using identity mask.")
action = ContextAction(mask=mask)
obs = env.step(action)
score = obs.metadata.get("eval_score", 0.0)
print(f"[STEP] reward={getattr(obs, 'reward', 0.0):.2f} mask={mask}")
print(f"[END] task={task} score={score:.2f} success={str(score > 0.5).lower()}")
total_score += score
print(f"\nINFO:--- ALL TASKS COMPLETE. FINAL AVG SCORE: {total_score / len(tasks):.2f} ---")
if __name__ == "__main__":
run_inference()