| import openai |
| import base64 |
| from pathlib import Path |
| import random |
| import os |
|
|
|
|
|
|
| evaluation_prompts = { |
| "identity": """ |
| Compare the original subject image with the generated image. |
| Rate on a scale of 1-5 how well the essential identifying features |
| are preserved (logos, brand marks, distinctive patterns). |
| Score: [1-5] |
| Reasoning: [explanation] |
| """, |
| |
| "material": """ |
| Evaluate the material quality and surface characteristics. |
| Rate on a scale of 1-5 how accurately materials are represented |
| (textures, reflections, surface properties). |
| Score: [1-5] |
| Reasoning: [explanation] |
| """, |
| |
| "color": """ |
| Assess color fidelity in regions NOT specified for modification. |
| Rate on a scale of 1-5 how consistent colors remain. |
| Score: [1-5] |
| Reasoning: [explanation] |
| """, |
| |
| "appearance": """ |
| Evaluate the overall realism and coherence of the generated image. |
| Rate on a scale of 1-5 how realistic and natural it appears. |
| Score: [1-5] |
| Reasoning: [explanation] |
| """, |
| |
| "modification": """ |
| Given the text prompt: "{prompt}" |
| Rate on a scale of 1-5 how well the specified changes are executed. |
| Score: [1-5] |
| Reasoning: [explanation] |
| """ |
| } |
|
|
|
|
| def encode_image(image_path): |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
| def evaluate_subject_driven_generation( |
| original_image_path, |
| generated_image_path, |
| text_prompt, |
| client |
| ): |
| """ |
| Evaluate a subject-driven generation using GPT-4o vision |
| """ |
| |
| |
| original_img = encode_image(original_image_path) |
| generated_img = encode_image(generated_image_path) |
| |
| results = {} |
| |
| |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "Original subject image:"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
| {"type": "text", "text": "Generated image:"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| {"type": "text", "text": evaluation_prompts["identity"]} |
| ] |
| }], |
| max_tokens=300 |
| ) |
| results['identity'] = parse_score(response.choices[0].message.content) |
| |
| |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "Evaluate this generated image:"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| {"type": "text", "text": evaluation_prompts["material"]} |
| ] |
| }], |
| max_tokens=300 |
| ) |
| results['material'] = parse_score(response.choices[0].message.content) |
| |
| |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "Original:"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
| {"type": "text", "text": "Generated:"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| {"type": "text", "text": evaluation_prompts["color"]} |
| ] |
| }], |
| max_tokens=300 |
| ) |
| results['color'] = parse_score(response.choices[0].message.content) |
| |
| |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| {"type": "text", "text": evaluation_prompts["appearance"]} |
| ] |
| }], |
| max_tokens=300 |
| ) |
| results['appearance'] = parse_score(response.choices[0].message.content) |
| |
| |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": f"Text prompt: {text_prompt}"}, |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| {"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)} |
| ] |
| }], |
| max_tokens=300 |
| ) |
| results['modification'] = parse_score(response.choices[0].message.content) |
| |
| return results |
|
|
| def parse_score(response_text): |
| """Extract score from GPT-4o response""" |
| |
| import re |
| match = re.search(r'Score:\s*(\d+)', response_text) |
| if match: |
| return int(match.group(1)) |
| return None |
|
|
| subject_names = [ |
| "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can", |
| "candle", "cat", "cat2", "clock", "colorful_sneaker", |
| "dog", "dog2", "dog3", "dog5", "dog6", |
| "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie", |
| "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", |
| "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie" |
| ] |
|
|
|
|
| def get_prompt(subject_id, prompt_id): |
| |
| |
| if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]: |
| subject_name = subject_names[subject_id] |
| prompts = [ |
| f"a {subject_name} in the jungle", |
| f"a {subject_name} in the snow", |
| f"a {subject_name} on the beach", |
| f"a {subject_name} on a cobblestone street", |
| f"a {subject_name} on top of pink fabric", |
| f"a {subject_name} on top of a wooden floor", |
| f"a {subject_name} with a city in the background", |
| f"a {subject_name} with a mountain in the background", |
| f"a {subject_name} with a blue house in the background", |
| f"a {subject_name} on top of a purple rug in a forest", |
| f"a {subject_name} with a wheat field in the background", |
| f"a {subject_name} with a tree and autumn leaves in the background", |
| f"a {subject_name} with the Eiffel Tower in the background", |
| f"a {subject_name} floating on top of water", |
| f"a {subject_name} floating in an ocean of milk", |
| f"a {subject_name} on top of green grass with sunflowers around it", |
| f"a {subject_name} on top of a mirror", |
| f"a {subject_name} on top of the sidewalk in a crowded street", |
| f"a {subject_name} on top of a dirt road", |
| f"a {subject_name} on top of a white rug", |
| f"a red {subject_name}", |
| f"a purple {subject_name}", |
| f"a shiny {subject_name}", |
| f"a wet {subject_name}", |
| f"a cube shaped {subject_name}" |
| ] |
| |
| else: |
| prompts = [ |
| f"a {subject_name} in the jungle", |
| f"a {subject_name} in the snow", |
| f"a {subject_name} on the beach", |
| f"a {subject_name} on a cobblestone street", |
| f"a {subject_name} on top of pink fabric", |
| f"a {subject_name} on top of a wooden floor", |
| f"a {subject_name} with a city in the background", |
| f"a {subject_name} with a mountain in the background", |
| f"a {subject_name} with a blue house in the background", |
| f"a {subject_name} on top of a purple rug in a forest", |
| f"a {subject_name} wearing a red hat", |
| f"a {subject_name} wearing a santa hat", |
| f"a {subject_name} wearing a rainbow scarf", |
| f"a {subject_name} wearing a black top hat and a monocle", |
| f"a {subject_name} in a chef outfit", |
| f"a {subject_name} in a firefighter outfit", |
| f"a {subject_name} in a police outfit", |
| f"a {subject_name} wearing pink glasses", |
| f"a {subject_name} wearing a yellow shirt", |
| f"a {subject_name} in a purple wizard outfit", |
| f"a red {subject_name}", |
| f"a purple {subject_name}", |
| f"a shiny {subject_name}", |
| f"a wet {subject_name}", |
| f"a cube shaped {subject_name}" |
| ] |
| |
| return prompts[prompt_id] |
|
|
|
|
|
|
|
|
|
|
| def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv): |
| """ |
| Evaluate 750 image pairs with 5 seeds each |
| """ |
| import pandas as pd |
| |
| results_list = [] |
| |
| |
| for subject_id in range(30): |
| subject_name = subject_names[subject_id] |
| for prompt_id in range(25): |
| original = f"{dataset_path}/{subject_name}" |
| |
| original_files = list(Path(original).glob("*.png")) |
| if len(original_files) == 0: |
| raise ValueError(f"No original images found in {original}") |
| |
| original = str(original_files[0]) |
|
|
| |
| for seed in range(5): |
| |
| prompt = get_prompt(subject_id, prompt_id) |
| |
| |
| generated_folder = f"{dataset_path}/{subject_name}/generated/" |
| os.makedirs(generated_folder, exist_ok=True) |
| generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png" |
| |
| generate_fn( |
| prompt=prompt, |
| subject_image_path=original, |
| output_image_path=generated, |
| seed=seed |
| ) |
| |
| scores = evaluate_subject_driven_generation( |
| original, generated, prompt, client |
| ) |
| |
| results_list.append({ |
| 'subject_id': subject_id, |
| 'subject_name': subject_name, |
| 'prompt_id': prompt_id, |
| 'seed': seed, |
| 'prompt': prompt, |
| |
| **scores |
| }) |
| |
| |
| df = pd.DataFrame(results_list) |
| df.to_csv(output_csv, index=False) |
| |
| |
| print(df.groupby('subject_id').mean()) |
| print(f"\nOverall averages:") |
| print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean()) |
| |
| |
| def evaluate_omini_control(): |
| |
| import torch |
| from diffusers.pipelines import FluxPipeline |
| from PIL import Image |
| |
| from omini.pipeline.flux_omini import Condition, generate, seed_everything |
|
|
| pipe = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 |
| ) |
| |
| pipe = pipe.to("cuda") |
| pipe.load_lora_weights( |
| "Yuanshi/OminiControl", |
| weight_name=f"omini/subject_512.safetensors", |
| adapter_name="subject", |
| ) |
| |
| def generate_fn(image_path, prompt, seed, output_path): |
| seed_everything(seed) |
| |
| image = Image.open(image_path).convert("RGB").resize((512, 512)) |
| condition = Condition.from_image( |
| image, |
| "subject", position_delta=(0, 32) |
| ) |
| |
| result_img = generate( |
| pipe, |
| prompt=prompt, |
| conditions=[condition], |
| ).images[0] |
| |
| result_img.save(output_path) |
| |
| return generate_fn |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| openai.api_key = os.getenv("OPENAI_API_KEY") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| result = evaluate_subject_driven_generation( |
| "data/dreambooth/backpack/00.jpg", |
| "data/dreambooth/backpack/01.jpg", |
| "a backpack in the jungle", |
| openai.Client() |
| ) |
| |
| print(result) |