| import collections |
| import heapq |
| import json |
| import os |
| import logging |
| import faiss |
| import requests |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from open_clip import create_model, get_tokenizer |
| from torchvision import transforms |
| from PIL import Image |
| import io |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
|
|
| log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" |
| logging.basicConfig(level=logging.INFO, format=log_format) |
| logger = logging.getLogger() |
|
|
| hf_token = os.getenv("HF_TOKEN") |
|
|
| model_str = "hf-hub:imageomics/bioclip" |
| tokenizer_str = "ViT-B-16" |
|
|
| txt_emb_npy = hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='txt_emb_species.npy', repo_type="dataset") |
| txt_names_json = "txt_emb_species.json" |
|
|
| min_prob = 1e-9 |
| k = 5 |
|
|
| ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| preprocess_img = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Resize((224, 224), antialias=True), |
| transforms.Normalize( |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ), |
| ] |
| ) |
|
|
| MIN_PROB = 1e-9 |
| TOP_K_PREDICTIONS = 5 |
| TOP_K_CANDIDATES = 250 |
| TOP_N_SIMILAR = 22 |
| SIMILARITY_BOOST = 0.2 |
| VOTE_THRESHOLD = 3 |
| SIMILARITY_THRESHOLD = 0.99 |
|
|
| |
| PHOTO_LOOKUP_PATH = f"./photo_lookup.json" |
| SPECIES_LOOKUP_PATH = f"./species_lookup.json" |
|
|
| theme = gr.themes.Base( |
| primary_hue=gr.themes.colors.teal, |
| secondary_hue=gr.themes.colors.blue, |
| neutral_hue=gr.themes.colors.gray, |
| text_size=gr.themes.sizes.text_lg, |
| ).set( |
| button_primary_background_fill="#114A56", |
| button_primary_background_fill_hover="#114A56", |
| block_title_text_weight="600", |
| block_label_text_weight="600", |
| block_label_text_size="*text_md", |
| ) |
|
|
| EXAMPLES_DIR = Path("examples") |
| example_images = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg")) |
|
|
| def indexed(lst, indices): |
| return [lst[i] for i in indices] |
|
|
| def format_name(taxon, common): |
| taxon = " ".join(taxon) |
| if not common: |
| return taxon |
| return f"{taxon} ({common})" |
|
|
| def combine_duplicate_predictions(predictions): |
| """Combine predictions where one name is contained within another.""" |
| combined = {} |
| used = set() |
| |
| |
| items = sorted(predictions.items(), key=lambda x: (-len(x[0]), -x[1])) |
| |
| for name1, prob1 in items: |
| if name1 in used: |
| continue |
| |
| total_prob = prob1 |
| used.add(name1) |
| |
| |
| for name2, prob2 in predictions.items(): |
| if name2 in used: |
| continue |
| |
| |
| name1_lower = name1.lower() |
| name2_lower = name2.lower() |
| |
| |
| if name1_lower in name2_lower or name2_lower in name1_lower: |
| total_prob += prob2 |
| used.add(name2) |
| |
| combined[name1] = total_prob |
| |
| |
| total = sum(combined.values()) |
| return {k: v/total for k, v in combined.items()} |
|
|
| @torch.no_grad() |
| def open_domain_classification(img, rank: int, return_all=False): |
| """ |
| Predicts from the entire tree of life using RAG approach. |
| """ |
| logger.info(f"Starting open domain classification for rank: {rank}") |
| img = preprocess_img(img).to(device) |
| img_features = model.encode_image(img.unsqueeze(0)) |
| img_features = F.normalize(img_features, dim=-1) |
|
|
| |
| logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() |
| probs = F.softmax(logits, dim=0) |
|
|
| |
| species_votes, similar_images = get_similar_images_metadata(img_features, faiss_index, id_mapping, name_mapping) |
|
|
| if rank + 1 == len(ranks): |
| |
| topk = probs.topk(TOP_K_CANDIDATES) |
| predictions = { |
| format_name(*txt_names[i]): prob.item() |
| for i, prob in zip(topk.indices, topk.values) |
| } |
| |
| |
| augmented_predictions = predictions.copy() |
| for pred_name in predictions: |
| pred_name_lower = pred_name.lower() |
| for voted_species, vote_count in species_votes.items(): |
| if voted_species in pred_name_lower or pred_name_lower in voted_species: |
| augmented_predictions[pred_name] += SIMILARITY_BOOST * vote_count |
| elif vote_count >= VOTE_THRESHOLD: |
| augmented_predictions[voted_species] = vote_count * SIMILARITY_BOOST |
|
|
| |
| sorted_predictions = dict(sorted( |
| augmented_predictions.items(), |
| key=lambda x: x[1], |
| reverse=True |
| )[:k]) |
| |
| |
| total = sum(sorted_predictions.values()) |
| sorted_predictions = {k: v/total for k, v in sorted_predictions.items()} |
| sorted_predictions = combine_duplicate_predictions(sorted_predictions) |
| |
| logger.info(f"Top K predictions after combining duplicates: {sorted_predictions}") |
| return sorted_predictions, similar_images |
|
|
| |
| output = collections.defaultdict(float) |
| for i in torch.nonzero(probs > MIN_PROB).squeeze(): |
| output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] |
|
|
| |
| for species, vote_count in species_votes.items(): |
| try: |
| |
| for taxonomy, _ in txt_names: |
| if species in " ".join(taxonomy).lower(): |
| higher_rank = " ".join(taxonomy[: rank + 1]) |
| output[higher_rank] += SIMILARITY_BOOST * vote_count |
| break |
| except Exception as e: |
| logger.error(f"Error processing vote for species {species}: {e}") |
|
|
| |
| topk_names = heapq.nlargest(k, output, key=output.get) |
| prediction_dict = {name: output[name] for name in topk_names} |
| |
| |
| total = sum(prediction_dict.values()) |
| prediction_dict = {k: v/total for k, v in prediction_dict.items()} |
| prediction_dict = combine_duplicate_predictions(prediction_dict) |
| |
| logger.info(f"Prediction dictionary after combining duplicates: {prediction_dict}") |
|
|
| return prediction_dict, similar_images |
|
|
|
|
| def change_output(choice): |
| return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) |
|
|
| def get_cache_paths(name="demo"): |
| """Get paths for cached FAISS index and ID mapping.""" |
| return { |
| 'index': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo.index', repo_type="dataset"), |
| 'mapping': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo_mapping.json', repo_type="dataset") |
| } |
|
|
| def build_name_mapping(txt_names): |
| """Build mapping between scientific names and common names.""" |
| name_mapping = {} |
| for taxonomy, common_name in txt_names: |
| if not common_name: |
| continue |
| if len(taxonomy) >= 2: |
| scientific_name = f"{taxonomy[-2]} {taxonomy[-1]}".lower() |
| common_name = common_name.lower() |
| name_mapping[scientific_name] = (scientific_name, common_name) |
| name_mapping[common_name] = (scientific_name, common_name) |
| return name_mapping |
|
|
| def load_faiss_index(): |
| """Load FAISS index from cache.""" |
| cache_paths = get_cache_paths() |
| logger.info("Loading FAISS index from cache...") |
| index = faiss.read_index(cache_paths['index']) |
| with open(cache_paths['mapping'], 'r') as f: |
| id_mapping = json.load(f) |
| return index, id_mapping |
| |
| def get_similar_images_metadata(img_embedding, faiss_index, id_mapping, name_mapping): |
| """Get metadata for similar images using FAISS search.""" |
| img_embedding_np = img_embedding.cpu().numpy() |
| if img_embedding_np.ndim == 1: |
| img_embedding_np = img_embedding_np.reshape(1, -1) |
| |
| |
| distances, indices = faiss_index.search(img_embedding_np, TOP_N_SIMILAR * 2) |
| |
| |
| valid_indices = [] |
| valid_distances = [] |
| valid_count = 0 |
| |
| for dist, idx in zip(distances[0], indices[0]): |
| |
| similarity = dist |
| if similarity > SIMILARITY_THRESHOLD: |
| continue |
| |
| valid_indices.append(idx) |
| valid_distances.append(similarity) |
| valid_count += 1 |
| |
| if valid_count >= TOP_N_SIMILAR: |
| break |
| |
| species_votes = {} |
| similar_images = [] |
| |
| for idx, similarity in zip(valid_indices[:5], valid_distances[:5]): |
| similar_img_id = id_mapping[idx] |
|
|
| try: |
| species_names = id_to_species_info.get(similar_img_id) |
| species_names = [name for name in species_names if name] |
| |
| processed_names = set() |
| for species in species_names: |
| if not species: |
| continue |
| name_tuple = name_mapping.get(species) |
| if name_tuple: |
| processed_names.add(name_tuple[0]) |
| else: |
| processed_names.add(species) |
| |
| for species in processed_names: |
| species_votes[species] = species_votes.get(species, 0) + 1 |
| |
| |
| |
| similar_images.append({ |
| 'id': similar_img_id, |
| 'species': next(iter(processed_names)) if processed_names else 'Unknown', |
| 'common_name': species_names[-1], |
| 'similarity': similarity |
| }) |
| |
| except Exception as e: |
| logger.error(f"Error processing JSON for image {similar_img_id}: {e}") |
| continue |
| |
| return species_votes, similar_images |
|
|
|
|
| if __name__ == "__main__": |
| logger.info("Starting.") |
| model = create_model(model_str, output_dict=True, require_pretrained=True) |
| model = model.to(device) |
| logger.info("Created model.") |
|
|
| model = torch.compile(model) |
| logger.info("Compiled model.") |
|
|
| tokenizer = get_tokenizer(tokenizer_str) |
|
|
| id_to_photo_url = json.load(open(PHOTO_LOOKUP_PATH)) |
| id_to_species_info = json.load(open(SPECIES_LOOKUP_PATH)) |
| logger.info(f"Loaded {len(id_to_photo_url)} photo mappings") |
| logger.info(f"Loaded {len(id_to_species_info)} species mappings") |
| |
| txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device) |
| with open(txt_names_json) as fd: |
| txt_names = json.load(fd) |
| |
| |
| name_mapping = build_name_mapping(txt_names) |
| |
| |
| faiss_index, id_mapping = load_faiss_index() |
|
|
| |
| def process_output(img, rank): |
| predictions, similar_imgs = open_domain_classification(img, rank) |
| |
| logger.info(f"Number of similar images found: {len(similar_imgs)}") |
| |
| images = [] |
| labels = [] |
| |
| for img_info in similar_imgs: |
| img_id = img_info['id'] |
| img_url = id_to_photo_url.get(img_id) |
| img_url = img_url.replace("square", "small") |
| logger.info(f"Processing image URL: {img_url}") |
| |
| try: |
| |
| response = requests.get(img_url) |
| if response.status_code == 200: |
| try: |
| img = Image.open(io.BytesIO(response.content)) |
| images.append(img) |
| except Exception as e: |
| logger.info(f"Failed to load image from URL: {e}") |
| images.append(None) |
| else: |
| logger.info(f"Failed to fetch image from URL: {response}") |
| images.append(None) |
| |
| |
| label = f"**{img_info['species']}**" |
| if img_info['common_name']: |
| label += f" ({img_info['common_name']})" |
| label += f"\nSimilarity: {img_info['similarity']:.3f}" |
| label += f"\n[View on iNaturalist](https://www.inaturalist.org/observations/{img_id})" |
| labels.append(label) |
| |
| except Exception as e: |
| logger.error(f"Error processing image {img_id}: {e}") |
| images.append(None) |
| labels.append("") |
|
|
| |
| images += [None] * (5 - len(images)) |
| labels += [""] * (5 - len(labels)) |
| |
| logger.info(f"Final number of images: {len(images)}") |
| logger.info(f"Final number of labels: {len(labels)}") |
| |
| return [predictions] + images + labels |
|
|
| with gr.Blocks(theme=theme) as app: |
| |
| with gr.Row(variant="panel"): |
| with gr.Column(scale=1): |
| gr.Image("image.jpg", elem_id="logo-img", |
| show_label=False ) |
| with gr.Column(scale=30): |
| gr.Markdown("""Biome is a vision foundation model-powered tool customized to identify Singapore's local biodiversity. |
| <br/> <br/> |
| **Developed by**: Pye Sone Kyaw - AI Engineer @ Multimodal AI Team - AI Practice - GovTech SG |
| <br/> <br/> |
| Under the hood, Biome is using [BioCLIP](https://github.com/Imageomics/BioCLIP) augmented with multimodal search and retrieval to enhance its Singapore-specific biodiversity classification capabilities. |
| <br/> <br/> |
| Biome work best when the organism is clearly visible and takes up a substantial part of the image. |
| """) |
|
|
| with gr.Row(variant="panel", elem_id="images_panel"): |
| img_input = gr.Image( |
| height=400, |
| sources=["upload"], |
| type="pil" |
| ) |
| |
| |
|
|
| with gr.Row(): |
| |
| with gr.Column(): |
| with gr.Row(): |
| gr.Examples( |
| examples=example_images, |
| inputs=img_input, |
| label="Example Images" |
| ) |
| rank_dropdown = gr.Dropdown( |
| label="Taxonomic Rank", |
| info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", |
| choices=ranks, |
| value="Species", |
| type="index", |
| ) |
| open_domain_btn = gr.Button("Submit", variant="primary") |
| with gr.Column(): |
| open_domain_output = gr.Label( |
| num_top_classes=k, |
| label="Prediction", |
| show_label=True, |
| value=None, |
| ) |
| |
| |
| with gr.Row(variant="panel"): |
| with gr.Column(): |
| gr.Markdown("### Most Similar Images from Database") |
| |
| with gr.Row(): |
| similar_images = [ |
| gr.Image(label="Similar Image 1", height=200, show_label=True), |
| gr.Image(label="Similar Image 2", height=200, show_label=True), |
| gr.Image(label="Similar Image 3", height=200, show_label=True), |
| gr.Image(label="Similar Image 4", height=200, show_label=True), |
| gr.Image(label="Similar Image 5", height=200, show_label=True), |
| ] |
| |
| with gr.Row(): |
| similar_labels = [ |
| gr.Markdown("Species 1"), |
| gr.Markdown("Species 2"), |
| gr.Markdown("Species 3"), |
| gr.Markdown("Species 4"), |
| gr.Markdown("Species 5"), |
| ] |
| |
| rank_dropdown.change( |
| fn=change_output, |
| inputs=rank_dropdown, |
| outputs=[open_domain_output] |
| ) |
|
|
| open_domain_btn.click( |
| fn=process_output, |
| inputs=[img_input, rank_dropdown], |
| outputs=[open_domain_output] + similar_images + similar_labels, |
| ) |
|
|
| with gr.Row(variant="panel"): |
| gr.Markdown(""" |
| **Disclaimer**: This is a proof-of-concept demo for non-commercial purposes. No data is stored or used for any form of training, and all data used for retrieval are from [iNaturalist](https://inaturalist.org/). |
| The adage of garbage in, garbage out applies here - uploading images not biodiversity-related will yield unpredictable results. |
| """) |
| app.queue(max_size=20) |
| app.launch(share=False, enable_monitoring=False, allowed_paths=["/app/"]) |