| from typing import Literal
|
| from models import *
|
| from utils import *
|
| from modules import *
|
| from construct import *
|
|
|
|
|
| class Pipeline:
|
| def __init__(self, llm: BaseEngine):
|
| self.llm = llm
|
| self.case_repo = CaseRepositoryHandler(llm = llm)
|
| self.schema_agent = SchemaAgent(llm = llm)
|
| self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
| self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
|
|
| def __check_consistancy(self, llm, task, mode, update_case):
|
| if llm.name == "OneKE":
|
| if task == "Base" or task == "Triple":
|
| raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
| else:
|
| mode = "quick"
|
| update_case = False
|
| print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
|
| return mode, update_case
|
| return mode, update_case
|
|
|
| def __init_method(self, data: DataPoint, process_method2):
|
| default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
| if "schema_agent" not in process_method2:
|
| process_method2["schema_agent"] = "get_default_schema"
|
| if data.task != "Base":
|
| process_method2["schema_agent"] = "get_retrieved_schema"
|
| if "extraction_agent" not in process_method2:
|
| process_method2["extraction_agent"] = "extract_information_direct"
|
| sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
|
| return sorted_process_method
|
|
|
| def __init_data(self, data: DataPoint):
|
| if data.task == "NER":
|
| data.instruction = config['agent']['default_ner']
|
| data.output_schema = "EntityList"
|
| elif data.task == "RE":
|
| data.instruction = config['agent']['default_re']
|
| data.output_schema = "RelationList"
|
| elif data.task == "EE":
|
| data.instruction = config['agent']['default_ee']
|
| data.output_schema = "EventList"
|
| elif data.task == "Triple":
|
| data.instruction = config['agent']['default_triple']
|
| data.output_schema = "TripleList"
|
| return data
|
|
|
|
|
| def get_extract_result(self,
|
| task: TaskType,
|
| three_agents = {},
|
| construct = {},
|
| instruction: str = "",
|
| text: str = "",
|
| output_schema: str = "",
|
| constraint: str = "",
|
| use_file: bool = False,
|
| file_path: str = "",
|
| truth: str = "",
|
| mode: str = "quick",
|
| update_case: bool = False,
|
| show_trajectory: bool = False,
|
| isgui: bool = False,
|
| iskg: bool = False,
|
| ):
|
|
|
|
|
|
|
|
|
| mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
|
|
|
|
|
| data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
| data = self.__init_data(data)
|
| if mode in config['agent']['mode'].keys():
|
| process_method = config['agent']['mode'][mode].copy()
|
| else:
|
| process_method = mode
|
|
|
| if isgui and mode == "customized":
|
| process_method = three_agents
|
| print("Customized 3-Agents: ", three_agents)
|
|
|
| sorted_process_method = self.__init_method(data, process_method)
|
| print("Process Method: ", sorted_process_method)
|
|
|
| print_schema = False
|
| frontend_schema = ""
|
| frontend_res = ""
|
|
|
|
|
| for agent_name, method_name in sorted_process_method.items():
|
| agent = getattr(self, agent_name, None)
|
| if not agent:
|
| raise AttributeError(f"{agent_name} does not exist.")
|
| method = getattr(agent, method_name, None)
|
| if not method:
|
| raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
| data = method(data)
|
| if not print_schema and data.print_schema:
|
| print("Schema: \n", data.print_schema)
|
| frontend_schema = data.print_schema
|
| print_schema = True
|
| data = self.extraction_agent.summarize_answer(data)
|
|
|
|
|
| if show_trajectory:
|
| print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
| extraction_result = json.dumps(data.pred, indent=2)
|
| print("Extraction Result: \n", extraction_result)
|
|
|
|
|
| if iskg:
|
| myurl = construct['url']
|
| myusername = construct['username']
|
| mypassword = construct['password']
|
| print(f"Construct KG in your {construct['database']} now...")
|
| cypher_statements = generate_cypher_statements(extraction_result)
|
| execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
|
|
|
| frontend_res = data.pred
|
|
|
|
|
| if update_case:
|
| if (data.truth == ""):
|
| truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
|
| if truth.strip() == "":
|
| data.truth = data.pred
|
| else:
|
| data.truth = extract_json_dict(truth)
|
| self.case_repo.update_case(data)
|
|
|
|
|
| result = data.pred
|
| trajectory = data.get_result_trajectory()
|
|
|
| return result, trajectory, frontend_schema, frontend_res
|
|
|