| import os |
| import json |
| import torch |
| import multiprocessing as mp |
| from tqdm import tqdm |
| from hpsv3.inference import HPSv3RewardInferencer |
| import argparse |
| from collections import defaultdict |
| import glob |
| import numpy as np |
| from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer |
| from PIL import Image |
| import ImageReward as RM |
| from transformers import AutoProcessor, AutoModel |
| def initialize_model_hpsv2(device, cp): |
| model_dict = {} |
| model, preprocess_train, preprocess_val = create_model_and_transforms( |
| 'ViT-H-14', |
| 'laion2B-s32B-b79K', |
| precision='amp', |
| device=device, |
| jit=False, |
| force_quick_gelu=False, |
| force_custom_text=False, |
| force_patch_dropout=False, |
| force_image_size=None, |
| pretrained_image=False, |
| image_mean=None, |
| image_std=None, |
| light_augmentation=True, |
| aug_cfg={}, |
| output_dict=True, |
| with_score_predictor=False, |
| with_region_predictor=False |
| ) |
|
|
| checkpoint = torch.load(cp, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['state_dict']) |
| model = model.to(device) |
| model.eval() |
| tokenizer = get_tokenizer('ViT-H-14') |
|
|
| model_dict['model'] = model |
| model_dict['preprocess_val'] = preprocess_val |
| return model_dict, tokenizer |
|
|
| def initialize_pickscore(device, checkpoint_path): |
| processor = AutoProcessor.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') |
| model = AutoModel.from_pretrained(checkpoint_path).eval().to(device) |
| return model, processor |
|
|
| def initialize_aesthetic_model(): |
| import open_clip |
| from os.path import expanduser |
| from urllib.request import urlretrieve |
| import torch.nn as nn |
|
|
| def get_aesthetic_model(clip_model="vit_l_14"): |
| """Load the aesthetic model with caching""" |
|
|
| home = expanduser("~") |
| cache_folder = home + "/.cache/emb_reader" |
| path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth" |
| if not os.path.exists(path_to_model): |
| os.makedirs(cache_folder, exist_ok=True) |
| url_model = ( |
| "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true" |
| ) |
| urlretrieve(url_model, path_to_model) |
| |
| if clip_model == "vit_l_14": |
| m = nn.Linear(768, 1) |
| elif clip_model == "vit_b_32": |
| m = nn.Linear(512, 1) |
| else: |
| raise ValueError() |
| m.load_state_dict(torch.load(path_to_model)) |
| m.eval() |
| return m |
| |
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') |
| amodel = get_aesthetic_model(clip_model="vit_l_14") |
| return model, preprocess, amodel |
|
|
| def initialize_clip(device): |
| """Initialize the CLIP model and processor.""" |
| model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| return model.to(device), processor |
|
|
| def score_hpsv2_batch(model_dict, tokenizer, device, img_paths: list, prompts: list) -> list: |
| model = model_dict['model'] |
| preprocess_val = model_dict['preprocess_val'] |
|
|
| |
| images = [preprocess_val(Image.open(p)).unsqueeze(0)[:,:3,:,:] for p in img_paths] |
| images = torch.cat(images, dim=0).to(device=device) |
| texts = tokenizer(prompts).to(device=device) |
| with torch.no_grad(): |
| outputs = model(images, texts) |
| image_features, text_features = outputs["image_features"], outputs["text_features"] |
| logits_per_image = image_features @ text_features.T |
| hps_scores = torch.diagonal(logits_per_image).cpu() |
| return hps_scores |
|
|
| def score_pick_score_batch(prompts, images, model, processor, device): |
| |
| pil_images = [Image.open(p) for p in images] |
| image_inputs = processor( |
| images=pil_images, |
| padding=True, |
| truncation=True, |
| max_length=77, |
| return_tensors="pt", |
| ).to(device) |
| |
| text_inputs = processor( |
| text=prompts, |
| padding=True, |
| truncation=True, |
| max_length=77, |
| return_tensors="pt", |
| ).to(device) |
|
|
| with torch.no_grad(): |
| |
| image_embs = model.get_image_features(**image_inputs) |
| image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
| |
| text_embs = model.get_text_features(**text_inputs) |
| text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) |
| |
| scores = model.logit_scale.exp() * (text_embs @ image_embs.T) |
| scores = torch.diagonal(scores).cpu() |
| |
| return scores |
|
|
|
|
| def score_aesthetic_batch(model, preprocess, aesthetic_model, device, img_paths: list) -> list: |
| """Scores a batch of images using the aesthetic model.""" |
| images = [preprocess(Image.open(p)).unsqueeze(0) for p in img_paths] |
| images = torch.cat(images, dim=0).to(device=device) |
| with torch.no_grad(): |
| feat = model.encode_image(images) |
| feat = feat / feat.norm(dim=-1, keepdim=True) |
| pred = aesthetic_model(feat).cpu() |
| return pred |
|
|
| def score_clip_batch(model, processor, device, img_paths: list, prompts: list) -> list: |
| """Scores a batch of images against prompts using CLIP.""" |
| |
| pil_images = [Image.open(p) for p in img_paths] |
| image_inputs = processor( |
| images=pil_images, |
| padding=True, |
| truncation=True, |
| max_length=77, |
| return_tensors="pt", |
| ).to(device) |
| |
| text_inputs = processor( |
| text=prompts, |
| padding=True, |
| truncation=True, |
| max_length=77, |
| return_tensors="pt", |
| ).to(device) |
|
|
| with torch.no_grad(): |
| |
| image_embs = model.get_image_features(**image_inputs) |
| image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
| |
| text_embs = model.get_text_features(**text_inputs) |
| text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) |
| |
| scores = image_embs @ text_embs.T |
| scores = torch.diagonal(scores).cpu() |
| |
| return scores |
|
|
| def calculate_category_stats(data_dict): |
| """Calculate statistics for each category""" |
| stats = {} |
| for category, data_list in data_dict.items(): |
| if not data_list: |
| stats[category] = { |
| 'count': 0, |
| 'mean': 0.0, |
| 'std': 0.0, |
| 'min': 0.0, |
| 'max': 0.0 |
| } |
| continue |
| |
| rewards = [item['reward'] for item in data_list] |
| stats[category] = { |
| 'count': len(rewards), |
| 'mean': float(np.mean(rewards)), |
| 'std': float(np.std(rewards)), |
| 'min': float(np.min(rewards)), |
| 'max': float(np.max(rewards)) |
| } |
| total_mean = np.mean([stat['mean'] for stat in stats.values() if stat['count'] > 0]) |
| stats['OVERALL'] = { |
| 'count': sum(stat['count'] for stat in stats.values()), |
| 'mean': float(total_mean), |
| 'std': float(np.std([stat['mean'] for stat in stats.values() if stat['count'] > 0])), |
| 'min': float(min(stat['min'] for stat in stats.values() if stat['count'] > 0)), |
| 'max': float(max(stat['max'] for stat in stats.values() if stat['count'] > 0)) |
| } |
| return stats |
|
|
| def print_stats(stats): |
| print(f"{'Category':<30} {'Count':<8} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10}") |
| print("-" * 78) |
| for category, stat in stats.items(): |
| category_name = category |
| print(f"{category_name:<30} {stat['count']:<8} {stat['mean']:<10.4f} {stat['std']:<10.4f} {stat['min']:<10.4f} {stat['max']:<10.4f}") |
| |
| |
| if stats: |
| all_counts = [stat['count'] for stat in stats.values()] |
| all_means = [stat['mean'] for stat in stats.values() if stat['count'] > 0] |
| if all_means: |
| print("-" * 78) |
| print(f"{'OVERALL':<30} {sum(all_counts):<8} {np.mean(all_means):<10.4f} {'':<10} {min([stat['min'] for stat in stats.values() if stat['count'] > 0]):<10.4f} {max([stat['max'] for stat in stats.values() if stat['count'] > 0]):<10.4f}") |
|
|
| def worker_process(process_id, process_dict, config_path, checkpoint_path, mode, device_id, dtype, batch_size, return_dict): |
| """Worker process function that processes a chunk of data""" |
| category_rewards = defaultdict(list) |
|
|
| device = f"cuda:{device_id}" if torch.cuda.is_available() else "cpu" |
| if mode == 'imagereward': |
| model = RM.load("ImageReward-v1.0") |
| elif mode == 'hpsv2': |
| inferencer = initialize_model_hpsv2(device, checkpoint_path) |
| model_dict, tokenizer = inferencer |
| elif mode == 'hpsv3': |
| inferencer = HPSv3RewardInferencer(config_path=config_path, checkpoint_path=checkpoint_path,device=device) |
| elif mode == 'pickscore': |
| model, processor = initialize_pickscore(device, checkpoint_path) |
| elif mode == 'aesthetic': |
| model, preprocess, aesthetic_model = initialize_aesthetic_model() |
| model = model.to(device) |
| aesthetic_model = aesthetic_model.to(device) |
| elif mode == 'clip': |
| model, processor = initialize_clip(device) |
| model = model.to(device) |
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
| for category, chunk_data in tqdm(process_dict.items(), total=len(process_dict), desc='Total', disable=not process_id == 0): |
| processed_data = [] |
| |
| for batch_start in tqdm(range(0, len(chunk_data), batch_size), |
| total=(len(chunk_data) + batch_size - 1) // batch_size, |
| desc=f"Category {category}", disable=not process_id == 0): |
| batch_end = min(batch_start + batch_size, len(chunk_data)) |
| image_paths = chunk_data[batch_start:batch_end] |
| text_paths = [p[:-4]+'.txt' for p in image_paths] |
|
|
| prompts = ['\n'.join(open(p, 'r').readlines()) for p in text_paths] |
|
|
| with torch.no_grad(): |
| if mode == 'imagereward': |
| rewards = torch.tensor([model.score(prompt, image_path) for prompt, image_path in zip(prompts, image_paths)]) |
| elif mode == 'hpsv2': |
| rewards = score_hpsv2_batch(model_dict, tokenizer, device, image_paths, prompts) |
| elif mode == 'hpsv3': |
| rewards = inferencer.reward(image_paths, prompts) |
| elif mode == 'pickscore': |
| rewards = score_pick_score_batch(prompts, image_paths, model, processor, device) |
| elif mode == 'aesthetic': |
| rewards = score_aesthetic_batch(model, preprocess, aesthetic_model, device, image_paths) |
| elif mode == 'clip': |
| rewards = score_clip_batch(model, processor, device, image_paths, prompts) |
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
| |
| torch.cuda.empty_cache() |
| for i in range(len(image_paths)): |
| if rewards.ndim == 2: |
| reward = rewards[i][0].item() |
| else: |
| reward = rewards[i].item() |
| processed_data.append({ |
| 'image_path': image_paths[i], |
| 'reward': reward, |
| 'prompt': prompts[i] |
| }) |
|
|
| category_rewards[category] = processed_data |
|
|
| return_dict[process_id] = { |
| 'data': category_rewards, |
| } |
|
|
| def chunk_list(data_list, num_chunks): |
| """Split list into roughly equal chunks""" |
| chunk_size = len(data_list) // num_chunks |
| remainder = len(data_list) % num_chunks |
| |
| chunks = [] |
| start = 0 |
| for i in range(num_chunks): |
| |
| current_chunk_size = chunk_size + (1 if i < remainder else 0) |
| end = start + current_chunk_size |
| chunks.append(data_list[start:end]) |
| start = end |
| |
| return chunks |
|
|
| def main(config_path, checkpint_path, mode, image_folders, output_path, batch_size=16, num_processes=8, num_machines=1, machine_id=0): |
| print(f"Config path: {config_path}") |
|
|
| dtype = torch.bfloat16 |
| |
| |
| folder_dict = {} |
| for folder in image_folders: |
| images = [] |
| for ext in ['.png', '.jpg']: |
| images.extend(glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True)) |
| machine_image_chunks = chunk_list(images, num_machines) |
| image_list = machine_image_chunks[machine_id] if machine_id < len(machine_image_chunks) else [] |
| print(f"Folder {folder} total data points: {len(image_list)}") |
| data_chunks = chunk_list(image_list, num_processes) |
| print(f"Folder {folder} data split into {num_processes} chunks with sizes: {[len(chunk) for chunk in data_chunks]}") |
| folder_dict[folder] = data_chunks |
|
|
| per_process_folder_dict = [] |
| for i in range(num_processes): |
| one_dict = {} |
| for key, value in folder_dict.items(): |
| one_dict[key] = value[i] if i < len(value) else [] |
| per_process_folder_dict.append(one_dict) |
|
|
| |
| with mp.Manager() as manager: |
| return_dict = manager.dict() |
| processes = [] |
| |
| |
| for i in range(num_processes): |
| device_id = i % torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| |
| p = mp.Process(target=worker_process, |
| args=(i, per_process_folder_dict[i], config_path, checkpint_path, mode, device_id, dtype, batch_size, return_dict)) |
| p.start() |
| processes.append(p) |
| |
| for p in processes: |
| p.join() |
| |
| |
| all_processed_data = {} |
| for i in range(num_processes): |
| if i in return_dict: |
| result = return_dict[i] |
| process_data = result['data'] |
| |
| for category, data_list in process_data.items(): |
| if category not in all_processed_data: |
| all_processed_data[category] = [] |
| all_processed_data[category].extend(data_list) |
| else: |
| print(f"No result from process {i}") |
| |
| |
| if all_processed_data: |
| stats = calculate_category_stats(all_processed_data) |
| print(f"\n=== Machine {machine_id} Statistics ===") |
| print_stats(stats) |
| |
| |
| if num_machines > 1: |
| |
| machine_output_path = output_path.replace('.json', f'_machine_{machine_id}.json') |
| with open(machine_output_path, "w") as f: |
| json.dump(all_processed_data, f, indent=4) |
| print(f"Machine {machine_id} results saved to {machine_output_path}") |
| |
| |
| if machine_id == 0: |
| print("Waiting for all machines to complete...") |
| |
| |
| final_result = {} |
| for i in range(num_machines): |
| machine_file = output_path.replace('.json', f'_machine_{i}.json') |
| if os.path.exists(machine_file): |
| print(f"Loading results from machine {i}") |
| with open(machine_file, 'r') as f: |
| machine_data = json.load(f) |
| |
| for category, data_list in machine_data.items(): |
| if category not in final_result: |
| final_result[category] = [] |
| final_result[category].extend(data_list) |
| else: |
| print(f"Warning: Machine {i} results file not found: {machine_file}") |
| |
| |
| stats = calculate_category_stats(final_result) |
| print("\n=== Final Combined Statistics ===") |
| print_stats(stats) |
| |
| |
| final_output = { |
| 'statistics': stats, |
| 'data': final_result, |
| } |
| with open(output_path, "w") as f: |
| json.dump(final_output, f, indent=4) |
| print(f"Final combined results saved to {output_path}") |
| else: |
| |
| stats = calculate_category_stats(all_processed_data) |
| print("\n=== Statistics ===") |
| print_stats(stats) |
| |
| |
| output_data = { |
| 'statistics': stats, |
| 'data': all_processed_data, |
| } |
| with open(output_path, "w") as f: |
| json.dump(output_data, f, indent=4) |
| print(f"Results saved to {output_path}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Process images with HPSv3 reward inference') |
| parser.add_argument('--config_path', type=str, help='Path to the configuration file') |
| parser.add_argument('--checkpoint_path', type=str, help='Path to the model checkpoint file') |
| parser.add_argument('--mode', type=str, choices=['imagereward','hpsv2', 'hpsv3', 'pickscore', 'aesthetic', 'clip'], default='hpsv3') |
| parser.add_argument('--image_folders', type=str, nargs='+', required=True, help='List of image folder paths to process') |
| parser.add_argument('--output_path', type=str, required=True, help='Path to save the output JSON file') |
| parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing (default: 16)') |
| parser.add_argument('--num_processes', type=int, default=8, help='Number of processes to use (default: 8)') |
| parser.add_argument('--num_machines', type=int, default=1, help='Total number of machines (default: 1)') |
| parser.add_argument('--machine_id', type=int, default=0, help='ID of current machine (default: 0)') |
| |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| mp.set_start_method('spawn', force=True) |
| |
| args = parse_args() |
| main( |
| config_path=args.config_path, |
| checkpint_path=args.checkpoint_path, |
| mode=args.mode, |
| image_folders=args.image_folders, |
| output_path=args.output_path, |
| batch_size=args.batch_size, |
| num_processes=args.num_processes, |
| num_machines=args.num_machines, |
| machine_id=args.machine_id |
| ) |