Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import re | |
| import time | |
| from google import genai | |
| from dotenv import load_dotenv | |
| from context_pruning_env.env import ContextPruningEnv | |
| from context_pruning_env.models import ContextAction | |
| # Load .env | |
| load_dotenv() | |
| # --- SDK CONFIGURATION --- | |
| API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") | |
| client = genai.Client(api_key=API_KEY) | |
| # Fallback sequence for 2026 availability & quota limits | |
| MODEL_SEQUENCE = [ | |
| os.environ.get("MODEL_NAME", "gemini-2.0-flash"), | |
| "gemini-2.5-flash", | |
| "gemini-3.1-flash-live-preview", | |
| "gemini-1.5-flash-8b" | |
| ] | |
| def call_with_retry(prompt): | |
| """Calls Gemini with exponential backoff and model fallback for 429 errors.""" | |
| for model_name in MODEL_SEQUENCE: | |
| retries = 3 | |
| backoff = 5 # Start with 5s for 2026 free tier | |
| for attempt in range(retries): | |
| try: | |
| print(f"DEBUG: Attempting {model_name} (Attempt {attempt+1}/{retries})...") | |
| response = client.models.generate_content( | |
| model=model_name, | |
| config={ | |
| 'temperature': 0.1, | |
| 'top_p': 0.95, | |
| 'max_output_tokens': 512, | |
| }, | |
| contents=prompt | |
| ) | |
| if response and response.text: | |
| return response.text, model_name | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| if "429" in err_str or "quota" in err_str or "resource" in err_str: | |
| print(f"DEBUG: QUOTA EXCEEDED for {model_name}. Retrying in {backoff}s...") | |
| time.sleep(backoff) | |
| backoff *= 2 | |
| elif "404" in err_str or "not found" in err_str: | |
| print(f"DEBUG: MODEL {model_name} NOT FOUND. Falling back to next model.") | |
| break # Try next model in sequence | |
| else: | |
| print(f"LOUD ERROR: {e}") | |
| # If it's a non-retryable error, we still try the next model | |
| break | |
| return None, None | |
| def run_inference(): | |
| env = ContextPruningEnv() | |
| tasks = ["noise_purge", "dedupe_arena", "signal_extract"] | |
| total_score = 0 | |
| for task in tasks: | |
| print(f"\n--- Starting Task: {task} ---") | |
| print(f"[START] task={task}") | |
| obs = env.reset(task_name=task) | |
| # PROMPT | |
| prompt = ( | |
| f"Query: {obs.question}\n\n" | |
| f"TASKS: Prune the following {len(obs.chunks)} chunks. Output EXACTLY {len(obs.chunks)} binary integers [0 or 1] as a JSON list.\n" | |
| "Chunks:\n" | |
| ) | |
| for i, c in enumerate(obs.chunks): | |
| prompt += f"[{i}]: {c}\n" | |
| raw_text, used_model = call_with_retry(prompt) | |
| mask = [1] * len(obs.chunks) # Default fallback | |
| if raw_text: | |
| print(f"DEBUG: Used Model: {used_model} | RAW RESP: {raw_text}") | |
| # Extract list | |
| match = re.search(r"\[([\d\s,]+)\]", raw_text) | |
| if match: | |
| try: | |
| mask = json.loads(match.group(0)) | |
| except: | |
| # manual parse if json fails | |
| mask = [int(x) for x in re.findall(r'[01]', match.group(1))] | |
| # Robust Pad/truncate | |
| mask = (mask + [1] * len(obs.chunks))[:len(obs.chunks)] | |
| else: | |
| print(f"DEBUG: ALL MODELS FAILED for {task}. Using identity mask.") | |
| action = ContextAction(mask=mask) | |
| obs = env.step(action) | |
| score = obs.metadata.get("eval_score", 0.0) | |
| print(f"[STEP] reward={getattr(obs, 'reward', 0.0):.2f} mask={mask}") | |
| print(f"[END] task={task} score={score:.2f} success={str(score > 0.5).lower()}") | |
| total_score += score | |
| print(f"\nINFO:--- ALL TASKS COMPLETE. FINAL AVG SCORE: {total_score / len(tasks):.2f} ---") | |
| if __name__ == "__main__": | |
| run_inference() | |