# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long) import argparse import os import sys from pathlib import Path import faiss import torch import torchvision.transforms as T from PIL import Image from torch import nn from tqdm import tqdm CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) SALAD_ROOT = os.path.join(CURRENT_DIR, "salad") if SALAD_ROOT not in sys.path: sys.path.insert(0, SALAD_ROOT) from loop_utils.salad.models import helper class VPRModel(nn.Module): """This is the main model for Visual Place Recognition we use Pytorch Lightning for modularity purposes. Args: pl (_type_): _description_ """ def __init__( self, # ---- Backbone backbone_arch="resnet50", backbone_config={}, # ---- Aggregator agg_arch="ConvAP", agg_config={}, ): super().__init__() # Backbone self.encoder_arch = backbone_arch self.backbone_config = backbone_config # Aggregator self.agg_arch = agg_arch self.agg_config = agg_config # ---------------------------------- # get the backbone and the aggregator self.backbone = helper.get_backbone(backbone_arch, backbone_config) self.aggregator = helper.get_aggregator(agg_arch, agg_config) # the forward pass of the lightning model def forward(self, x): x = self.backbone(x) x = self.aggregator(x) return x class LoopDetector: """Loop detector class for detecting loop closures in image sequences""" def __init__(self, image_dir, output="loop_closures.txt", config=None): """Initialize the loop detector Args: image_dir: Directory path containing images ckpt_path: Model checkpoint path image_size: Image resize dimensions [height width] batch_size: Batch size for processing similarity_threshold: Similarity threshold for loop closure top_k: Number of nearest neighbors to check for each image use_nms: Whether to use Non-Maximum Suppression (NMS) filtering nms_threshold: NMS threshold for minimum frame difference between loop pairs output: Output file path """ self.config = config self.image_dir = image_dir self.ckpt_path = self.config["Weights"]["SALAD"] self.image_size = self.config["Loop"]["SALAD"]["image_size"] self.batch_size = self.config["Loop"]["SALAD"]["batch_size"] self.similarity_threshold = self.config["Loop"]["SALAD"]["similarity_threshold"] self.top_k = self.config["Loop"]["SALAD"]["top_k"] self.use_nms = self.config["Loop"]["SALAD"]["use_nms"] self.nms_threshold = self.config["Loop"]["SALAD"]["nms_threshold"] self.output = output self.model = None self.device = None self.image_paths = None self.descriptors = None self.loop_closures = None def _input_transform(self, image_size=None): """Create image transformation function""" MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] if image_size: return T.Compose( [ T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR), T.ToTensor(), T.Normalize(mean=MEAN, std=STD), ] ) else: return T.Compose([T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]) def load_model(self): """Load model""" model = VPRModel( backbone_arch="dinov2_vitb14", backbone_config={ "num_trainable_blocks": 4, "return_token": True, "norm_layer": True, }, agg_arch="SALAD", agg_config={ "num_channels": 768, "num_clusters": 64, "cluster_dim": 128, "token_dim": 256, }, ) model.load_state_dict(torch.load(self.ckpt_path)) model = model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"Model loaded: {self.ckpt_path}") self.model = model self.device = device return model, device def get_image_paths(self): """Get paths of all image files in directory""" image_extensions = [".jpg", ".jpeg", ".png"] image_paths = [] for ext in image_extensions: image_paths.extend(list(Path(self.image_dir).glob(f"*{ext}"))) image_paths.extend(list(Path(self.image_dir).glob(f"*{ext.upper()}"))) image_paths = sorted(image_paths) self.image_paths = image_paths return image_paths def extract_descriptors(self): """Extract image feature descriptors""" if self.model is None or self.device is None: self.load_model() if self.image_paths is None: self.get_image_paths() transform = self._input_transform(self.image_size) descriptors = [] for i in tqdm( range(0, len(self.image_paths), self.batch_size), desc="Extracting features" ): batch_paths = self.image_paths[i : i + self.batch_size] batch_imgs = [] for path in batch_paths: try: img = Image.open(path).convert("RGB") img = transform(img) batch_imgs.append(img) except Exception as e: print(f"Error processing image {path}: {e}") img = ( torch.zeros(3, 224, 224) if self.image_size is None else torch.zeros(3, self.image_size[0], self.image_size[1]) ) batch_imgs.append(img) batch_tensor = torch.stack(batch_imgs).to(self.device) with torch.no_grad(): with torch.autocast( device_type="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float16 ): batch_descriptors = self.model(batch_tensor).cpu() descriptors.append(batch_descriptors) self.descriptors = torch.cat(descriptors) return self.descriptors def _apply_nms_filter(self, loop_closures, nms_threshold): """Apply Non-Maximum Suppression (NMS) filtering to loop pairs""" if not loop_closures or nms_threshold <= 0: return loop_closures sorted_loops = sorted(loop_closures, key=lambda x: x[2], reverse=True) filtered_loops = [] suppressed = set() max_frame = max(max(idx1, idx2) for idx1, idx2, _ in loop_closures) for idx1, idx2, sim in sorted_loops: if idx1 in suppressed or idx2 in suppressed: continue filtered_loops.append((idx1, idx2, sim)) suppress_range = set() start1 = max(0, idx1 - nms_threshold) end1 = min(idx1 + nms_threshold + 1, idx2) suppress_range.update(range(start1, end1)) start2 = max(idx1 + 1, idx2 - nms_threshold) end2 = min(idx2 + nms_threshold + 1, max_frame + 1) suppress_range.update(range(start2, end2)) suppressed.update(suppress_range) return filtered_loops def _ensure_decending_order(self, tuples_list): return [(max(a, b), min(a, b), score) for a, b, score in tuples_list] def find_loop_closures(self): """Find loop closures""" if self.descriptors is None: self.extract_descriptors() embed_size = self.descriptors.shape[1] faiss_index = faiss.IndexFlatIP(embed_size) normalized_descriptors = self.descriptors.numpy() faiss_index.add(normalized_descriptors) similarities, indices = faiss_index.search( normalized_descriptors, self.top_k + 1 ) # +1 because self is most similar loop_closures = [] for i in range(len(self.descriptors)): # Skip first result (self) for j in range(1, self.top_k + 1): neighbor_idx = indices[i, j] similarity = similarities[i, j] if similarity > self.similarity_threshold and abs(i - neighbor_idx) > 10: if i < neighbor_idx: loop_closures.append((i, neighbor_idx, similarity)) else: loop_closures.append((neighbor_idx, i, similarity)) loop_closures = list(set(loop_closures)) loop_closures.sort(key=lambda x: x[2], reverse=True) if self.use_nms and self.nms_threshold > 0: loop_closures = self._apply_nms_filter(loop_closures, self.nms_threshold) self.loop_closures = self._ensure_decending_order(loop_closures) return self.loop_closures def save_results(self): """Save loop detection results to file""" if self.loop_closures is None: self.find_loop_closures() with open(self.output, "w") as f: f.write("# Loop Detection Results (index1, index2, similarity)\n") if self.use_nms: f.write(f"# NMS filtering applied, threshold: {self.nms_threshold}\n") f.write("\n# Loop pairs:\n") for i, j, sim in self.loop_closures: f.write(f"{i}, {j}, {sim:.4f}\n") f.write("\n# Image path list:\n") for i, path in enumerate(self.image_paths): f.write(f"# {i}: {path}\n") print(f"Found {len(self.loop_closures)} loop pairs, results saved to {self.output}") if self.use_nms: print(f"NMS filtering applied, threshold: {self.nms_threshold}") if self.loop_closures: print("\nTop 10 loop pairs:") for i, (idx1, idx2, sim) in enumerate(self.loop_closures[:10]): print(f"{idx1}, {idx2}, similarity: {sim:.4f}") if i >= 9: break def get_loop_list(self): return [(idx1, idx2) for idx1, idx2, _ in self.loop_closures] def run(self): """Run complete loop detection pipeline""" print("Loading model...") if self.model is None: self.load_model() self.get_image_paths() if not self.image_paths: print(f"No image files found in {self.image_dir}") return print(f"Found {len(self.image_paths)} image files") self.extract_descriptors() self.find_loop_closures() self.save_results() return self.loop_closures def main(): parser = argparse.ArgumentParser(description="Loop detection using SALAD model") parser.add_argument( "--image_dir", type=str, default="/media/deng/Data/KITTIdataset/data_odometry_color/dataset/sequences/00/image_2", help="Directory path containing images", ) parser.add_argument( "--ckpt_path", type=str, default="./weights/dino_salad.ckpt", help="Model checkpoint path" ) parser.add_argument( "--image_size", nargs=2, type=int, default=[336, 336], help="Image resize dimensions [height width]", ) parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing") parser.add_argument( "--similarity_threshold", type=float, default=0.7, help="Similarity threshold for loop closure", ) parser.add_argument( "--top_k", type=int, default=5, help="Number of nearest neighbors to check for each image" ) parser.add_argument("--output", type=str, default="loop_closures.txt", help="Output file path") parser.add_argument( "--use_nms", action="store_true", default=True, help="Whether to use Non-Maximum Suppression (NMS) filtering", ) parser.add_argument( "--nms_threshold", type=int, default=25, help="NMS threshold for minimum frame difference between loop pairs", ) args = parser.parse_args() detector = LoopDetector( image_dir=args.image_dir, ckpt_path=args.ckpt_path, image_size=args.image_size, batch_size=args.batch_size, similarity_threshold=args.similarity_threshold, top_k=args.top_k, use_nms=args.use_nms, nms_threshold=args.nms_threshold, output=args.output, ) detector.run() if __name__ == "__main__": main()