Prasham.Jain
feat(training): Phase C5 — evaluation harness, baselines, plots, readme table
93e68bc
"""ZeroShotPolicy — any HF model in zero-shot inference mode.
All GPU-heavy imports (unsloth, torch) are lazy so the module is importable without GPU.
"""
from __future__ import annotations
class ZeroShotPolicy:
"""Wrapper around any HF model used in zero-shot inference mode."""
def __init__(
self,
model_name: str,
system_prompt: str,
name: str | None = None,
) -> None:
self.name = name or f"zero_shot_{model_name.split('/')[-1]}"
from unsloth import FastLanguageModel # type: ignore[import]
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=8192,
load_in_4bit=True,
)
FastLanguageModel.for_inference(self.model)
self.system_prompt = system_prompt
def act(self, obs, history: list) -> dict:
import torch # type: ignore[import]
from ci_triage_env.training.trajectory_gen import TrajectoryGenerator
messages = self._build_messages(obs, history)
input_ids = self.tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=400,
do_sample=False,
temperature=1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
)
action = TrajectoryGenerator._parse_action(response)
if action is not None:
return action.model_dump()
return {
"action_type": "submit_diagnosis",
"diagnosis": "ambiguous",
"confidence": 0.5,
"secondary_actions": [],
}
def _build_messages(self, obs, history: list) -> list[dict]:
messages = [{"role": "system", "content": self.system_prompt}]
if obs.failure_summary:
fs = obs.failure_summary
content = (
f"CI FAILURE ALERT\nTest: {fs.test_name}\nSuite: {fs.suite}\n"
f"Branch: {fs.branch}\nLog excerpt:\n{fs.initial_log_excerpt}\n\n"
"Investigate and submit your diagnosis."
)
else:
content = "CI failure detected. Begin investigation."
messages.append({"role": "user", "content": content})
for entry in history:
messages.append({"role": "assistant", "content": str(entry)})
return messages