import os import json import re import time from google import genai from dotenv import load_dotenv from context_pruning_env.env import ContextPruningEnv from context_pruning_env.models import ContextAction # Load .env load_dotenv() # --- SDK CONFIGURATION --- API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") client = genai.Client(api_key=API_KEY) # Fallback sequence for 2026 availability & quota limits MODEL_SEQUENCE = [ os.environ.get("MODEL_NAME", "gemini-2.0-flash"), "gemini-2.5-flash", "gemini-3.1-flash-live-preview", "gemini-1.5-flash-8b" ] def call_with_retry(prompt): """Calls Gemini with exponential backoff and model fallback for 429 errors.""" for model_name in MODEL_SEQUENCE: retries = 3 backoff = 5 # Start with 5s for 2026 free tier for attempt in range(retries): try: print(f"DEBUG: Attempting {model_name} (Attempt {attempt+1}/{retries})...") response = client.models.generate_content( model=model_name, config={ 'temperature': 0.1, 'top_p': 0.95, 'max_output_tokens': 512, }, contents=prompt ) if response and response.text: return response.text, model_name except Exception as e: err_str = str(e).lower() if "429" in err_str or "quota" in err_str or "resource" in err_str: print(f"DEBUG: QUOTA EXCEEDED for {model_name}. Retrying in {backoff}s...") time.sleep(backoff) backoff *= 2 elif "404" in err_str or "not found" in err_str: print(f"DEBUG: MODEL {model_name} NOT FOUND. Falling back to next model.") break # Try next model in sequence else: print(f"LOUD ERROR: {e}") # If it's a non-retryable error, we still try the next model break return None, None def run_inference(): env = ContextPruningEnv() tasks = ["noise_purge", "dedupe_arena", "signal_extract"] total_score = 0 for task in tasks: print(f"\n--- Starting Task: {task} ---") print(f"[START] task={task}") obs = env.reset(task_name=task) # PROMPT prompt = ( f"Query: {obs.question}\n\n" f"TASKS: Prune the following {len(obs.chunks)} chunks. Output EXACTLY {len(obs.chunks)} binary integers [0 or 1] as a JSON list.\n" "Chunks:\n" ) for i, c in enumerate(obs.chunks): prompt += f"[{i}]: {c}\n" raw_text, used_model = call_with_retry(prompt) mask = [1] * len(obs.chunks) # Default fallback if raw_text: print(f"DEBUG: Used Model: {used_model} | RAW RESP: {raw_text}") # Extract list match = re.search(r"\[([\d\s,]+)\]", raw_text) if match: try: mask = json.loads(match.group(0)) except: # manual parse if json fails mask = [int(x) for x in re.findall(r'[01]', match.group(1))] # Robust Pad/truncate mask = (mask + [1] * len(obs.chunks))[:len(obs.chunks)] else: print(f"DEBUG: ALL MODELS FAILED for {task}. Using identity mask.") action = ContextAction(mask=mask) obs = env.step(action) score = obs.metadata.get("eval_score", 0.0) print(f"[STEP] reward={getattr(obs, 'reward', 0.0):.2f} mask={mask}") print(f"[END] task={task} score={score:.2f} success={str(score > 0.5).lower()}") total_score += score print(f"\nINFO:--- ALL TASKS COMPLETE. FINAL AVG SCORE: {total_score / len(tasks):.2f} ---") if __name__ == "__main__": run_inference()