| import os |
| from tqdm import tqdm |
| from loguru import logger |
| import json |
| from dataclasses import asdict |
| from agents.Reflexion import Reflexion |
| from utils.utils import extract_function_signatures, clear_code, extract_function_calls |
| from prompts import prompt_for_reflection |
| from memories.Memory import MemoryClassMeta |
| from models.Base import BaseModel |
| from retrievers.retriever import BM25Retriever |
| from prompts import prompt_for_generation |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
|
| class Reflexion_Oneshot(Reflexion): |
|
|
| def __init__(self, model: BaseModel, dataset, corpus_path, mem_file=None, descendant_num=1): |
| self.model = model |
| self.dataset = dataset |
| self.memories = [] |
|
|
| self.instruction_retriever = BM25Retriever() |
| self.instruction_retriever.process(content_input_path=corpus_path) |
| self.code_retriever = BM25Retriever(mode="code") |
| self.code_retriever.process(content_input_path=corpus_path) |
|
|
| self.memory_init(mem_file, descendant_num) |
|
|
| def memory_init(self, mem_file=None, descendant_num=1): |
| class Memory(metaclass=MemoryClassMeta, field_names=["ps", |
| "err_msg", |
| "reflection", |
| "function_signatures", |
| "oneshot", |
| "pass_call", |
| ]): |
| pass |
| |
| if mem_file is not None: |
| assert mem_file.endswith(".json"), f"expect a json file, but got {mem_file} instead" |
| with open(mem_file, "r") as f: |
| input_mems = json.load(f) |
| assert len(input_mems) == len(self.dataset), f"expect {len(self.dataset)} samples, but got {len(input_mems)} instead" |
|
|
| for ps in self.dataset.problem_states: |
| if ps.label: |
| fs_mem = extract_function_signatures(ps.label) |
| else: |
| fs_mem = None |
| if mem_file is None: |
| os_mem = self.instruction_retriever.query(ps.instruction)[0] |
| tmp_mem = Memory(ps=ps, |
| err_msg=None, |
| reflection=None, |
| function_signatures=fs_mem, |
| oneshot=os_mem["code"], |
| pass_call=False, |
| ) |
| else: |
| input_mem = input_mems[ps.filename] |
| tmp_mem = Memory(ps=ps, |
| err_msg=input_mem["err_msg"], |
| reflection=input_mem["reflection"], |
| function_signatures=fs_mem, |
| oneshot=input_mem["oneshot"], |
| pass_call=input_mem["pass_call"], |
| ) |
| self.memories.append(tmp_mem) |
|
|
| def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0): |
| data_len = datalen if datalen else len(self.dataset) |
| for iter in range(iteration_num): |
| logger.info(f"\n=== Iteration {iter} ===") |
| if output_path is not None: |
| root, extension = os.path.splitext(output_path) |
| iter_path = f"{root}_{iter}{extension}" |
|
|
| if multi_thread: |
| thread_num = 3 |
| |
| |
| logger.info(f"\ngenerate solution") |
| with tqdm(total=data_len) as pbar: |
| if multi_thread: |
| |
| with ThreadPoolExecutor(max_workers=thread_num) as executor: |
| futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]} |
| for future in as_completed(futures): |
| pbar.update(1) |
| else: |
| for mem in self.memories[:data_len]: |
| self.generate_solution(mem, temperature=temperature) |
| pbar.update(1) |
| |
| |
| logger.info(f"\nrun scripts on gpu") |
| for mem in tqdm(self.memories[:data_len]): |
| if mem.pass_call: |
| continue |
| is_pass, err_msg = self.dataset.run_single_call(mem.ps) |
| if not is_pass: |
| mem.err_msg = err_msg |
|
|
| |
| logger.info(f"\ngenerate reflections") |
| with tqdm(total=data_len) as pbar: |
| if multi_thread: |
| with ThreadPoolExecutor(max_workers=thread_num) as executor: |
| futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]} |
| for future in as_completed(futures): |
| pbar.update(1) |
| else: |
| for mem in self.memories[:data_len]: |
| self.generate_reflexion(mem, temperature=temperature) |
| pbar.update(1) |
| |
| if output_path is not None: |
| self.dataset.write_file(iter_path) |
| |
|
|
| |
| def generate_solution(self, mem, temperature=0): |
| if mem.pass_call: |
| return |
| |
| tab = "\n" |
| fss_text = "".join(f"* {sig}{tab}" for sig in mem.function_signatures) |
| text = prompt_for_generation.prompt.format( |
| instruction=mem.ps.instruction, |
| function_signatures=fss_text |
| ) |
|
|
| if not mem.ps.solution: |
| text += f"\nHere is an example snippet of code: {mem.oneshot}" |
| else: |
| one_shot = self.code_retriever.query(mem.ps.solution)[0]["code"] |
| text += f"\nHere is an example snippet of code: {one_shot}" |
| text += f"\nPrevious attempt implementation:{mem.ps.solution}" |
| |
| |
| if mem.err_msg: |
| text += f"\nTest messages for previous attempt:{mem.err_msg}" |
| |
| if mem.reflection: |
| text += f"\nReflection on previous attempt:{mem.reflection}" |
|
|
| text += "Please output the codes only without explanation, which we can run directly." |
| msg = [ |
| {"role": "user", "content": text}, |
| ] |
| response = self.model.generate(msg, temperature=temperature) |
| mem.ps.solution = clear_code(response) |
|
|
| return |
|
|
|
|
|
|
| def generate_reflexion(self, mem, temperature): |
| if mem.pass_call: |
| return |
| reflect_txt = prompt_for_reflection.prompt.format( |
| problem=mem.ps.instruction, |
| solution=mem.ps.solution, |
| test_result=mem.err_msg |
| ) |
| reflect_msg = [ |
| { |
| "role": "user", |
| "content": reflect_txt |
| } |
| ] |
| mem.reflection = self.model.generate(reflect_msg, temperature=temperature) |