Spaces:
Paused
Paused
| from asyncio import sleep | |
| from base64 import b64decode | |
| from binascii import Error as BinasciiError | |
| from contextlib import asynccontextmanager | |
| from io import BytesIO | |
| from json import dumps, loads | |
| from logging import Formatter, INFO, StreamHandler, getLogger | |
| from pathlib import Path | |
| from random import choice | |
| from typing import AsyncGenerator | |
| from PIL.Image import open as image_open | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from httpx import AsyncClient | |
| from patchright.async_api import FilePayload, Request as PlaywrightRequest, async_playwright | |
| from prlps_fakeua import UserAgent | |
| from starlette.responses import Response | |
| logger = getLogger('RHYMES_AI_API') | |
| logger.setLevel(INFO) | |
| handler = StreamHandler() | |
| handler.setLevel(INFO) | |
| formatter = Formatter('%(asctime)s | %(levelname)s : %(message)s', datefmt='%d.%m.%Y %H:%M:%S') | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| logger.info('инициализация приложения...') | |
| ua = UserAgent(os=['windows', 'mac']) | |
| workdir = Path(__file__).parent | |
| infer_data = workdir / 'infer_data.json' | |
| BASE_URL = 'https://akhaliq-anychat.hf.space' | |
| def base64_to_jpeg_bytes(base64_str: str) -> bytes: | |
| try: | |
| if ',' not in base64_str: | |
| raise ValueError("недопустимый формат строки base64") | |
| base64_data = base64_str.split(',', 1)[1] | |
| binary_data = b64decode(base64_data) | |
| with image_open(BytesIO(binary_data)) as img: | |
| with BytesIO() as jpeg_bytes: | |
| img.convert('RGB').save(jpeg_bytes, format='JPEG', quality=90, optimize=True) | |
| return jpeg_bytes.getvalue() | |
| except (BinasciiError, OSError) as e: | |
| raise ValueError('данные не являются корректным изображением') from e | |
| def image_bytes(base64_image_str: str) -> FilePayload: | |
| return FilePayload( | |
| name=generate_random_string(12) + '.jpeg', | |
| mimeType='image/jpeg', | |
| buffer=base64_to_jpeg_bytes(base64_image_str) | |
| ) | |
| def generate_random_string(length): | |
| return ''.join(choice('abcdefghijklmnopqrstuvwxyz0123456789') for _ in range(length)) | |
| def get_infer_data() -> tuple[int, int, str]: | |
| data = loads(infer_data.read_text()) | |
| logger.debug(f'загруженные из файла данные `get_infer_data`: {data}') | |
| return data['fn_index'], data['trigger_id'], data['session_hash'] | |
| def prepare_data(gradio_file_path: str, question: str, fn_index: int, trigger_id: int, session_hash: str) -> dict: | |
| return { | |
| "data": [ | |
| None, | |
| [[{ | |
| "file": { | |
| "path": gradio_file_path, | |
| "url": f"{BASE_URL}/gradio_api/file={gradio_file_path}", | |
| "size": None, "orig_name": None, "mime_type": "image/jpeg", "is_stream": False, | |
| "meta": {"_type": "gradio.FileData"} | |
| }, | |
| "alt_text": None | |
| }, None], [question, None]] | |
| ], "event_data": None, | |
| "fn_index": fn_index, | |
| "trigger_id": trigger_id, | |
| "session_hash": session_hash | |
| } | |
| async def fetch_result(base64_image_str: str, question: str) -> str | None: | |
| fn_index, trigger_id, session_hash = get_infer_data() | |
| async with AsyncClient(follow_redirects=True, timeout=40) as client: | |
| image_file = image_bytes(base64_image_str) | |
| boundary = f'----WebKitFormBoundary{generate_random_string(15).upper()}' | |
| upload_response = await client.post( | |
| f'{BASE_URL}/gradio_api/upload?upload_id={generate_random_string(11)}', | |
| headers={ | |
| 'Content-Type': f'multipart/form-data; boundary={boundary}', | |
| 'accept': '*/*' | |
| }, | |
| content=( | |
| f'--{boundary}\r\n' | |
| f'Content-Disposition: form-data; name="files"; filename="{image_file.get('name')}"\r\n' | |
| f'Content-Type: {image_file.get("mimeType")}\r\n\r\n' | |
| f'{image_file.get("buffer").decode("latin1")}\r\n' | |
| f'--{boundary}--\r\n' | |
| ).encode('latin1') | |
| ) | |
| upload_response.raise_for_status() | |
| gradio_file_path = upload_response.json()[0] | |
| logger.debug(f'gradio_file_path: {gradio_file_path}') | |
| send_response = await client.post( | |
| f'{BASE_URL}/gradio_api/queue/join', | |
| headers={ | |
| 'accept': '*/*', | |
| 'content-type': 'application/json' | |
| }, | |
| json=prepare_data(gradio_file_path, question, fn_index, trigger_id, session_hash) | |
| ) | |
| send_response.raise_for_status() | |
| logger.debug(f'send_response: {send_response.text}') | |
| async with client.stream( | |
| 'GET', | |
| f'{BASE_URL}/gradio_api/queue/data?session_hash={session_hash}', | |
| headers={'accept': 'text/event-stream', 'content-type': 'application/json' | |
| }) as result_response: | |
| result_response.raise_for_status() | |
| async for line in result_response.aiter_lines(): | |
| if line.startswith('data:'): | |
| logger.debug(f'result_response line: {line}') | |
| event_data = loads(line[6:]) | |
| if event_data.get('msg') == 'process_completed': | |
| logger.debug(f'process_completed: {event_data}') | |
| data = event_data.get('output', {}).get('data', []) | |
| if data: | |
| return data[0][1][1] | |
| return None | |
| def take_infer_data(request: PlaywrightRequest): | |
| if request.url.startswith("https://akhaliq-anychat.hf.space/gradio_api/queue/join"): | |
| try: | |
| data = loads(request.post_data) | |
| if data.get('data'): | |
| fn_index = data.get('fn_index') | |
| trigger_id = data.get('trigger_id') | |
| session_hash = data.get('session_hash') | |
| if fn_index and trigger_id and session_hash: | |
| infer_data_json = { | |
| 'fn_index': fn_index, | |
| 'trigger_id': trigger_id, | |
| 'session_hash': session_hash | |
| } | |
| infer_data.write_text(dumps(infer_data_json, indent=4)) | |
| logger.debug(f'полученные из браузера данные в `take_infer_data`: {infer_data_json}') | |
| except Exception as ext: | |
| logger.error(f'ошибка `take_infer_data`: {ext}') | |
| pass | |
| async def browser_request(base64_image_str: str, question: str) -> str | None: | |
| async with async_playwright() as playwright: | |
| browser = await playwright.chromium.launch(headless=True, args=['--disable-blink-features=AutomationControlled']) | |
| context = await browser.new_context( | |
| viewport={'width': 2560, 'height': 1440}, | |
| screen={'width': 2560, 'height': 1286}, | |
| color_scheme='dark', | |
| ignore_https_errors=True, | |
| locale='en-US', | |
| user_agent=ua.random, | |
| ) | |
| try: | |
| page = await context.new_page() | |
| image_file = image_bytes(base64_image_str) | |
| page.on('request', take_infer_data) | |
| await page.goto('https://akhaliq-anychat.hf.space/?__theme=light') | |
| await page.get_by_role('tab', name='Grok').click() | |
| await page.get_by_role('textbox', name='Type a message...').fill(question) | |
| await page.get_by_role('group', name='Multimedia input field').get_by_test_id('file-upload').set_input_files(image_file) | |
| await page.wait_for_selector('img.thumbnail-image') | |
| submit_button = page.get_by_role('group', name='Multimedia input field').locator('.submit-button') | |
| await submit_button.click() | |
| await page.wait_for_selector('button[aria-label="Retry"]', state='visible') | |
| await submit_button.wait_for(state='visible') | |
| caption = ' '.join(await page.get_by_test_id('bot').all_text_contents()).strip() | |
| await context.close() | |
| await browser.close() | |
| if caption: | |
| logger.info('результат получен из `browser_request`') | |
| return caption | |
| except Exception as exc: | |
| logger.error(f'ошибка `browser_request`: {exc}') | |
| return None | |
| async def httpx_request(base64_image_str: str, question: str) -> str | None: | |
| try: | |
| caption = await fetch_result(base64_image_str, question) | |
| logger.debug(caption) | |
| if caption: | |
| logger.info('результат получен из `httpx_request`') | |
| return caption | |
| except Exception as exc: | |
| logger.error(f'ошибка `browser_request`: {exc}') | |
| return None | |
| async def get_grok_caption(base64_image_str: str, question: str) -> str | None: | |
| attempts = 3 | |
| for _ in range(attempts): | |
| result = await httpx_request(base64_image_str, question) | |
| if result: | |
| return result | |
| result = await browser_request(base64_image_str, question) | |
| if result: | |
| return result | |
| await sleep(1.5) | |
| logger.error(f'превышено максимальное количество попыток') | |
| return None | |
| async def app_lifespan(_) -> AsyncGenerator: | |
| logger.info('запуск приложения') | |
| try: | |
| logger.info('старт API') | |
| yield | |
| finally: | |
| logger.info('приложение завершено') | |
| app = FastAPI(lifespan=app_lifespan, title='RHYMES_AI_API') | |
| banned_endpoints = [ | |
| '/openapi.json', | |
| '/docs', | |
| '/docs/oauth2-redirect', | |
| 'swagger_ui_redirect', | |
| '/redoc', | |
| ] | |
| async def block_banned_endpoints(request: Request, call_next): | |
| logger.debug(f'получен запрос: {request.url.path}') | |
| if request.url.path in banned_endpoints: | |
| logger.warning(f'запрещенный endpoint: {request.url.path}') | |
| return Response(status_code=403) | |
| response = await call_next(request) | |
| return response | |
| async def describe_v1(request: Request): | |
| logger.info('запрос `describe_v1`') | |
| body = await request.json() | |
| content_text = '' | |
| image_data = '' | |
| messages = body.get('messages', []) | |
| for message in messages: | |
| role = message.get('role') | |
| content = message.get('content') | |
| if role in ['system', 'user']: | |
| if isinstance(content, str): | |
| content_text += content + ' ' | |
| elif isinstance(content, list): | |
| for item in content: | |
| if item.get('type') == 'text': | |
| content_text += item.get('text', '') + ' ' | |
| elif item.get('type') == 'image_url': | |
| image_url = item.get('image_url', {}) | |
| url = image_url.get('url') | |
| if url and url.startswith('data:image/'): | |
| image_data = url | |
| image_data, content_text = image_data.strip(), content_text.strip() | |
| if not content_text or not image_data: | |
| return JSONResponse({'caption': 'изображение должно быть передано как строка base64 `data:image/jpeg;base64,{base64_img}` а также текст'}, status_code=400) | |
| try: | |
| caption = await get_grok_caption(image_data, content_text) | |
| return JSONResponse({'caption': caption}, status_code=200) | |
| except Exception as e: | |
| return JSONResponse({'caption': str(e)}, status_code=500) | |
| async def root(): | |
| return HTMLResponse('ну пролапс, ну и что', status_code=200) | |
| if __name__ == '__main__': | |
| from uvicorn import run as uvicorn_run | |
| logger.info('запуск сервера uvicorn') | |
| uvicorn_run(app, host='0.0.0.0', port=7860) | |