| import click |
| from agents.GaAgent import GaAgent |
| from models.OpenAI import OpenAIModel |
| from models.Gemini import GeminiModel |
| from models.Claude import ClaudeModel |
| from models.Vllm import VLLMModel |
| from dataloaders.TritonBench import TritonBench |
| from args_config import load_config |
|
|
|
|
| @click.command() |
| @click.option('-c', '--config', default='') |
| def main(config): |
| |
| |
| args = load_config(config) |
|
|
| |
| |
| |
| model = VLLMModel(api_key=args.api_key, model_id=args.model_id, base_url=args.model_id) |
|
|
|
|
| |
| dataset = TritonBench(statis_path=args.statis_path, |
| py_folder=args.py_folder, |
| instruction_path=args.instruction_path, |
| py_interpreter=args.py_interpreter, |
| golden_metrics=args.golden_metrics, |
| perf_ref_folder=args.perf_ref_folder, |
| perf_G_path=args.perf_G_path, |
| result_path=args.result_path) |
|
|
| |
| agent = GaAgent(model=model, dataset=dataset, corpus_path=args.corpus_path, mem_file=args.mem_file, descendant_num=args.descendant_num) |
|
|
| |
| agent.run(output_path=args.output_path, |
| multi_thread=args.multi_thread, |
| iteration_num=args.max_iteration, |
| temperature=args.temperature, |
| datalen=args.datalen, |
| gpu_id=args.gpu_id, |
| start_iter=args.start_iter, |
| ancestor_num=args.ancestor_num, |
| descendant_num=args.descendant_num, |
| descendant_debug=args.descendant_debug, |
| target_gpu=args.target_gpu, |
| profiling=args.profiling, |
| start_idx=args.start_idx, |
| args=args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|