| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from response import TurboResponser | |
| from dataset_all import get_datasets_by_name | |
| from is_correct import test_output_comparison | |
| from config import cfg | |
| from prompt import code_system_prompt, test_prompt, code_prompt | |
| from typing import List, Optional | |
| import re | |
| import json | |
| import multiprocessing | |
| import traceback | |
| from execute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux, run_multiple_tests | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-lcb.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| def extract_jsonl_from_markdown(markdown_text): | |
| # 使用正则表达式匹配 Markdown 中的 JSON 格式数据 | |
| json_pattern = re.compile(r'```json(.*?)```', re.DOTALL) | |
| matches = json_pattern.findall(markdown_text) | |
| # 将匹配到的 JSON 字符串解析为字典 | |
| json_data = [] | |
| for match in matches: | |
| try: | |
| json_data.append(json.loads(match.strip())) | |
| except json.JSONDecodeError: | |
| continue | |
| return json_data[0] | |
| def extract_code(ans_str): | |
| pattern = r'```python\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return [matche for matche in matches] | |
| def extract_json(ans_str): | |
| pattern = r'```json\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| import argparse | |
| keys = [ | |
| "ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9", | |
| "ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5", | |
| "ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2", | |
| "ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7", | |
| ] | |
| # parser = argparse.ArgumentParser(description="接收1个命令行参数") | |
| # parser.add_argument("param1", type=str, help="第一个参数") | |
| # args = parser.parse_args() | |
| # batch = int(args.param1) | |
| batch = 0 | |
| api_key = keys[batch // 3] | |
| log_file = cfg.log_file.format(batch) | |
| responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name) | |
| log_file = cfg.log_file.format(batch) | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| start_pos = (batch) * 100 | |
| end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset) | |
| al_dataset = al_dataset[start_pos: end_pos] | |
| al_dataset = al_dataset[0: 2] | |
| count = 0 | |
| testcases = {} | |
| generators = {} | |
| testcases_pass_rate = {} | |
| for item in al_dataset: | |
| id = item["tcb_id"].replace('/', '') | |
| query = item['query_en'] | |
| user_prompt_test = test_prompt.replace("{num_of_test}", str(cfg.gen_test_nums)) + query | |
| user_prompt_code = code_prompt.format(cfg.pass_k) + query | |
| solutions_gen = None | |
| testcase_list = [] | |
| attempt = 0 | |
| # 为了防止生成失败,这里我们最多尝试3次 | |
| # 本文是一种双分支的方法,分别生成k个code以及多个testcase | |
| while attempt < 3 and (testcase_list is None or len(testcase_list) < cfg.gen_test_nums): | |
| attempt += 1 | |
| try: | |
| # 获取方法:生成答案以及测试样例,这里我们不给 shots | |
| test_res = responser.respond( | |
| code_system_prompt, | |
| user_prompt= user_prompt_test) | |
| if '```json' in test_res: | |
| testcase_list += list(extract_jsonl_from_markdown(test_res)) | |
| else: | |
| testcase_list += list(json.loads(test_res)) | |
| except Exception as e: | |
| write_log(f"error request{e}", log_file) | |
| continue | |
| attempt = 0 | |
| while attempt < 3 and solutions_gen is None: | |
| attempt += 1 | |
| try: | |
| code_res = responser.respond( | |
| code_system_prompt, | |
| user_prompt= user_prompt_code) | |
| solutions_gen = extract_code(code_res) | |
| except Exception as e: | |
| write_log(f"fail to extrack {e} \n{code_res}", log_file) | |
| continue | |
| # 然后利用生成的测试样例,与各个解代码进行测试 | |
| # 这里我们并行的每次对一个解方法跑5个测试样例 | |
| batch_size = 10 | |
| matrix = [] | |
| for code in solutions_gen: | |
| exe_result = [] | |
| for i in range(0, len(testcase_list), batch_size): | |
| batch = testcase_list[i:i+batch_size] if i + batch_size <= len(testcase_list) else testcase_list[i:] # 获取当前批次 | |
| batch_inputs = [test['input'] for test in batch] | |
| batch_outputs = [test['output'] for test in batch] | |
| results = run_multiple_tests(code, batch_inputs, time_limit=3) | |
| exe_result += [batch_output == result for batch_output, result in zip(batch_outputs, results)] | |
| matrix.append(exe_result) | |
| # 针对通过情况,划分为多个子集,选择 |测试样例数量| * |解代码| 最大的子集 | |
| # 验证生成测试样例的正确性 | |
| testcases[item['problem_id']] = [] | |
| for test in testcase_list: | |
| input_str = test["input"] | |
| output_str = test["output"] | |
| try: | |
| for solution in item["solutions"]: | |
| try : | |
| execute_res = run_cpp_code_linux(solution, input_str, time_limit=3) | |
| print(f"res: {execute_res}") | |
| except Exception as e: | |
| write_log(f"{item['problem_id']} solution execute fail", log_file) | |
| continue | |
| if "stdout" in execute_res.keys() and test_output_comparison( output_str, execute_res["stdout"]): | |
| testcases[item['problem_id']].append( | |
| test | |
| ) | |
| break | |
| except Exception as e: | |
| write_log(f"{item['problem_id']} fail to generate gloden-code output {e}", log_file) | |
| total_gen_count = len(testcase_list) | |
| success_gen_count = len(testcases[item['problem_id']]) | |
| testcases_pass_rate[item['problem_id']] = { | |
| "gen_nums": total_gen_count, | |
| "set_nums": success_gen_count | |
| } | |
| if len(testcases.keys()) % 20 == 0: | |
| write_json_to_file(testcases, cfg.file_crux.format('ours', batch, count // 20)) | |
| testcases = {} | |
| write_json_to_file(testcases, cfg.file_crux.format('ours', f"final{batch}")) | |
| write_json_to_file(testcases_pass_rate, cfg.pass_rate_file.format('ours', batch)) |
Xet Storage Details
- Size:
- 6.69 kB
- Xet hash:
- 020475a0345f5f3f023df49e4a54e7a3e060f259c7352e6b587a6c6b8ab8893f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.