| import glob |
| import os |
| import PIL |
| import PIL.Image |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| import open_clip |
|
|
|
|
| class SingleFolderDataset(torch.utils.data.Dataset): |
| def __init__(self, folder_path, transform=None): |
| self.folder_path = folder_path |
| self.transform = transform |
| self.image_paths = glob.glob(os.path.join(folder_path, "*")) |
| print('Found {} images in {}'.format(len(self.image_paths), folder_path)) |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, index): |
| image_path = self.image_paths[index] |
| image = PIL.Image.open(image_path) |
| if self.transform: |
| image = self.transform(image) |
| return image, os.path.basename(image_path) |
|
|
|
|
| def extract_feats(index_config): |
|
|
| ai_config = index_config['a1_config'] |
| weight_path = index_config['weight_path'] |
| img_dir = index_config['img_dir'] |
| batch_size = 1024 |
|
|
| model, _, transform = open_clip.create_model_and_transforms(ai_config, pretrained=weight_path) |
|
|
| devive = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| print('Using device:', devive) |
| model = model.to(devive) |
|
|
| dataset = SingleFolderDataset(img_dir, transform=transform) |
|
|
| dl = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) |
|
|
| im_ids = [] |
| im_feats = [] |
|
|
| for i, (patched_tensor, img_id) in tqdm(enumerate(dl)): |
| patched_tensor = patched_tensor.to(devive) |
|
|
| with torch.no_grad(): |
| out = model.encode_image(patched_tensor) |
|
|
| im_ids.append(img_id) |
| im_feats.append(out.cpu().numpy()) |
|
|
| im_hashes = np.concatenate(im_ids) |
| im_feats = np.concatenate(im_feats) |
| return im_hashes, im_feats |
|
|
|
|
|
|