download
raw
6.69 kB
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from prompt import code_system_prompt, test_prompt, code_prompt
from typing import List, Optional
import re
import json
import multiprocessing
import traceback
from execute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux, run_multiple_tests
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
import datetime
def write_log(message: str, log_file: str = "log-lcb.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
def extract_jsonl_from_markdown(markdown_text):
# 使用正则表达式匹配 Markdown 中的 JSON 格式数据
json_pattern = re.compile(r'```json(.*?)```', re.DOTALL)
matches = json_pattern.findall(markdown_text)
# 将匹配到的 JSON 字符串解析为字典
json_data = []
for match in matches:
try:
json_data.append(json.loads(match.strip()))
except json.JSONDecodeError:
continue
return json_data[0]
def extract_code(ans_str):
pattern = r'```python\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return [matche for matche in matches]
def extract_json(ans_str):
pattern = r'```json\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
import argparse
keys = [
"ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9",
"ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5",
"ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2",
"ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7",
]
# parser = argparse.ArgumentParser(description="接收1个命令行参数")
# parser.add_argument("param1", type=str, help="第一个参数")
# args = parser.parse_args()
# batch = int(args.param1)
batch = 0
api_key = keys[batch // 3]
log_file = cfg.log_file.format(batch)
responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name)
log_file = cfg.log_file.format(batch)
al_dataset = get_datasets_by_name(cfg.dataset_name)
start_pos = (batch) * 100
end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset)
al_dataset = al_dataset[start_pos: end_pos]
al_dataset = al_dataset[0: 2]
count = 0
testcases = {}
generators = {}
testcases_pass_rate = {}
for item in al_dataset:
id = item["tcb_id"].replace('/', '')
query = item['query_en']
user_prompt_test = test_prompt.replace("{num_of_test}", str(cfg.gen_test_nums)) + query
user_prompt_code = code_prompt.format(cfg.pass_k) + query
solutions_gen = None
testcase_list = []
attempt = 0
# 为了防止生成失败,这里我们最多尝试3次
# 本文是一种双分支的方法,分别生成k个code以及多个testcase
while attempt < 3 and (testcase_list is None or len(testcase_list) < cfg.gen_test_nums):
attempt += 1
try:
# 获取方法:生成答案以及测试样例,这里我们不给 shots
test_res = responser.respond(
code_system_prompt,
user_prompt= user_prompt_test)
if '```json' in test_res:
testcase_list += list(extract_jsonl_from_markdown(test_res))
else:
testcase_list += list(json.loads(test_res))
except Exception as e:
write_log(f"error request{e}", log_file)
continue
attempt = 0
while attempt < 3 and solutions_gen is None:
attempt += 1
try:
code_res = responser.respond(
code_system_prompt,
user_prompt= user_prompt_code)
solutions_gen = extract_code(code_res)
except Exception as e:
write_log(f"fail to extrack {e} \n{code_res}", log_file)
continue
# 然后利用生成的测试样例,与各个解代码进行测试
# 这里我们并行的每次对一个解方法跑5个测试样例
batch_size = 10
matrix = []
for code in solutions_gen:
exe_result = []
for i in range(0, len(testcase_list), batch_size):
batch = testcase_list[i:i+batch_size] if i + batch_size <= len(testcase_list) else testcase_list[i:] # 获取当前批次
batch_inputs = [test['input'] for test in batch]
batch_outputs = [test['output'] for test in batch]
results = run_multiple_tests(code, batch_inputs, time_limit=3)
exe_result += [batch_output == result for batch_output, result in zip(batch_outputs, results)]
matrix.append(exe_result)
# 针对通过情况,划分为多个子集,选择 |测试样例数量| * |解代码| 最大的子集
# 验证生成测试样例的正确性
testcases[item['problem_id']] = []
for test in testcase_list:
input_str = test["input"]
output_str = test["output"]
try:
for solution in item["solutions"]:
try :
execute_res = run_cpp_code_linux(solution, input_str, time_limit=3)
print(f"res: {execute_res}")
except Exception as e:
write_log(f"{item['problem_id']} solution execute fail", log_file)
continue
if "stdout" in execute_res.keys() and test_output_comparison( output_str, execute_res["stdout"]):
testcases[item['problem_id']].append(
test
)
break
except Exception as e:
write_log(f"{item['problem_id']} fail to generate gloden-code output {e}", log_file)
total_gen_count = len(testcase_list)
success_gen_count = len(testcases[item['problem_id']])
testcases_pass_rate[item['problem_id']] = {
"gen_nums": total_gen_count,
"set_nums": success_gen_count
}
if len(testcases.keys()) % 20 == 0:
write_json_to_file(testcases, cfg.file_crux.format('ours', batch, count // 20))
testcases = {}
write_json_to_file(testcases, cfg.file_crux.format('ours', f"final{batch}"))
write_json_to_file(testcases_pass_rate, cfg.pass_rate_file.format('ours', batch))

Xet Storage Details

Size:
6.69 kB
·
Xet hash:
020475a0345f5f3f023df49e4a54e7a3e060f259c7352e6b587a6c6b8ab8893f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.