| import asyncio |
| import copy |
| import pdb |
|
|
| from factool.knowledge_qa.pipeline import knowledge_qa_pipeline |
| from factool.code.pipeline import code_pipeline |
| from factool.math.pipeline import math_pipeline |
| from factool.scientific.pipeline import scientific_pipeline |
|
|
| class Factool(): |
| def __init__(self, foundation_model): |
| self.foundation_model = foundation_model |
| self.pipelines = { |
| "kbqa_online": knowledge_qa_pipeline( |
| foundation_model, 10, "online" |
| ), |
| "code": code_pipeline( |
| foundation_model, 3, 3 |
| ), |
| "math": math_pipeline( |
| foundation_model |
| ), |
| "scientific": scientific_pipeline( |
| foundation_model |
| ), |
| } |
|
|
| def run(self, inputs): |
| outputs = copy.deepcopy(inputs) |
| batches = [] |
| current_category = inputs[0]['category'] |
| current_search_type = inputs[0].get('search_type', None) |
| current_data_link = inputs[0].get('data_link', None) |
| current_embedding_link = inputs[0].get('embedding_link', None) |
| current_batch = [] |
|
|
| for input in inputs: |
| if (input['category'] == current_category != 'kbqa') \ |
| or (input['category'] == current_category == 'kbqa' and input.get('search_type', None) == current_search_type == "online") \ |
| or (input['category'] == current_category == 'kbqa' and input.get('search_type', None) == current_search_type == "local"\ |
| and input.get('data_link', None)==current_data_link and input.get('embedding_link', None)==current_embedding_link): |
| current_batch.append(input) |
| else: |
| batches.append(current_batch) |
| current_batch = [input] |
| current_category = input['category'] |
| current_search_type = input.get('search_type', None) |
| current_data_link = input.get('data_link', None) |
| current_embedding_link = input.get('embedding_link', None) |
| |
| batches.append(current_batch) |
|
|
| index = 0 |
| for batch in batches: |
| if not batch: continue |
| |
| category = batch[0]['category'] |
| search_type = batch[0].get('search_type', None) |
| if category == 'code': |
| batch_results = asyncio.run( |
| self.pipelines[category].run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch], |
| [sample['entry_point'] for sample in batch] |
| ) |
| ) |
| elif category == 'kbqa': |
| if search_type is None or search_type == "online": |
| batch_results = asyncio.run( |
| self.pipelines[category+"_online"].run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch], |
| ) |
| ) |
| else: |
| batch_results = asyncio.run( |
| knowledge_qa_pipeline( |
| self.foundation_model,2,"local",batch[0].get("data_link"),batch[0].get("embedding_link") |
| ).run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch], |
| ) |
| ) |
| else: |
| batch_results = asyncio.run( |
| self.pipelines[category].run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch] |
| ) |
| ) |
| for result in batch_results: |
| outputs[index].update(result) |
| index += 1 |
| |
| |
| total_response_factuality = sum(output['response_level_factuality'] for output in outputs) |
| avg_response_level_factuality = total_response_factuality / len(outputs) |
|
|
| |
| num_claims = 0 |
| total_claim_factuality = 0 |
| for output in outputs: |
| if output['category'] == 'kbqa': |
| num_claims += len(output['claim_level_factuality']) |
| total_claim_factuality += sum(claim['factuality'] for claim in output['claim_level_factuality']) |
| elif output['category'] == 'code': |
| num_claims += 1 |
| total_claim_factuality += output['claim_level_factuality'] |
| elif output['category'] == 'math': |
| num_claims += len(output['claim_level_factuality']) |
| total_claim_factuality += sum(output['claim_level_factuality']) |
| elif output['category'] == 'scientific': |
| num_claims += len(output['claim_level_factuality']) |
| total_claim_factuality += sum(claim['factuality'] for claim in output['claim_level_factuality']) |
|
|
| avg_claim_level_factuality = total_claim_factuality / num_claims |
|
|
| return {"average_claim_level_factuality": avg_claim_level_factuality, "average_response_level_factuality": avg_response_level_factuality, "detailed_information": outputs} |
|
|
| async def run_for_plugin(self, inputs): |
| outputs = copy.deepcopy(inputs) |
|
|
| batches = [] |
| current_category = inputs[0]['category'] |
| current_batch = [] |
|
|
| for input in inputs: |
| if input['category'] == current_category: |
| current_batch.append(input) |
| else: |
| batches.append(current_batch) |
| current_batch = [input] |
| current_category = input['category'] |
| |
| batches.append(current_batch) |
|
|
| index = 0 |
| for batch in batches: |
| category = batch[0]['category'] |
| if category == 'code': |
| batch_results = await self.pipelines[category].run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch], |
| [sample['entry_point'] for sample in batch], |
| ) |
| else: |
| batch_results = await self.pipelines[category].run_with_tool_api_call( |
| [sample['prompt'] for sample in batch], |
| [sample['response'] for sample in batch], |
| ) |
| for result in batch_results: |
| outputs[index].update(result) |
| index += 1 |
| |
| return outputs |