| from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor |
| from core.helper.encrypter import decrypt_token, encrypt_token |
| from extensions.ext_database import db |
| from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint |
|
|
|
|
| class APIBasedExtensionService: |
| @staticmethod |
| def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: |
| extension_list = ( |
| db.session.query(APIBasedExtension) |
| .filter_by(tenant_id=tenant_id) |
| .order_by(APIBasedExtension.created_at.desc()) |
| .all() |
| ) |
|
|
| for extension in extension_list: |
| extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) |
|
|
| return extension_list |
|
|
| @classmethod |
| def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: |
| cls._validation(extension_data) |
|
|
| extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) |
|
|
| db.session.add(extension_data) |
| db.session.commit() |
| return extension_data |
|
|
| @staticmethod |
| def delete(extension_data: APIBasedExtension) -> None: |
| db.session.delete(extension_data) |
| db.session.commit() |
|
|
| @staticmethod |
| def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: |
| extension = ( |
| db.session.query(APIBasedExtension) |
| .filter_by(tenant_id=tenant_id) |
| .filter_by(id=api_based_extension_id) |
| .first() |
| ) |
|
|
| if not extension: |
| raise ValueError("API based extension is not found") |
|
|
| extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) |
|
|
| return extension |
|
|
| @classmethod |
| def _validation(cls, extension_data: APIBasedExtension) -> None: |
| |
| if not extension_data.name: |
| raise ValueError("name must not be empty") |
|
|
| if not extension_data.id: |
| |
| is_name_existed = ( |
| db.session.query(APIBasedExtension) |
| .filter_by(tenant_id=extension_data.tenant_id) |
| .filter_by(name=extension_data.name) |
| .first() |
| ) |
|
|
| if is_name_existed: |
| raise ValueError("name must be unique, it is already existed") |
| else: |
| |
| is_name_existed = ( |
| db.session.query(APIBasedExtension) |
| .filter_by(tenant_id=extension_data.tenant_id) |
| .filter_by(name=extension_data.name) |
| .filter(APIBasedExtension.id != extension_data.id) |
| .first() |
| ) |
|
|
| if is_name_existed: |
| raise ValueError("name must be unique, it is already existed") |
|
|
| |
| if not extension_data.api_endpoint: |
| raise ValueError("api_endpoint must not be empty") |
|
|
| |
| if not extension_data.api_key: |
| raise ValueError("api_key must not be empty") |
|
|
| if len(extension_data.api_key) < 5: |
| raise ValueError("api_key must be at least 5 characters") |
|
|
| |
| cls._ping_connection(extension_data) |
|
|
| @staticmethod |
| def _ping_connection(extension_data: APIBasedExtension) -> None: |
| try: |
| client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) |
| resp = client.request(point=APIBasedExtensionPoint.PING, params={}) |
| if resp.get("result") != "pong": |
| raise ValueError(resp) |
| except Exception as e: |
| raise ValueError("connection error: {}".format(e)) |
|
|