| import os |
| import json |
| import requests |
|
|
| BASE_URL = "http://localhost:7860" |
| LLM_PROVIDER = os.getenv("LLM_PROVIDER", "ollama").lower() |
| OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434") |
| OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen3:1.7b") |
| GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") |
| OLLAMA_TIMEOUT = int(os.getenv("OLLAMA_TIMEOUT", "180")) |
| OLLAMA_RETRIES = int(os.getenv("OLLAMA_RETRIES", "2")) |
|
|
| SYSTEM_PROMPT = """ |
| You are an expert Git Merge Conflict Resolver. |
| You have the following tools available: |
| - READ_FILE (requires 'filepath') |
| - WRITE_FILE (requires 'filepath' and 'content') |
| - RUN_LINTER (no arguments) |
| - RUN_TESTS (no arguments) |
| - SUBMIT (no arguments) |
| |
| Always format your response as a valid JSON matching this schema: |
| {"tool": "TOOL_NAME", "filepath": "optional_filename", "content": "optional_full_file_content"} |
| |
| Strategy: |
| 1. READ_FILE to see the markers (<<<<<<< HEAD). |
| 2. Think about the logic required. |
| 3. WRITE_FILE with the fully resolved content (remove markers!). |
| 4. RUN_TESTS to verify. |
| 5. SUBMIT to finish. |
| """ |
|
|
| def parse_action(action_text: str) -> dict: |
| try: |
| return json.loads(action_text) |
| except json.JSONDecodeError: |
| clean_text = action_text.strip() |
| if clean_text.startswith("```"): |
| clean_text = clean_text.strip("`") |
| if clean_text.startswith("json"): |
| clean_text = clean_text[4:] |
| return json.loads(clean_text.strip()) |
|
|
|
|
| class GeminiRunner: |
| def __init__(self): |
| import google.generativeai as genai |
|
|
| api_key = os.getenv("GEMINI_API_KEY") |
| if not api_key: |
| raise RuntimeError("GEMINI_API_KEY environment variable not set for LLM_PROVIDER=gemini.") |
| genai.configure(api_key=api_key) |
| model = genai.GenerativeModel( |
| model_name=GEMINI_MODEL, |
| system_instruction=SYSTEM_PROMPT, |
| generation_config={"response_mime_type": "application/json"}, |
| ) |
| self.chat = model.start_chat(history=[]) |
|
|
| def send(self, prompt: str) -> str: |
| response = self.chat.send_message(prompt) |
| return response.text |
|
|
|
|
| class OllamaRunner: |
| def __init__(self): |
| self.model = OLLAMA_MODEL |
| self.host = OLLAMA_HOST.rstrip("/") |
| self.messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
|
| def send(self, prompt: str) -> str: |
| self.messages.append({"role": "user", "content": prompt}) |
| last_error = None |
| for attempt in range(OLLAMA_RETRIES + 1): |
| try: |
| response = requests.post( |
| f"{self.host}/api/chat", |
| json={ |
| "model": self.model, |
| "messages": self.messages, |
| "format": "json", |
| "stream": False, |
| "options": {"temperature": 0}, |
| }, |
| timeout=OLLAMA_TIMEOUT, |
| ) |
| response.raise_for_status() |
| content = response.json()["message"]["content"] |
| self.messages.append({"role": "assistant", "content": content}) |
| return content |
| except requests.RequestException as err: |
| last_error = err |
| if attempt == OLLAMA_RETRIES: |
| break |
| raise RuntimeError(f"Ollama request failed after retries: {last_error}") |
|
|
|
|
| def create_runner(): |
| if LLM_PROVIDER == "gemini": |
| return GeminiRunner() |
| if LLM_PROVIDER == "ollama": |
| return OllamaRunner() |
| raise ValueError(f"Unsupported LLM_PROVIDER={LLM_PROVIDER}. Use 'ollama' or 'gemini'.") |
|
|
| def run_task(task_id: str): |
| print(f"\n--- Running Baseline for {task_id} with {LLM_PROVIDER} ---") |
| |
| |
| obs = requests.post(f"{BASE_URL}/reset", json={"task_id": task_id}).json() |
| done = False |
| runner = create_runner() |
| prompt = f"Task started. Observation:\n{json.dumps(obs, indent=2)}" |
|
|
| step_count = 0 |
| while not done and step_count < 10: |
| step_count += 1 |
| try: |
| action_text = runner.send(prompt) |
| action_dict = parse_action(action_text) |
| |
| print(f"Step {step_count} Action: {action_dict.get('tool')} on {action_dict.get('filepath')}") |
| |
| |
| step_res = requests.post(f"{BASE_URL}/step", json=action_dict).json() |
| obs = step_res["observation"] |
| done = step_res["done"] |
| |
| |
| prompt = f"Result:\n{json.dumps(obs, indent=2)}" |
| |
| except Exception as e: |
| print(f"Error during loop: {e}") |
| break |
|
|
| |
| score = requests.get(f"{BASE_URL}/grader").json().get("score", 0.0) |
| print(f"Final Score for {task_id}: {score}") |
| return score |
|
|
| if __name__ == "__main__": |
| tasks = ["task_1_easy", "task_2_medium", "task_3_hard"] |
| total_score = 0.0 |
| |
| for t in tasks: |
| total_score += run_task(t) |
| |
| print(f"\nBaseline Completed! Average Score: {total_score / len(tasks):.2f}") |
|
|