| import os |
| current_dir = os.getcwd() |
| os.environ['HF_HOME'] = os.path.join(current_dir) |
| from sentence_transformers import SentenceTransformer, util |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| from PIL import Image |
| from serpapi import GoogleSearch |
| from keybert import KeyBERT |
| from typing import Dict, Any, List |
| import base64 |
| import torch |
| model_id = "vikhyatk/moondream2" |
| revision = "2024-08-26" |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, trust_remote_code=True, revision=revision |
| ) |
|
|
| model.to('cuda') |
| tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) |
|
|
| model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
| sentence_model = SentenceTransformer(model_name, device='cuda') |
|
|
| class ProductSearcher: |
| def __init__(self, user_input, image_path): |
| self.user_input = user_input |
| self.image_path = image_path |
| self.predefined_questions = [ |
| "tôi muốn mua sản phẩm này", |
| "tôi muốn thông tin về sản phẩm", |
| "tôi muốn biết giá cái này" |
| ] |
| self.prompts = [ |
| "Descibe product in image with it color. Only answer in one sentence" |
| "Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image", |
| "Estimate the price of the product and provide a detailed description of the product" |
| ] |
| self.description = '' |
| self.keyphrases = [] |
| self.kw_model= KeyBERT() |
|
|
|
|
| def get_most_similar_sentence(self): |
| user_input_embedding = sentence_model.encode(self.user_input) |
| predefined_embeddings = sentence_model.encode(self.predefined_questions) |
| similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings) |
| most_similar_index = similarity_scores.argmax().item() |
| return self.prompts[most_similar_index] |
|
|
| def generate_description(self): |
| prompt = self.get_most_similar_sentence() |
| image = Image.open(self.image_path) |
|
|
| enc_image = model.encode_image(image) |
| self.description = model.answer_question(enc_image, prompt, tokenizer) |
| del enc_image |
|
|
| def extract_keyphrases(self): |
| self.keyphrases = self.kw_model.extract_keywords(self.description) |
| def search_products(self, k=3): |
| |
| q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image'] |
| question = " ".join(q) |
| search = GoogleSearch({ |
| "engine": "google", |
| |
| "q":question, |
| "tbm": "shop", |
| "api_key": os.environ["API_KEY"] |
| }) |
| results = search.get_dict() |
| |
| products = results.get('shopping_results', [])[:k] |
| return products |
|
|
| def run(self, k=3): |
| self.generate_description() |
| self.extract_keyphrases() |
| |
| return self.keyphrases |
|
|
|
|
|
|
| class EndpointHandler: |
| def __init__(self,path=""): |
| pass |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: dict): A dictionary containing the inputs. |
| message (:obj: str): The user message. |
| image (:obj: str): The base64-encoded image content. |
| Return: |
| A list of dictionaries containing the product search results. |
| """ |
| inputs = data.get("inputs", {}) |
| message = inputs.get("message") |
| image_content = inputs.get("image") |
|
|
| |
| image_bytes = base64.b64decode(image_content) |
|
|
| |
| image_path = "input/temp_image.jpg" |
| os.makedirs("input", exist_ok=True) |
| with open(image_path, "wb") as f: |
| f.write(image_bytes) |
|
|
| |
| searcher = ProductSearcher(message, image_path) |
|
|
| |
| results = searcher.run(k=3) |
| del searcher |
| |
| return results |
|
|
| |