| import tarfile |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import faiss |
| import numpy as np |
| import pyarrow as pa |
| import requests |
| from tqdm import tqdm |
|
|
| __all__ = ["RetrievalDatabase", "download_retrieval_databases"] |
|
|
| RETRIEVAL_DATABASES_URLS = { |
| "cc12m": { |
| "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", |
| "cache_subdir": "./cc12m/vit-l-14/", |
| }, |
| } |
|
|
|
|
| def download_retrieval_databases(cache_dir: str): |
| """Download data if needed. |
| |
| Args: |
| cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". |
| """ |
| databases_path = Path(cache_dir, "databases") |
|
|
| for name, items in RETRIEVAL_DATABASES_URLS.items(): |
| url = items["url"] |
| database_path = Path(databases_path, name) |
| if database_path.exists(): |
| continue |
|
|
| |
| target_path = Path(databases_path, name + ".tar.gz") |
| target_path.parent.mkdir(parents=True, exist_ok=True) |
| with requests.get(url, stream=True) as r: |
| r.raise_for_status() |
| total_bytes_size = int(r.headers.get("content-length", 0)) |
| chunk_size = 8192 |
| p_bar = tqdm( |
| desc="Downloading cc12m index", |
| total=total_bytes_size, |
| unit="iB", |
| unit_scale=True, |
| ) |
| with open(target_path, "wb") as f: |
| for chunk in r.iter_content(chunk_size=chunk_size): |
| f.write(chunk) |
| p_bar.update(len(chunk)) |
| p_bar.close() |
|
|
| |
| tar = tarfile.open(target_path, "r:gz") |
| tar.extractall(target_path.parent) |
| tar.close() |
| target_path.unlink() |
|
|
|
|
| class RetrievalDatabaseMetadataProvider: |
| """Metadata provider for the retrieval database. |
| |
| Args: |
| metadata_dir (str): Path to the metadata directory. |
| """ |
|
|
| def __init__(self, metadata_dir: str): |
| metadatas = [str(a) for a in sorted(Path(metadata_dir).glob("**/*")) if a.is_file()] |
| self.table = pa.concat_tables( |
| [ |
| pa.ipc.RecordBatchFileReader(pa.memory_map(metadata, "r")).read_all() |
| for metadata in metadatas |
| ] |
| ) |
|
|
| def get(self, ids): |
| """Get the metadata for the given ids. |
| |
| Args: |
| ids (list): List of ids. |
| """ |
| columns = self.table.schema.names |
| end_ids = [i + 1 for i in ids] |
| t = pa.concat_tables([self.table[start:end] for start, end in zip(ids, end_ids)]) |
| return t.select(columns).to_pandas().to_dict("records") |
|
|
|
|
| class RetrievalDatabase: |
| """Retrieval database. |
| |
| Args: |
| database_name (str): Name of the database. |
| cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". |
| """ |
|
|
| def __init__(self, database_name: str, cache_dir: str): |
| assert database_name in RETRIEVAL_DATABASES_URLS.keys(), ( |
| f"Database name should be one of " |
| f"{list(RETRIEVAL_DATABASES_URLS.keys())}, got {database_name}." |
| ) |
|
|
| database_dir = Path(cache_dir) / "databases" |
| database_dir = database_dir / RETRIEVAL_DATABASES_URLS[database_name]["cache_subdir"] |
| self._database_dir = database_dir |
|
|
| image_index_fp = Path(database_dir) / "image.index" |
| text_index_fp = Path(database_dir) / "text.index" |
|
|
| image_index = ( |
| faiss.read_index(str(image_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) |
| if image_index_fp.exists() |
| else None |
| ) |
| text_index = ( |
| faiss.read_index(str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) |
| if text_index_fp.exists() |
| else None |
| ) |
|
|
| metadata_dir = str(Path(database_dir) / "metadata") |
| metadata_provider = RetrievalDatabaseMetadataProvider(metadata_dir) |
|
|
| self._image_index = image_index |
| self._text_index = text_index |
| self._metadata_provider = metadata_provider |
|
|
| def _map_to_metadata(self, indices: list, distances: list, embs: list, num_images: int): |
| """Map the indices to metadata. |
| |
| Args: |
| indices (list): List of indices. |
| distances (list): List of distances. |
| embs (list): List of results embeddings. |
| num_images (int): Number of images. |
| """ |
| results = [] |
| metas = self._metadata_provider.get(indices[:num_images]) |
| for key, (d, i, emb) in enumerate(zip(distances, indices, embs)): |
| output = {} |
| meta = None if key + 1 > len(metas) else metas[key] |
| if meta is not None: |
| output.update(self._meta_to_dict(meta)) |
| output["id"] = i.item() |
| output["similarity"] = d.item() |
| output["sample_z"] = emb.tolist() |
| results.append(output) |
|
|
| return results |
|
|
| def _meta_to_dict(self, metadata): |
| """Convert metadata to dict. |
| |
| Args: |
| metadata (dict): Metadata. |
| """ |
| output = {} |
| for k, v in metadata.items(): |
| if isinstance(v, bytes): |
| v = v.decode() |
| elif type(v).__module__ == np.__name__: |
| v = v.item() |
| output[k] = v |
| return output |
|
|
| def _get_connected_components(self, neighbors): |
| """Find connected components in a graph. |
| |
| Args: |
| neighbors (dict): Dictionary of neighbors. |
| """ |
| seen = set() |
|
|
| def component(node): |
| r = [] |
| nodes = {node} |
| while nodes: |
| node = nodes.pop() |
| seen.add(node) |
| nodes |= set(neighbors[node]) - seen |
| r.append(node) |
| return r |
|
|
| u = [] |
| for node in neighbors: |
| if node not in seen: |
| u.append(component(node)) |
| return u |
|
|
| def _deduplicate_embeddings(self, embeddings, threshold=0.94): |
| """Deduplicate embeddings. |
| |
| Args: |
| embeddings (np.matrix): Embeddings to deduplicate. |
| threshold (float): Threshold to use for deduplication. Default is 0.94. |
| """ |
| index = faiss.IndexFlatIP(embeddings.shape[1]) |
| index.add(embeddings) |
| l, _, indices = index.range_search(embeddings, threshold) |
|
|
| same_mapping = defaultdict(list) |
|
|
| for i in range(embeddings.shape[0]): |
| start = l[i] |
| end = l[i + 1] |
| for j in indices[start:end]: |
| same_mapping[int(i)].append(int(j)) |
|
|
| groups = self._get_connected_components(same_mapping) |
| non_uniques = set() |
| for g in groups: |
| for e in g[1:]: |
| non_uniques.add(e) |
|
|
| return set(list(non_uniques)) |
|
|
| def query( |
| self, query: np.matrix, modality: str = "text", num_samples: int = 10 |
| ) -> list[list[dict]]: |
| """Query the database. |
| |
| Args: |
| query (np.matrix): Query to search. |
| modality (str): Modality to search. One of `image` or `text`. Default to `text`. |
| num_samples (int): Number of samples to return. Default is 40. |
| """ |
| index = self._image_index if modality == "image" else self._text_index |
|
|
| distances, indices, embeddings = index.search_and_reconstruct(query, num_samples) |
| results = [indices[i] for i in range(len(indices))] |
|
|
| nb_results = [np.where(r == -1)[0] for r in results] |
| total_distances = [] |
| total_indices = [] |
| total_embeddings = [] |
| for i in range(len(results)): |
| num_res = nb_results[i][0] if len(nb_results[i]) > 0 else len(results[i]) |
|
|
| result_indices = results[i][:num_res] |
| result_distances = distances[i][:num_res] |
| result_embeddings = embeddings[i][:num_res] |
|
|
| |
| l2 = np.atleast_1d(np.linalg.norm(result_embeddings, 2, -1)) |
| l2[l2 == 0] = 1 |
| result_embeddings = result_embeddings / np.expand_dims(l2, -1) |
|
|
| |
| local_indices_to_remove = self._deduplicate_embeddings(result_embeddings) |
| indices_to_remove = set() |
| for local_index in local_indices_to_remove: |
| indices_to_remove.add(result_indices[local_index]) |
|
|
| curr_indices = [] |
| curr_distances = [] |
| curr_embeddings = [] |
| for ind, dis, emb in zip(result_indices, result_distances, result_embeddings): |
| if ind not in indices_to_remove: |
| indices_to_remove.add(ind) |
| curr_indices.append(ind) |
| curr_distances.append(dis) |
| curr_embeddings.append(emb) |
|
|
| total_indices.append(curr_indices) |
| total_distances.append(curr_distances) |
| total_embeddings.append(curr_embeddings) |
|
|
| if len(total_distances) == 0: |
| return [] |
|
|
| total_results = [] |
| for i in range(len(total_distances)): |
| results = self._map_to_metadata( |
| total_indices[i], total_distances[i], total_embeddings[i], num_samples |
| ) |
| total_results.append(results) |
|
|
| return total_results |
|
|