| import os |
| import io |
| import torch |
| import pandas as pd |
| import gradio as gr |
| from PIL import Image |
| from sd_parsers import ParserManager |
| from torchvision import transforms |
| from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig |
| import lpips |
| import piq |
| import plotly.express as px |
|
|
| |
| |
| |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
| |
| blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") |
| if torch.cuda.is_available(): |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) |
| blip_model = Blip2ForConditionalGeneration.from_pretrained( |
| "Salesforce/blip2-flan-t5-xl", |
| quantization_config=bnb_config, |
| device_map="auto" |
| ) |
| else: |
| blip_model = Blip2ForConditionalGeneration.from_pretrained( |
| "Salesforce/blip2-flan-t5-xl", |
| torch_dtype=torch.float16 |
| ).to(device) |
|
|
| |
| lpips_model = lpips.LPIPS(net='alex').to(device) |
|
|
| |
| |
| |
|
|
| def extract_metadata(file): |
| """Extract prompt and model name using sd-parsers from file path.""" |
| parser = ParserManager() |
| info = parser.parse(file.name) |
| prompt = info.prompts[0].value if info.prompts else '' |
| |
| model_name = '' |
| if hasattr(info, 'models') and info.models: |
| |
| first = next(iter(info.models)) |
| model_name = first.name if hasattr(first, 'name') else str(first) |
| return prompt, model_name |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.48145466, 0.4578275, 0.40821073), |
| (0.26862954, 0.26130258, 0.27577711) |
| ) |
| ]) |
|
|
| |
| |
| |
|
|
| def compute_clip_score(img: Image.Image, text: str) -> float: |
| inputs = clip_processor(text=[text], images=img, return_tensors="pt", padding=True).to(device) |
| outputs = clip_model(**inputs) |
| score = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds) |
| return float((score.clamp(min=0) * 100).mean()) |
|
|
| @torch.no_grad() |
| def compute_caption_similarity(img: Image.Image, prompt: str) -> float: |
| inputs = blip_processor(images=img, return_tensors="pt").to(device) |
| out = blip_model.generate(**inputs) |
| caption = blip_processor.decode(out[0], skip_special_tokens=True) |
| return compute_clip_score(img, caption) |
|
|
| @torch.no_grad() |
| def compute_iqa_metrics(img: Image.Image): |
| tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) |
| brisque = float(piq.brisque(tensor).cpu()) |
| niqe = float(piq.niqe(tensor).cpu()) |
| return brisque, niqe |
|
|
| @torch.no_grad() |
| def compute_lpips_pair(img1: Image.Image, img2: Image.Image) -> float: |
| t1 = transforms.ToTensor()(img1).unsqueeze(0).to(device) |
| t2 = transforms.ToTensor()(img2).unsqueeze(0).to(device) |
| return float(lpips_model(t1, t2).cpu()) |
|
|
| |
| |
| |
|
|
| def analyze_images(files): |
| records = [] |
| imgs_by_model = {} |
|
|
| for f in files: |
| img = Image.open(f.name).convert('RGB') |
| prompt, model = extract_metadata(f) |
|
|
| cs = compute_clip_score(img, prompt) |
| cap_sim = compute_caption_similarity(img, prompt) |
| brisque, niqe = compute_iqa_metrics(img) |
| aesthetic = compute_clip_score(img, "a beautiful high quality image") |
|
|
| records.append({ |
| 'model': model, |
| 'prompt': prompt, |
| 'clip_score': cs, |
| 'caption_sim': cap_sim, |
| 'brisque': brisque, |
| 'niqe': niqe, |
| 'aesthetic': aesthetic |
| }) |
| imgs_by_model.setdefault(model, []).append(img) |
|
|
| df = pd.DataFrame(records) |
|
|
| diversity = {} |
| for model, imgs in imgs_by_model.items(): |
| if len(imgs) < 2: |
| diversity[model] = 0.0 |
| else: |
| pairs = [compute_lpips_pair(imgs[i], imgs[j]) |
| for i in range(len(imgs)) for j in range(i+1, len(imgs))] |
| diversity[model] = sum(pairs) / len(pairs) |
|
|
| agg = df.groupby('model').agg( |
| clip_score_mean=('clip_score', 'mean'), |
| caption_sim_mean=('caption_sim', 'mean'), |
| brisque_mean=('brisque', 'mean'), |
| niqe_mean=('niqe', 'mean'), |
| aesthetic_mean=('aesthetic', 'mean') |
| ).reset_index() |
| agg['diversity'] = agg['model'].map(diversity) |
|
|
| return df, agg |
|
|
| |
| |
| |
|
|
| def plot_metrics(agg: pd.DataFrame): |
| return px.bar( |
| agg, |
| x='model', |
| y=['aesthetic_mean', 'clip_score_mean', 'caption_sim_mean', 'diversity'], |
| barmode='group', |
| title='Сравнение моделей по метрикам' |
| ) |
|
|
| |
| |
| |
|
|
| def run_analysis(files): |
| df, agg = analyze_images(files) |
| fig = plot_metrics(agg) |
| return df, fig |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# AI Image Quality Evaluator") |
| gr.Markdown("Загрузите PNG-изображения (с EXIF-метаданными SD) для анализа и сравнения моделей.") |
|
|
| with gr.Row(): |
| input_files = gr.File(file_count="multiple", label="Выберите PNG файлы") |
| output_table = gr.Dataframe( |
| headers=[ |
| "model", "clip_score_mean", "caption_sim_mean", "brisque_mean", |
| "niqe_mean", "aesthetic_mean", "diversity" |
| ], |
| label="Сводная таблица" |
| ) |
|
|
| plot_output = gr.Plot(label="График метрик") |
|
|
| run_btn = gr.Button("Запустить анализ") |
| run_btn.click(run_analysis, inputs=[input_files], outputs=[output_table, plot_output]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name='0.0.0.0', share=False) |
|
|