| import argparse |
| import os |
| import re |
|
|
| import bleach |
| import cv2 |
| import jsonlines |
| import numpy as np |
| import torch |
| from loguru import logger |
| from PIL import Image |
| from tqdm import tqdm |
| from transformers import AutoTokenizer, CLIPImageProcessor, PreTrainedTokenizer |
|
|
| from eval.utils import grounding_image_ecoder_preprocess |
| from model.Legion import LegionForCls |
| from model.llava import conversation as conversation_lib |
| from model.llava.mm_utils import tokenizer_image_token |
| from model.SAM.utils.transforms import ResizeLongestSide |
| from tools.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="LEGION Inference") |
| |
| parser.add_argument("--model_path", required=True, help="The directory to your legion ckpt") |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") |
| parser.add_argument("--model_max_length", default=512, type=int) |
| |
| parser.add_argument("--image_root", required=True, help="The directory containing images to run inference.") |
| parser.add_argument("--save_root", required=True, help="The directory to store the inference result.") |
| |
| args = parser.parse_args() |
| return args |
|
|
|
|
| class LEGION: |
| """A simple wrapper for LEGION model loading and inference. |
| |
| Args: |
| model_path (str): Path to the model checkpoint. |
| image_size (int): Size of the input images. |
| model_max_length (int): Maximum length of the model input sequence. |
| """ |
|
|
| INSTRUCTION = ( |
| "Please provide a detailed analysis of artifacts in this photo, considering " |
| "physical artifacts (e.g., optical display issues, violations of physical laws, " |
| "and spatial/perspective errors), structural artifacts (e.g., deformed objects, asymmetry, or distorted text), " |
| "and distortion artifacts (e.g., color/texture distortion, noise/blur, artistic style errors, and material misrepresentation). " |
| "Output with interleaved segmentation masks for the corresponding parts of the answer." |
| ) |
|
|
| def __init__(self, model_path: str, image_size: int = 1024, model_max_length: int = 512): |
| self.model_path = model_path |
| self.image_size = image_size |
| self.model_max_length = model_max_length |
|
|
| |
| self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( |
| self.model_path, |
| cache_dir=None, |
| model_max_length=self.model_max_length, |
| padding_side="right", |
| use_fast=False |
| ) |
| self.tokenizer.pad_token = self.tokenizer.unk_token |
| seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| logger.info("Tokenizer loaded successfully.") |
|
|
| |
| self.model: LegionForCls = LegionForCls.from_pretrained( |
| self.model_path, |
| low_cpu_mem_usage=True, |
| seg_token_idx=seg_token_idx, |
| torch_dtype=torch.bfloat16 |
| ) |
| |
| self.model.config.eos_token_id = self.tokenizer.eos_token_id |
| self.model.config.bos_token_id = self.tokenizer.bos_token_id |
| self.model.config.pad_token_id = self.tokenizer.pad_token_id |
| |
| self.model.get_model().initialize_vision_modules(self.model.get_model().config) |
| vision_tower = self.model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.bfloat16) |
| |
| self.model = self.model.bfloat16().cuda() |
| vision_tower.to(device="cuda") |
| self.model.eval() |
| logger.info("Model loaded successfully.") |
|
|
| |
| self.image_processor = CLIPImageProcessor.from_pretrained(self.model.config.vision_tower) |
| self.transform = ResizeLongestSide(self.image_size) |
| logger.info("Image processor initialized successfully.") |
| |
| @torch.inference_mode() |
| def _infer(self, raw_image: np.ndarray): |
| """Run inference on a single image. |
| |
| Args: |
| raw_image (np.ndarray): The input image in numpy array format. |
| |
| Returns: |
| tuple: A tuple containing the explanation string, predicted masks, phrases, and classification result. |
| """ |
| |
| instructions = bleach.clean(LEGION.INSTRUCTION) |
| instructions = instructions.replace('<', '<').replace('>', '>') |
|
|
| |
| conv = conversation_lib.conv_templates["llava_v1"].copy() |
| conv.messages = [] |
| prompt = f"The {DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN} provides an overview of the picture.\n" + instructions |
| conv.append_message(conv.roles[0], prompt) |
| conv.append_message(conv.roles[1], "") |
| prompt = conv.get_prompt() |
|
|
| |
| image_np = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) |
| original_size_list = [image_np.shape[:2]] |
| image_clip = (self.image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0].unsqueeze(0).cuda()) |
| image_clip = image_clip.bfloat16() |
|
|
| |
| image = self.transform.apply_image(image_np) |
| resize_list = [image.shape[:2]] |
| image = (grounding_image_ecoder_preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()).unsqueeze(0).cuda()) |
| image = image.bfloat16() |
|
|
| |
| input_ids = tokenizer_image_token(prompt, self.tokenizer, return_tensors="pt") |
| input_ids = input_ids.unsqueeze(0).cuda() |
|
|
| |
| output_ids, pred_masks = self.model.evaluate( |
| image_clip, |
| image, |
| input_ids, |
| resize_list, |
| original_size_list, |
| max_tokens_new=512, |
| bboxes=None |
| ) |
| output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] |
|
|
| |
| text_output = self.tokenizer.decode(output_ids, skip_special_tokens=False) |
| text_output = text_output.replace("\n", "").replace(" ", " ") |
| text_output = text_output.split("ASSISTANT: ")[-1] |
| cleaned_str = re.sub(r'<.*?>', '', text_output) |
| |
| cleaned_str = cleaned_str.replace('[SEG]', '') |
| |
| cleaned_str = ' '.join(cleaned_str.split()).strip("'") |
| cleaned_str = cleaned_str.strip() |
|
|
| |
| logits = self.model(global_enc_images=image_clip, inference_cls=True)['logits'].cpu() |
| _, pred_cls = torch.max(logits, dim=1) |
| pred_cls = int(pred_cls) |
| return cleaned_str, pred_masks, pred_cls |
|
|
| @torch.inference_mode() |
| def infer(self, image_path: str): |
| """Run inference on a single image. |
| |
| Args: |
| image_path (str): Path to the input image. |
| |
| Returns: |
| dict: A dictionary containing the explanation, localization mask path, and detection result. |
| """ |
| raw_image = cv2.imread(image_path) |
| explanation, localization, detection = self._infer(raw_image.astype(np.uint8)) |
|
|
| |
| localization = localization[0].cpu() |
| binary_localization = localization > 0 |
| binary_localization = torch.any(binary_localization, dim=0).int() |
| localization = (binary_localization.numpy() * 255).astype(np.uint8) |
| localization = Image.fromarray(localization, mode="L") |
|
|
| |
| detection = "real" if detection == 1 else "fake" |
|
|
| return { |
| "explanation": explanation, |
| "localization": localization, |
| "detection": detection |
| } |
|
|
|
|
| def main(args): |
| |
| suffixes = [".jpg", ".jpeg", ".png"] |
| image_paths = sorted(os.listdir(args.image_root)) |
| image_paths = [p for p in image_paths if os.path.splitext(p)[-1].lower() in suffixes] |
| logger.info(f"Found {len(image_paths)} images for inference.") |
|
|
| |
| legion = LEGION(args.model_path, args.image_size, args.model_max_length) |
|
|
| |
| os.makedirs(args.save_root, exist_ok=True) |
| localization_save_dir = os.path.join(args.save_root, "localization") |
| os.makedirs(localization_save_dir, exist_ok=True) |
| explanation_save_path = os.path.join(args.save_root, "explanations.jsonl") |
|
|
| |
| num_processed_images = 0 |
| if os.path.exists(explanation_save_path): |
| num_processed_images = len(list(jsonlines.open(explanation_save_path))) |
| logger.info(f"Resuming from {num_processed_images} processed images.") |
| image_paths = image_paths[num_processed_images:] |
|
|
| |
| with jsonlines.open(explanation_save_path, mode="a", flush=True) as writer: |
| for image_path in tqdm(image_paths): |
| image_name = os.path.splitext(image_path)[0] |
| full_image_path = os.path.join(args.image_root, image_path) |
| result = legion.infer(full_image_path) |
| |
| this_localization_save_path = os.path.join(localization_save_dir, f"{image_name}_mask.png") |
| result["localization"].save(this_localization_save_path) |
| result["localization"] = this_localization_save_path |
| |
| result["image_path"] = full_image_path |
| |
| writer.write(result) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|