| import json
|
| import os
|
| from typing import Any, Dict, List, Type, Union
|
|
|
| import anthropic
|
| import weave
|
| from anthropic import APIStatusError, AsyncAnthropic
|
| from pydantic import BaseModel
|
|
|
| from app.config import get_settings
|
| from app.core import errors
|
| from app.core.errors import BadRequestError, VendorError
|
| from app.core.prompts import get_prompts
|
| from app.services.base import BaseAttributionService
|
| from app.utils.converter import product_data_to_str
|
| from app.utils.image_processing import get_data_format, get_image_data
|
| from app.utils.logger import exception_to_str, setup_logger
|
|
|
| ENV = os.getenv("ENV", "LOCAL")
|
| if ENV == "LOCAL":
|
| weave_project_name = "cfai/attribution-exp"
|
| elif ENV == "DEV":
|
| weave_project_name = "cfai/attribution-dev"
|
| elif ENV == "UAT":
|
| weave_project_name = "cfai/attribution-uat"
|
| elif ENV == "PROD":
|
| pass
|
|
|
|
|
|
|
| settings = get_settings()
|
| prompts = get_prompts()
|
| logger = setup_logger(__name__)
|
|
|
|
|
| class AnthropicService(BaseAttributionService):
|
| def __init__(self):
|
| self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
|
| @weave.op
|
| async def extract_attributes(
|
| self,
|
| attributes_model: Type[BaseModel],
|
| ai_model: str,
|
| img_urls: List[str],
|
| product_taxonomy: str,
|
| product_data: Dict[str, Union[str, List[str]]],
|
| pil_images: List[Any] = None,
|
| img_paths: List[str] = None,
|
| ) -> Dict[str, Any]:
|
| logger.info("Extracting info via Anthropic...")
|
| tools = [
|
| {
|
| "name": "extract_garment_info",
|
| "description": "Extracts key information from the image.",
|
| "input_schema": attributes_model.model_json_schema(),
|
| "cache_control": {"type": "ephemeral"},
|
| }
|
| ]
|
|
|
| if img_urls is not None:
|
| image_messages = [
|
| {
|
| "type": "image",
|
| "source": {"type": "url", "url": img_url},
|
| }
|
| for img_url in img_urls
|
| ]
|
| elif img_paths is not None:
|
| image_messages = [
|
| {
|
| "type": "image",
|
| "source": {
|
| "type": "base64",
|
| "media_type": f"image/{get_data_format(img_path)}",
|
| "data": get_image_data(img_path),
|
| },
|
| }
|
| for img_path in img_paths
|
| ]
|
| else:
|
|
|
| pass
|
|
|
| system_message = [{"type": "text", "text": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE}]
|
|
|
| text_messages = [
|
| {
|
| "type": "text",
|
| "text": prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(
|
| product_taxonomy=product_taxonomy,
|
| product_data=product_data_to_str(product_data),
|
| ),
|
| }
|
| ]
|
|
|
| messages = [{"role": "user", "content": text_messages + image_messages}]
|
|
|
|
|
| try:
|
| response = await self.client.messages.create(
|
| model=ai_model,
|
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
| max_tokens=2048,
|
| system=system_message,
|
| tools=tools,
|
| messages=messages,
|
|
|
|
|
| top_k=1,
|
| )
|
| except anthropic.BadRequestError as e:
|
| raise BadRequestError(e.message)
|
| except Exception as e:
|
| raise VendorError(
|
| errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
| )
|
|
|
| for content in response.content:
|
| if content.type == "tool_use":
|
| if content.input is None or not content.input:
|
| raise VendorError(
|
| errors.VENDOR_THROW_ERROR.format(
|
| error_message="content.input is None or content.input is empty"
|
| )
|
| )
|
|
|
| return content.input
|
|
|
| raise VendorError(
|
| errors.VENDOR_THROW_ERROR.format(error_message="No tool_use found")
|
| )
|
|
|
| @weave.op
|
| async def follow_schema(self, schema, data):
|
| logger.info("Following structure via Anthropic...")
|
| tools = [
|
| {
|
| "name": "extract_garment_info",
|
| "description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE,
|
| "input_schema": schema,
|
| "cache_control": {"type": "ephemeral"},
|
| }
|
| ]
|
|
|
| text_messages = [
|
| {
|
| "type": "text",
|
| "text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
|
| }
|
| ]
|
|
|
| system_message = [
|
| {"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE}
|
| ]
|
|
|
| messages = [{"role": "user", "content": text_messages}]
|
| try:
|
| response = await self.client.messages.create(
|
| model=settings.ANTHROPIC_DEFAULT_MODEL,
|
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
| max_tokens=2048,
|
| system=system_message,
|
| tools=tools,
|
| messages=messages,
|
| )
|
| except Exception as e:
|
| raise VendorError(
|
| errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
| )
|
|
|
| for content in response.content:
|
| if content.type == "tool_use":
|
| return content.input["json_info"]
|
|
|
| return {"status": "ERROR: no tool_use found"}
|
|
|