| |
| import gradio as gr |
| import torch |
| import torchaudio |
| from transformers import ( |
| pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, |
| AutoImageProcessor, AutoModelForObjectDetection, |
| BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, |
| VitsModel, AutoTokenizer |
| ) |
| from PIL import Image, ImageDraw |
| import requests |
| import numpy as np |
| import soundfile as sf |
| from gtts import gTTS |
| import tempfile |
| import os |
| from sentence_transformers import SentenceTransformer |
|
|
| |
| models = {} |
|
|
| def load_audio_model(model_name): |
| if model_name not in models: |
| if model_name == "whisper": |
| models[model_name] = pipeline( |
| "automatic-speech-recognition", |
| model="openai/whisper-small" |
| ) |
| elif model_name == "wav2vec2": |
| models[model_name] = pipeline( |
| "automatic-speech-recognition", |
| model="bond005/wav2vec2-large-ru-golos" |
| ) |
| elif model_name == "audio_classifier": |
| models[model_name] = pipeline( |
| "audio-classification", |
| model="MIT/ast-finetuned-audioset-10-10-0.4593" |
| ) |
| elif model_name == "emotion_classifier": |
| models[model_name] = pipeline( |
| "audio-classification", |
| model="superb/hubert-large-superb-er" |
| ) |
| return models[model_name] |
|
|
| def load_image_model(model_name): |
| if model_name not in models: |
| if model_name == "object_detection": |
| models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") |
| elif model_name == "segmentation": |
| models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") |
| elif model_name == "captioning": |
| models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
| elif model_name == "vqa": |
| models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") |
| elif model_name == "clip": |
| models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| return models[model_name] |
|
|
| |
| def audio_classification(audio_file, model_type): |
| classifier = load_audio_model(model_type) |
| results = classifier(audio_file) |
| |
| output = "Топ-5 предсказаний:\n" |
| for i, result in enumerate(results[:5]): |
| output += f"{i+1}. {result['label']}: {result['score']:.4f}\n" |
| |
| return output |
|
|
| def speech_recognition(audio_file, model_type): |
| asr_pipeline = load_audio_model(model_type) |
| |
| if model_type == "whisper": |
| result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"}) |
| else: |
| result = asr_pipeline(audio_file) |
| |
| return result['text'] |
|
|
| def text_to_speech(text, model_type): |
| if model_type == "silero": |
| |
| model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', |
| model='silero_tts', |
| language='ru', |
| speaker='ru_v3') |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name) |
| return f.name |
| |
| elif model_type == "gtts": |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| tts = gTTS(text=text, lang='ru') |
| tts.save(f.name) |
| return f.name |
| |
| elif model_type == "mms": |
| |
| model = VitsModel.from_pretrained("facebook/mms-tts-rus") |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") |
| |
| inputs = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| output = model(**inputs).waveform |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) |
| return f.name |
|
|
| |
| def object_detection(image): |
| detector = load_image_model("object_detection") |
| results = detector(image) |
| |
| |
| draw = ImageDraw.Draw(image) |
| for result in results: |
| box = result['box'] |
| label = result['label'] |
| score = result['score'] |
| |
| draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], |
| outline='red', width=3) |
| draw.text((box['xmin'], box['ymin']), |
| f"{label}: {score:.2f}", fill='red') |
| |
| return image |
|
|
| def image_segmentation(image): |
| segmenter = load_image_model("segmentation") |
| results = segmenter(image) |
| |
| |
| return results[0]['mask'] |
|
|
| def image_captioning(image): |
| captioner = load_image_model("captioning") |
| result = captioner(image) |
| return result[0]['generated_text'] |
|
|
| def visual_question_answering(image, question): |
| vqa_pipeline = load_image_model("vqa") |
| result = vqa_pipeline(image, question) |
| return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" |
|
|
| def zero_shot_classification(image, classes): |
| model = load_image_model("clip") |
| processor = models["clip_processor"] |
| |
| class_list = [cls.strip() for cls in classes.split(",")] |
| |
| inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits_per_image = outputs.logits_per_image |
| probs = logits_per_image.softmax(dim=1) |
| |
| result = "Zero-Shot Classification Results:\n" |
| for i, cls in enumerate(class_list): |
| result += f"{cls}: {probs[0][i].item():.4f}\n" |
| |
| return result |
|
|
| def image_retrieval(images, query): |
| if not images or not query: |
| return "Пожалуйста, загрузите изображения и введите запрос" |
| |
| |
| model = load_image_model("clip") |
| processor = models["clip_processor"] |
| |
| |
| image_inputs = processor(images=images, return_tensors="pt", padding=True) |
| with torch.no_grad(): |
| image_embeddings = model.get_image_features(**image_inputs) |
| image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) |
| |
| |
| text_inputs = processor(text=[query], return_tensors="pt", padding=True) |
| with torch.no_grad(): |
| text_embeddings = model.get_text_features(**text_inputs) |
| text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) |
| |
| |
| similarities = (image_embeddings @ text_embeddings.T) |
| |
| |
| best_idx = similarities.argmax().item() |
| best_score = similarities[best_idx].item() |
| |
| return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx] |
|
|
| |
| with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🎯 Мультимодальные AI модели") |
| gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers") |
| |
| with gr.Tab("🎵 Классификация аудио"): |
| gr.Markdown("## Zero-Shot Audio Classification") |
| with gr.Row(): |
| with gr.Column(): |
| audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath") |
| audio_model_dropdown = gr.Dropdown( |
| choices=["audio_classifier", "emotion_classifier"], |
| label="Выберите модель", |
| value="audio_classifier", |
| info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи" |
| ) |
| classify_btn = gr.Button("Классифицировать") |
| with gr.Column(): |
| audio_output = gr.Textbox(label="Результаты классификации", lines=10) |
| |
| classify_btn.click( |
| fn=audio_classification, |
| inputs=[audio_input, audio_model_dropdown], |
| outputs=audio_output |
| ) |
| |
| with gr.Tab("🗣️ Распознавание речи"): |
| gr.Markdown("## Automatic Speech Recognition (ASR)") |
| with gr.Row(): |
| with gr.Column(): |
| asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath") |
| asr_model_dropdown = gr.Dropdown( |
| choices=["whisper", "wav2vec2"], |
| label="Выберите модель", |
| value="whisper", |
| info="whisper - многоязычная, wav2vec2 - специализированная для русского" |
| ) |
| transcribe_btn = gr.Button("Транскрибировать") |
| with gr.Column(): |
| asr_output = gr.Textbox(label="Транскрипция", lines=5) |
| |
| transcribe_btn.click( |
| fn=speech_recognition, |
| inputs=[asr_audio_input, asr_model_dropdown], |
| outputs=asr_output |
| ) |
| |
| with gr.Tab("🔊 Синтез речи"): |
| gr.Markdown("## Text-to-Speech (TTS)") |
| with gr.Row(): |
| with gr.Column(): |
| tts_text_input = gr.Textbox( |
| label="Введите текст для синтеза", |
| placeholder="Введите текст на русском языке...", |
| lines=3 |
| ) |
| tts_model_dropdown = gr.Dropdown( |
| choices=["silero", "gtts", "mms"], |
| label="Выберите модель", |
| value="silero", |
| info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS" |
| ) |
| synthesize_btn = gr.Button("Синтезировать речь") |
| with gr.Column(): |
| tts_output = gr.Audio(label="Синтезированная речь") |
| |
| synthesize_btn.click( |
| fn=text_to_speech, |
| inputs=[tts_text_input, tts_model_dropdown], |
| outputs=tts_output |
| ) |
| |
| with gr.Tab("📦 Детекция объектов"): |
| gr.Markdown("## Object Detection") |
| with gr.Row(): |
| with gr.Column(): |
| obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") |
| detect_btn = gr.Button("Обнаружить объекты") |
| with gr.Column(): |
| obj_detection_output = gr.Image(label="Результат детекции") |
| |
| detect_btn.click( |
| fn=object_detection, |
| inputs=obj_detection_input, |
| outputs=obj_detection_output |
| ) |
| |
| with gr.Tab("🎨 Сегментация"): |
| gr.Markdown("## Image Segmentation") |
| with gr.Row(): |
| with gr.Column(): |
| seg_input = gr.Image(label="Загрузите изображение", type="pil") |
| segment_btn = gr.Button("Сегментировать") |
| with gr.Column(): |
| seg_output = gr.Image(label="Маска сегментации") |
| |
| segment_btn.click( |
| fn=image_segmentation, |
| inputs=seg_input, |
| outputs=seg_output |
| ) |
| |
| with gr.Tab("📝 Описание изображений"): |
| gr.Markdown("## Image Captioning") |
| with gr.Row(): |
| with gr.Column(): |
| caption_input = gr.Image(label="Загрузите изображение", type="pil") |
| caption_btn = gr.Button("Сгенерировать описание") |
| with gr.Column(): |
| caption_output = gr.Textbox(label="Описание изображения", lines=3) |
| |
| caption_btn.click( |
| fn=image_captioning, |
| inputs=caption_input, |
| outputs=caption_output |
| ) |
| |
| with gr.Tab("❓ Визуальные вопросы"): |
| gr.Markdown("## Visual Question Answering") |
| with gr.Row(): |
| with gr.Column(): |
| vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") |
| vqa_question_input = gr.Textbox( |
| label="Вопрос об изображении", |
| placeholder="Что происходит на этом изображении?", |
| lines=2 |
| ) |
| vqa_btn = gr.Button("Ответить на вопрос") |
| with gr.Column(): |
| vqa_output = gr.Textbox(label="Ответ", lines=3) |
| |
| vqa_btn.click( |
| fn=visual_question_answering, |
| inputs=[vqa_image_input, vqa_question_input], |
| outputs=vqa_output |
| ) |
| |
| with gr.Tab("🎯 Zero-Shot классификация"): |
| gr.Markdown("## Zero-Shot Image Classification") |
| with gr.Row(): |
| with gr.Column(): |
| zs_image_input = gr.Image(label="Загрузите изображение", type="pil") |
| zs_classes_input = gr.Textbox( |
| label="Классы для классификации (через запятую)", |
| placeholder="человек, машина, дерево, здание, животное", |
| lines=2 |
| ) |
| zs_classify_btn = gr.Button("Классифицировать") |
| with gr.Column(): |
| zs_output = gr.Textbox(label="Результаты классификации", lines=10) |
| |
| zs_classify_btn.click( |
| fn=zero_shot_classification, |
| inputs=[zs_image_input, zs_classes_input], |
| outputs=zs_output |
| ) |
| |
| with gr.Tab("🔍 Поиск изображений"): |
| gr.Markdown("## Image Retrieval") |
| with gr.Row(): |
| with gr.Column(): |
| retrieval_images_input = gr.Gallery( |
| label="Загрузите изображения для поиска", |
| type="pil" |
| ) |
| retrieval_query_input = gr.Textbox( |
| label="Текстовый запрос", |
| placeholder="описание того, что вы ищете...", |
| lines=2 |
| ) |
| retrieval_btn = gr.Button("Найти изображение") |
| with gr.Column(): |
| retrieval_output_text = gr.Textbox(label="Результат поиска") |
| retrieval_output_image = gr.Image(label="Найденное изображение") |
| |
| retrieval_btn.click( |
| fn=image_retrieval, |
| inputs=[retrieval_images_input, retrieval_query_input], |
| outputs=[retrieval_output_text, retrieval_output_image] |
| ) |
| |
| gr.Markdown("---") |
| gr.Markdown("### 📊 Поддерживаемые задачи:") |
| gr.Markdown(""" |
| - **🎵 Аудио**: Классификация, распознавание речи, синтез речи |
| - **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений |
| - **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |