| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import tqdm |
| from PIL import Image |
| import hashlib |
| import torch |
| import fitz |
|
|
|
|
| def get_image_md5(img: Image.Image): |
| img_byte_array = img.tobytes() |
| hash_md5 = hashlib.md5() |
| hash_md5.update(img_byte_array) |
| hex_digest = hash_md5.hexdigest() |
| return hex_digest |
|
|
| def pdf_to_images(pdf_path, dpi=200): |
| doc = fitz.open(pdf_path) |
| images = [] |
| for page in tqdm.tqdm(doc): |
| pix = page.get_pixmap(dpi=dpi) |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
| images.append(img) |
| return images |
|
|
|
|
| class PDFVisualRetrieval: |
| def __init__(self, model, tokenizer): |
| self.tokenizer = tokenizer |
| self.model = model |
| self.reps = {} |
| self.images = {} |
| |
| def add_visual_documents(self, knowledge_base_name: str, images: Image.Image): |
| if knowledge_base_name not in self.reps: |
| self.reps[knowledge_base_name] = {} |
| if knowledge_base_name not in self.images: |
| self.images[knowledge_base_name] = {} |
| for image in tqdm.tqdm(images): |
| image_md5 = get_image_md5(image) |
| with torch.no_grad(): |
| reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps |
| self.reps[knowledge_base_name][image_md5] = reps.squeeze(0) |
| self.images[knowledge_base_name][image_md5] = image |
| return |
| |
| def retrieve(self, knowledge_base: str, query: str, topk: int): |
| doc_reps = list(self.reps[knowledge_base].values()) |
| query_with_instruction = "Represent this query for retrieving relavant document: " + query |
| with torch.no_grad(): |
| query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0) |
| doc_reps_cat = torch.stack(doc_reps, dim=0) |
| similarities = torch.matmul(query_rep, doc_reps_cat.T) |
| topk_values, topk_doc_ids = torch.topk(similarities, k=topk) |
| topk_values_np = topk_values.cpu().numpy() |
| topk_doc_ids_np = topk_doc_ids.cpu().numpy() |
| similarities_np = similarities.cpu().numpy() |
| all_images_doc_list = list(self.images[knowledge_base].values()) |
| images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np] |
| return topk_doc_ids_np, topk_values_np, images_topk |
| |
| def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 200): |
| print("[1/2] rendering pdf to images..") |
| images = pdf_to_images(pdf_file_path, dpi=dpi) |
| print("[2/2] model encoding images..") |
| self.add_visual_documents(knowledge_base_name=knowledge_base_name, images=images) |
| print("add pdf ok.") |
| return |
|
|
|
|
| if __name__ == "__main__": |
| from transformers import AutoModel |
| from transformers import AutoTokenizer |
| from PIL import Image |
| import torch |
| |
| device = 'cuda:0' |
| |
| |
| model_path = '/home/jeeves/xubokai/minicpm-visual-embedding-v0' |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
| model.to(device) |
| |
| pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf" |
| retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer) |
| retriever.add_pdf('test', pdf_path) |
| |
| topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=5) |
| |
| topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=5) |
| |
| topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=5) |
| |
|
|
|
|