mergeconflict_openenv / baseline.py
siddeshwar-kagatikar
Initial clean commit without secrets
d51679d
import os
import json
import requests
BASE_URL = "http://localhost:7860"
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "ollama").lower()
OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen3:1.7b")
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
OLLAMA_TIMEOUT = int(os.getenv("OLLAMA_TIMEOUT", "180"))
OLLAMA_RETRIES = int(os.getenv("OLLAMA_RETRIES", "2"))
SYSTEM_PROMPT = """
You are an expert Git Merge Conflict Resolver.
You have the following tools available:
- READ_FILE (requires 'filepath')
- WRITE_FILE (requires 'filepath' and 'content')
- RUN_LINTER (no arguments)
- RUN_TESTS (no arguments)
- SUBMIT (no arguments)
Always format your response as a valid JSON matching this schema:
{"tool": "TOOL_NAME", "filepath": "optional_filename", "content": "optional_full_file_content"}
Strategy:
1. READ_FILE to see the markers (<<<<<<< HEAD).
2. Think about the logic required.
3. WRITE_FILE with the fully resolved content (remove markers!).
4. RUN_TESTS to verify.
5. SUBMIT to finish.
"""
def parse_action(action_text: str) -> dict:
try:
return json.loads(action_text)
except json.JSONDecodeError:
clean_text = action_text.strip()
if clean_text.startswith("```"):
clean_text = clean_text.strip("`")
if clean_text.startswith("json"):
clean_text = clean_text[4:]
return json.loads(clean_text.strip())
class GeminiRunner:
def __init__(self):
import google.generativeai as genai
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("GEMINI_API_KEY environment variable not set for LLM_PROVIDER=gemini.")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(
model_name=GEMINI_MODEL,
system_instruction=SYSTEM_PROMPT,
generation_config={"response_mime_type": "application/json"},
)
self.chat = model.start_chat(history=[])
def send(self, prompt: str) -> str:
response = self.chat.send_message(prompt)
return response.text
class OllamaRunner:
def __init__(self):
self.model = OLLAMA_MODEL
self.host = OLLAMA_HOST.rstrip("/")
self.messages = [{"role": "system", "content": SYSTEM_PROMPT}]
def send(self, prompt: str) -> str:
self.messages.append({"role": "user", "content": prompt})
last_error = None
for attempt in range(OLLAMA_RETRIES + 1):
try:
response = requests.post(
f"{self.host}/api/chat",
json={
"model": self.model,
"messages": self.messages,
"format": "json",
"stream": False,
"options": {"temperature": 0},
},
timeout=OLLAMA_TIMEOUT,
)
response.raise_for_status()
content = response.json()["message"]["content"]
self.messages.append({"role": "assistant", "content": content})
return content
except requests.RequestException as err:
last_error = err
if attempt == OLLAMA_RETRIES:
break
raise RuntimeError(f"Ollama request failed after retries: {last_error}")
def create_runner():
if LLM_PROVIDER == "gemini":
return GeminiRunner()
if LLM_PROVIDER == "ollama":
return OllamaRunner()
raise ValueError(f"Unsupported LLM_PROVIDER={LLM_PROVIDER}. Use 'ollama' or 'gemini'.")
def run_task(task_id: str):
print(f"\n--- Running Baseline for {task_id} with {LLM_PROVIDER} ---")
# Reset Environment
obs = requests.post(f"{BASE_URL}/reset", json={"task_id": task_id}).json()
done = False
runner = create_runner()
prompt = f"Task started. Observation:\n{json.dumps(obs, indent=2)}"
step_count = 0
while not done and step_count < 10:
step_count += 1
try:
action_text = runner.send(prompt)
action_dict = parse_action(action_text)
print(f"Step {step_count} Action: {action_dict.get('tool')} on {action_dict.get('filepath')}")
# Step Environment
step_res = requests.post(f"{BASE_URL}/step", json=action_dict).json()
obs = step_res["observation"]
done = step_res["done"]
# Prepare the next prompt with the environment's observation
prompt = f"Result:\n{json.dumps(obs, indent=2)}"
except Exception as e:
print(f"Error during loop: {e}")
break
# Get final grader score
score = requests.get(f"{BASE_URL}/grader").json().get("score", 0.0)
print(f"Final Score for {task_id}: {score}")
return score
if __name__ == "__main__":
tasks = ["task_1_easy", "task_2_medium", "task_3_hard"]
total_score = 0.0
for t in tasks:
total_score += run_task(t)
print(f"\nBaseline Completed! Average Score: {total_score / len(tasks):.2f}")