| from agents.Base import SequentialBaseAgent, BaseAgent |
| from utils.utils import clear_code |
| from prompts import prompt_for_reflection |
| from memories.Memory import ReflexionMemory |
| from models.Base import BaseModel |
|
|
|
|
|
|
| class Reflexion(SequentialBaseAgent): |
|
|
| def __init__(self, model: BaseModel, dataset): |
| self.model = model |
| self.dataset = dataset |
| self.memories = self.memory_init() |
| |
| def memory_init(self): |
| return [ReflexionMemory(ps) for ps in self.dataset.problem_states] |
|
|
| def run_single_pass(self, mem: ReflexionMemory, verbose=False): |
| if mem.ps.pass_call: |
| return |
| |
| text = mem.ps.instruction |
|
|
| if mem.ps.solution: |
| text += f"\nPrevious attempt implementation:{mem.ps.solution}" |
| |
| if mem.err_msg: |
| text += f"\nTest messages for previous attempt:{mem.err_msg}" |
| |
| if mem.reflection: |
| text += f"\nReflection on previous attempt:{mem.reflection}" |
|
|
| text += "Please output the codes only without explanation, which we can run directly." |
| msg = [ |
| {"role": "user", "content": text}, |
| ] |
| response = self.model.generate(msg) |
| mem.ps.solution = clear_code(response) |
|
|
| |
| is_pass, err_msg = self.dataset.run_single_call(mem.ps) |
|
|
| |
| if not is_pass: |
| mem.err_msg = err_msg |
| reflect_txt = prompt_for_reflection.prompt.format( |
| problem=mem.ps.instruction, |
| solution=mem.ps.solution, |
| test_result=err_msg |
| ) |
| reflect_msg = [ |
| { |
| "role": "user", |
| "content": reflect_txt |
| } |
| ] |
| mem.reflection = self.model.generate(reflect_msg) |
| |
| |
|
|
| return |
|
|
|
|