| import os |
| import torch |
| from imagebind import data |
| from imagebind.models import imagebind_model |
| from imagebind.models.imagebind_model import ModalityType |
| from pydub import AudioSegment |
| from fastapi import FastAPI, UploadFile, File, Form |
| from typing import List, Dict |
| import tempfile |
| from pydantic import BaseModel |
| import uvicorn |
| import numpy as np |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi import Depends, HTTPException, status |
|
|
| app = FastAPI() |
|
|
| |
| security = HTTPBearer() |
| API_TOKEN = os.getenv("API_TOKEN", "your-default-token-here") |
|
|
| |
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if credentials.credentials != API_TOKEN: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authentication token", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| return credentials.credentials |
|
|
| def convert_audio_to_wav(audio_path: str) -> str: |
| """Convert MP3 to WAV if necessary.""" |
| if audio_path.lower().endswith('.mp3'): |
| wav_path = audio_path.rsplit('.', 1)[0] + '.wav' |
| if not os.path.exists(wav_path): |
| audio = AudioSegment.from_mp3(audio_path) |
| audio.export(wav_path, format='wav') |
| return wav_path |
| return audio_path |
|
|
| class EmbeddingManager: |
| def __init__(self): |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| self.model = imagebind_model.imagebind_huge(pretrained=True) |
| self.model.eval() |
| self.model.to(self.device) |
| |
| def compute_embeddings(self, |
| images: List[str] = None, |
| audio_files: List[str] = None, |
| texts: List[str] = None) -> dict: |
| """Compute embeddings for provided modalities only.""" |
| with torch.no_grad(): |
| inputs = {} |
| |
| if texts: |
| inputs[ModalityType.TEXT] = data.load_and_transform_text(texts, self.device) |
| if images: |
| inputs[ModalityType.VISION] = data.load_and_transform_vision_data(images, self.device) |
| if audio_files: |
| inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data(audio_files, self.device) |
| |
| if not inputs: |
| return {} |
| |
| embeddings = self.model(inputs) |
| |
| result = {} |
| if ModalityType.VISION in inputs: |
| result['vision'] = embeddings[ModalityType.VISION].cpu().numpy().tolist() |
| if ModalityType.AUDIO in inputs: |
| result['audio'] = embeddings[ModalityType.AUDIO].cpu().numpy().tolist() |
| if ModalityType.TEXT in inputs: |
| result['text'] = embeddings[ModalityType.TEXT].cpu().numpy().tolist() |
| |
| return result |
|
|
| @staticmethod |
| def compute_similarities(embeddings: Dict[str, List[List[float]]]) -> dict: |
| """Compute similarities between available embeddings.""" |
| similarities = {} |
| |
| |
| tensors = { |
| k: torch.tensor(v) for k, v in embeddings.items() |
| if isinstance(v, (list, np.ndarray)) and len(v) > 0 |
| } |
| |
| |
| modality_pairs = [ |
| ('vision', 'audio', 'vision_audio'), |
| ('vision', 'text', 'vision_text'), |
| ('audio', 'text', 'audio_text') |
| ] |
| |
| for mod1, mod2, key in modality_pairs: |
| if mod1 in tensors and mod2 in tensors: |
| similarities[key] = torch.softmax( |
| tensors[mod1] @ tensors[mod2].T, |
| dim=-1 |
| ).numpy().tolist() |
| |
| |
| for modality in ['vision', 'audio', 'text']: |
| if modality in tensors: |
| key = f'{modality}_{modality}' |
| similarities[key] = torch.softmax( |
| tensors[modality] @ tensors[modality].T, |
| dim=-1 |
| ).numpy().tolist() |
| |
| return similarities |
|
|
| |
| embedding_manager = EmbeddingManager() |
|
|
| class EmbeddingResponse(BaseModel): |
| embeddings: dict |
| file_names: dict |
|
|
| class SimilarityRequest(BaseModel): |
| embeddings: Dict[str, List[List[float]]] |
| threshold: float = 0.5 |
| top_k: int | None = None |
| include_self_similarity: bool = False |
| normalize_scores: bool = True |
|
|
| class SimilarityMatch(BaseModel): |
| index_a: int |
| index_b: int |
| score: float |
| modality_a: str |
| modality_b: str |
| item_a: str |
| item_b: str |
|
|
| class SimilarityResponse(BaseModel): |
| matches: List[SimilarityMatch] |
| statistics: Dict[str, float] |
| modality_pairs: List[str] |
|
|
| class ModalityPair: |
| def __init__(self, mod1: str, mod2: str): |
| self.mod1 = min(mod1, mod2) |
| self.mod2 = max(mod1, mod2) |
| |
| def __str__(self): |
| return f"{self.mod1}_to_{self.mod2}" |
|
|
| def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor: |
| """Compute cosine similarity between two sets of embeddings.""" |
| |
| if normalize: |
| tensor1 = torch.nn.functional.normalize(tensor1, dim=1) |
| tensor2 = torch.nn.functional.normalize(tensor2, dim=1) |
| |
| |
| similarity = torch.matmul(tensor1, tensor2.T) |
| |
| return similarity |
|
|
| def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]: |
| """Get top-k matches from a similarity matrix.""" |
| if top_k is None: |
| top_k = similarity_matrix.numel() |
| |
| |
| flat_sim = similarity_matrix.flatten() |
| top_k = min(top_k, flat_sim.numel()) |
| values, indices = torch.topk(flat_sim, k=top_k) |
| |
| |
| rows = indices // similarity_matrix.size(1) |
| cols = indices % similarity_matrix.size(1) |
| |
| return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)] |
|
|
| @app.post("/compute_embeddings", response_model=EmbeddingResponse) |
| async def generate_embeddings( |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token), |
| texts: str | None = Form(None), |
| images: List[UploadFile] | None = File(default=None), |
| audio_files: List[UploadFile] | None = File(default=None) |
| ): |
| """Generate embeddings for any provided files and texts.""" |
| temp_files = [] |
| |
| try: |
| image_paths = [] |
| image_names = [] |
| audio_paths = [] |
| audio_names = [] |
| text_list = [] |
| |
| |
| if images: |
| for img in images: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(img.filename)[1]) as tmp: |
| content = await img.read() |
| tmp.write(content) |
| image_paths.append(tmp.name) |
| image_names.append(img.filename) |
| temp_files.append(tmp.name) |
| |
| |
| if audio_files: |
| for audio in audio_files: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp: |
| content = await audio.read() |
| tmp.write(content) |
| audio_path = convert_audio_to_wav(tmp.name) |
| audio_paths.append(audio_path) |
| audio_names.append(audio.filename) |
| temp_files.append(tmp.name) |
| if audio_path != tmp.name: |
| temp_files.append(audio_path) |
| |
| |
| if texts: |
| text_list = [text.strip() for text in texts.split('\n') if text.strip()] |
| |
| |
| if not any([image_paths, audio_paths, text_list]): |
| return EmbeddingResponse( |
| embeddings={}, |
| file_names={} |
| ) |
| |
| embeddings = embedding_manager.compute_embeddings( |
| image_paths if image_paths else None, |
| audio_paths if audio_paths else None, |
| text_list if text_list else None |
| ) |
| |
| file_names = {} |
| if image_names: |
| file_names['images'] = image_names |
| if audio_names: |
| file_names['audio'] = audio_names |
| if text_list: |
| file_names['texts'] = text_list |
| |
| return EmbeddingResponse( |
| embeddings=embeddings, |
| file_names=file_names |
| ) |
| |
| finally: |
| |
| for temp_file in temp_files: |
| try: |
| os.unlink(temp_file) |
| except: |
| pass |
|
|
| @app.post("/compute_similarities", response_model=SimilarityResponse) |
| async def compute_similarities( |
| request: SimilarityRequest, |
| file_names: Dict[str, List[str]], |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token) |
| ): |
| """ |
| Compute cross-modal similarities with advanced filtering and matching options. |
| |
| Parameters: |
| - embeddings: Dict mapping modality to embedding tensors |
| - threshold: Minimum similarity score to include in results |
| - top_k: Maximum number of matches to return (per modality pair) |
| - include_self_similarity: Whether to include same-item comparisons |
| - normalize_scores: Whether to normalize embeddings before comparison |
| - file_names: Dict mapping modality to list of original file/text names |
| """ |
| |
| matches = [] |
| statistics = { |
| "avg_score": 0.0, |
| "max_score": 0.0, |
| "min_score": 1.0, |
| "total_comparisons": 0 |
| } |
| |
| |
| tensors = { |
| k: torch.tensor(v) for k, v in request.embeddings.items() |
| if isinstance(v, (list, np.ndarray)) and len(v) > 0 |
| } |
| |
| modality_pairs = [] |
| all_scores = [] |
| |
| |
| modalities = list(tensors.keys()) |
| for i, mod1 in enumerate(modalities): |
| for mod2 in modalities[i:]: |
| if mod1 == mod2 and not request.include_self_similarity: |
| continue |
| |
| pair = ModalityPair(mod1, mod2) |
| modality_pairs.append(str(pair)) |
| |
| |
| sim_matrix = compute_similarity_matrix( |
| tensors[mod1], |
| tensors[mod2], |
| normalize=request.normalize_scores |
| ) |
| |
| |
| top_matches = get_top_k_matches(sim_matrix, request.top_k) |
| |
| |
| for idx_a, idx_b, score in top_matches: |
| if score < request.threshold: |
| continue |
| |
| |
| if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity: |
| continue |
| |
| matches.append(SimilarityMatch( |
| index_a=idx_a, |
| index_b=idx_b, |
| score=float(score), |
| modality_a=mod1, |
| modality_b=mod2, |
| item_a=file_names[mod1][idx_a], |
| item_b=file_names[mod2][idx_b] |
| )) |
| all_scores.append(score) |
| |
| |
| if all_scores: |
| statistics.update({ |
| "avg_score": float(np.mean(all_scores)), |
| "max_score": float(np.max(all_scores)), |
| "min_score": float(np.min(all_scores)), |
| "total_comparisons": len(all_scores) |
| }) |
| |
| |
| matches.sort(key=lambda x: x.score, reverse=True) |
| |
| return SimilarityResponse( |
| matches=matches, |
| statistics=statistics, |
| modality_pairs=modality_pairs |
| ) |
|
|
| @app.get("/health") |
| async def health_check( |
| credentials: HTTPAuthorizationCredentials = Depends(verify_token) |
| ): |
| """Basic healthcheck endpoint that returns the status of the service.""" |
| return { |
| "status": "healthy", |
| "model_device": embedding_manager.device |
| } |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |