| import logging |
|
|
| from flask_login import current_user |
| from flask_restful import marshal, reqparse |
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound |
|
|
| import services.dataset_service |
| from controllers.console.app.error import ( |
| CompletionRequestError, |
| ProviderModelCurrentlyNotSupportError, |
| ProviderNotInitializeError, |
| ProviderQuotaExceededError, |
| ) |
| from controllers.console.datasets.error import DatasetNotInitializedError |
| from core.errors.error import ( |
| LLMBadRequestError, |
| ModelCurrentlyNotSupportError, |
| ProviderTokenNotInitError, |
| QuotaExceededError, |
| ) |
| from core.model_runtime.errors.invoke import InvokeError |
| from fields.hit_testing_fields import hit_testing_record_fields |
| from services.dataset_service import DatasetService |
| from services.hit_testing_service import HitTestingService |
|
|
|
|
| class DatasetsHitTestingBase: |
| @staticmethod |
| def get_and_validate_dataset(dataset_id: str): |
| dataset = DatasetService.get_dataset(dataset_id) |
| if dataset is None: |
| raise NotFound("Dataset not found.") |
|
|
| try: |
| DatasetService.check_dataset_permission(dataset, current_user) |
| except services.errors.account.NoPermissionError as e: |
| raise Forbidden(str(e)) |
|
|
| return dataset |
|
|
| @staticmethod |
| def hit_testing_args_check(args): |
| HitTestingService.hit_testing_args_check(args) |
|
|
| @staticmethod |
| def parse_args(): |
| parser = reqparse.RequestParser() |
|
|
| parser.add_argument("query", type=str, location="json") |
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") |
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") |
| return parser.parse_args() |
|
|
| @staticmethod |
| def perform_hit_testing(dataset, args): |
| try: |
| response = HitTestingService.retrieve( |
| dataset=dataset, |
| query=args["query"], |
| account=current_user, |
| retrieval_model=args["retrieval_model"], |
| external_retrieval_model=args["external_retrieval_model"], |
| limit=10, |
| ) |
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} |
| except services.errors.index.IndexNotInitializedError: |
| raise DatasetNotInitializedError() |
| except ProviderTokenNotInitError as ex: |
| raise ProviderNotInitializeError(ex.description) |
| except QuotaExceededError: |
| raise ProviderQuotaExceededError() |
| except ModelCurrentlyNotSupportError: |
| raise ProviderModelCurrentlyNotSupportError() |
| except LLMBadRequestError: |
| raise ProviderNotInitializeError( |
| "No Embedding Model or Reranking Model available. Please configure a valid provider " |
| "in the Settings -> Model Provider." |
| ) |
| except InvokeError as e: |
| raise CompletionRequestError(e.description) |
| except ValueError as e: |
| raise ValueError(str(e)) |
| except Exception as e: |
| logging.exception("Hit testing failed.") |
| raise InternalServerError(str(e)) |
|
|