Spaces:
Sleeping
Sleeping
Jayant-Kernel commited on
Commit ·
f80bdfb
1
Parent(s): cdb2e06
fix: match exact [START]/[STEP]/[END] format, fix openenv.yaml app path
Browse files- inference.py +54 -50
- openenv.yaml +1 -1
inference.py
CHANGED
|
@@ -1,36 +1,25 @@
|
|
| 1 |
# inference.py — run this from the project root
|
| 2 |
-
"""
|
| 3 |
-
LLM agent for the CI/CD Failure Diagnosis environment.
|
| 4 |
-
|
| 5 |
-
Env vars:
|
| 6 |
-
API_BASE_URL OpenAI-compatible base URL (e.g. https://api.openai.com/v1)
|
| 7 |
-
MODEL_NAME model to call (e.g. gpt-4o-mini)
|
| 8 |
-
HF_TOKEN HuggingFace token — used as API key when running on HF Spaces
|
| 9 |
-
ENV_URL running server URL (default: http://localhost:8000)
|
| 10 |
-
NUM_EPISODES how many episodes to run (default: 10)
|
| 11 |
-
"""
|
| 12 |
import json
|
| 13 |
import os
|
| 14 |
import sys
|
| 15 |
-
import time
|
| 16 |
|
| 17 |
from openai import OpenAI
|
| 18 |
|
| 19 |
from cicd_diagnosis_env.client import CICDEnv
|
| 20 |
from cicd_diagnosis_env.models import DiagnoseAction
|
| 21 |
|
| 22 |
-
API_BASE_URL = os.
|
| 23 |
-
MODEL_NAME = os.
|
| 24 |
-
HF_TOKEN = os.
|
| 25 |
-
ENV_URL = os.
|
| 26 |
-
NUM_EPISODES = int(os.
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
_api_key = HF_TOKEN or os.environ.get("OPENAI_API_KEY", "no-key")
|
| 30 |
llm = OpenAI(api_key=_api_key, base_url=API_BASE_URL)
|
| 31 |
|
| 32 |
_SYSTEM = """You are an expert CI/CD engineer diagnosing pipeline failures.
|
| 33 |
-
|
| 34 |
{
|
| 35 |
"failure_category": "<dependency|config|flaky|code_bug|infra>",
|
| 36 |
"root_cause": "<concise 1-2 sentence explanation>",
|
|
@@ -39,7 +28,21 @@ You will receive a failure log. Respond ONLY with valid JSON (no markdown):
|
|
| 39 |
}"""
|
| 40 |
|
| 41 |
|
| 42 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
msg = f"Error summary: {summary}\n\nFull log:\n{log}"
|
| 44 |
resp = llm.chat.completions.create(
|
| 45 |
model=MODEL_NAME,
|
|
@@ -48,11 +51,8 @@ def diagnose(log, summary):
|
|
| 48 |
max_tokens=300,
|
| 49 |
)
|
| 50 |
raw = resp.choices[0].message.content.strip()
|
| 51 |
-
# strip markdown fences — some models add them even when told not to
|
| 52 |
-
# this is a bit fragile but works for the models we're targeting
|
| 53 |
if raw.startswith("```"):
|
| 54 |
lines = raw.splitlines()
|
| 55 |
-
# drop first line (```json or ```) and last line (```)
|
| 56 |
raw = "\n".join(lines[1:-1]).strip()
|
| 57 |
parsed = json.loads(raw)
|
| 58 |
return DiagnoseAction(
|
|
@@ -63,50 +63,54 @@ def diagnose(log, summary):
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
| 66 |
-
def run_episode(env,
|
| 67 |
obs = env.reset()
|
| 68 |
-
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
try:
|
| 73 |
-
action =
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
break
|
| 77 |
|
| 78 |
obs = env.step(action)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
if obs.done:
|
|
|
|
| 87 |
break
|
| 88 |
|
| 89 |
-
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def main():
|
| 93 |
-
print(f"[START] model={MODEL_NAME} episodes={NUM_EPISODES} env={ENV_URL}")
|
| 94 |
env = CICDEnv(base_url=ENV_URL)
|
| 95 |
-
scores = []
|
| 96 |
-
t0 = time.time()
|
| 97 |
-
|
| 98 |
for ep in range(1, NUM_EPISODES + 1):
|
| 99 |
try:
|
| 100 |
-
|
| 101 |
-
scores.append(s)
|
| 102 |
except Exception as e:
|
| 103 |
-
print(f"[
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
elapsed = time.time() - t0
|
| 107 |
-
avg = sum(scores) / len(scores) if scores else 0.0
|
| 108 |
-
print(f"[END] episodes={NUM_EPISODES} avg_score={avg:.4f} elapsed={elapsed:.1f}s")
|
| 109 |
|
| 110 |
|
| 111 |
-
if __name__ ==
|
| 112 |
main()
|
|
|
|
| 1 |
# inference.py — run this from the project root
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import sys
|
|
|
|
| 5 |
|
| 6 |
from openai import OpenAI
|
| 7 |
|
| 8 |
from cicd_diagnosis_env.client import CICDEnv
|
| 9 |
from cicd_diagnosis_env.models import DiagnoseAction
|
| 10 |
|
| 11 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 12 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 13 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 14 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 15 |
+
NUM_EPISODES = int(os.getenv("NUM_EPISODES", "10"))
|
| 16 |
+
MAX_STEPS = 3
|
| 17 |
|
| 18 |
+
_api_key = HF_TOKEN or os.getenv("OPENAI_API_KEY", "no-key")
|
|
|
|
| 19 |
llm = OpenAI(api_key=_api_key, base_url=API_BASE_URL)
|
| 20 |
|
| 21 |
_SYSTEM = """You are an expert CI/CD engineer diagnosing pipeline failures.
|
| 22 |
+
Respond ONLY with valid JSON (no markdown):
|
| 23 |
{
|
| 24 |
"failure_category": "<dependency|config|flaky|code_bug|infra>",
|
| 25 |
"root_cause": "<concise 1-2 sentence explanation>",
|
|
|
|
| 28 |
}"""
|
| 29 |
|
| 30 |
|
| 31 |
+
def log_start(task, env, model):
|
| 32 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def log_step(step, action, reward, done, error=None):
|
| 36 |
+
err = error if error else "null"
|
| 37 |
+
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={err}", flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log_end(success, steps, score, rewards):
|
| 41 |
+
rstr = ",".join(f"{r:.2f}" for r in rewards)
|
| 42 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rstr}", flush=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def call_llm(log, summary):
|
| 46 |
msg = f"Error summary: {summary}\n\nFull log:\n{log}"
|
| 47 |
resp = llm.chat.completions.create(
|
| 48 |
model=MODEL_NAME,
|
|
|
|
| 51 |
max_tokens=300,
|
| 52 |
)
|
| 53 |
raw = resp.choices[0].message.content.strip()
|
|
|
|
|
|
|
| 54 |
if raw.startswith("```"):
|
| 55 |
lines = raw.splitlines()
|
|
|
|
| 56 |
raw = "\n".join(lines[1:-1]).strip()
|
| 57 |
parsed = json.loads(raw)
|
| 58 |
return DiagnoseAction(
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
| 66 |
+
def run_episode(env, ep_num):
|
| 67 |
obs = env.reset()
|
| 68 |
+
task_name = f"task{obs.task_id}"
|
| 69 |
+
log_start(task=task_name, env="cicd_diagnosis_env", model=MODEL_NAME)
|
| 70 |
|
| 71 |
+
rewards = []
|
| 72 |
+
steps = 0
|
| 73 |
+
success = False
|
| 74 |
+
score = 0.0
|
| 75 |
+
|
| 76 |
+
for step in range(1, MAX_STEPS + 1):
|
| 77 |
+
error = None
|
| 78 |
try:
|
| 79 |
+
action = call_llm(obs.pipeline_log, obs.error_summary)
|
| 80 |
+
action_str = f"diagnose(category={action.failure_category})"
|
| 81 |
except Exception as e:
|
| 82 |
+
error = str(e)
|
| 83 |
+
action_str = "diagnose(error)"
|
| 84 |
+
log_step(step, action_str, 0.0, True, error)
|
| 85 |
+
rewards.append(0.0)
|
| 86 |
+
steps = step
|
| 87 |
break
|
| 88 |
|
| 89 |
obs = env.step(action)
|
| 90 |
+
reward = obs.reward if obs.reward is not None else 0.0
|
| 91 |
+
rewards.append(reward)
|
| 92 |
+
steps = step
|
| 93 |
+
score = obs.score
|
| 94 |
+
|
| 95 |
+
log_step(step, action_str, reward, obs.done, error)
|
| 96 |
+
|
| 97 |
if obs.done:
|
| 98 |
+
success = score >= 0.5
|
| 99 |
break
|
| 100 |
|
| 101 |
+
log_end(success=success, steps=steps, score=score, rewards=rewards)
|
| 102 |
+
return score
|
| 103 |
|
| 104 |
|
| 105 |
def main():
|
|
|
|
| 106 |
env = CICDEnv(base_url=ENV_URL)
|
|
|
|
|
|
|
|
|
|
| 107 |
for ep in range(1, NUM_EPISODES + 1):
|
| 108 |
try:
|
| 109 |
+
run_episode(env, ep)
|
|
|
|
| 110 |
except Exception as e:
|
| 111 |
+
print(f"[DEBUG] episode {ep} error: {e}", file=sys.stderr, flush=True)
|
| 112 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
main()
|
openenv.yaml
CHANGED
|
@@ -4,7 +4,7 @@ version: "0.1.0"
|
|
| 4 |
description: "RL environment for diagnosing CI/CD pipeline failures"
|
| 5 |
type: space
|
| 6 |
runtime: fastapi
|
| 7 |
-
app: server.app:app
|
| 8 |
port: 7860
|
| 9 |
action: DiagnoseAction
|
| 10 |
observation: PipelineObservation
|
|
|
|
| 4 |
description: "RL environment for diagnosing CI/CD pipeline failures"
|
| 5 |
type: space
|
| 6 |
runtime: fastapi
|
| 7 |
+
app: cicd_diagnosis_env.server.app:app
|
| 8 |
port: 7860
|
| 9 |
action: DiagnoseAction
|
| 10 |
observation: PipelineObservation
|