| import sys
|
| import os
|
| import torch
|
| import torchvision.transforms as T
|
| from typing import List, Tuple
|
|
|
| from torch.hub import download_url_to_file
|
| import urllib.parse
|
|
|
|
|
|
|
| dependencies = [
|
| 'tomesd',
|
| 'omegaconf',
|
| 'numpy',
|
| 'rich',
|
| 'yapf',
|
| 'addict',
|
| 'tqdm',
|
| 'packaging',
|
| 'torchvision'
|
| ]
|
|
|
|
|
| model_dir = os.path.join(os.path.dirname(__file__), 'model_without_OpenMMLab')
|
| sys.path.insert(0, model_dir)
|
|
|
|
|
| from segformer_plusplus.build_model import create_model
|
| from segformer_plusplus.random_benchmark import random_benchmark
|
|
|
|
|
| sys.path.pop(0)
|
|
|
|
|
| def _get_local_cache_path(url: str, filename: str) -> str:
|
| """
|
| Creates the full local path to the checkpoint file in the PyTorch Hub cache.
|
| """
|
|
|
| torch_home = torch.hub._get_torch_home()
|
|
|
|
|
| checkpoint_dir = os.path.join(torch_home, 'checkpoints')
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
| url_path_hash = urllib.parse.quote_plus(url)
|
|
|
|
|
| local_filename = f"{filename}_{url_path_hash[:10]}.pt"
|
|
|
| return os.path.join(checkpoint_dir, local_filename)
|
|
|
|
|
|
|
| def segformer_plusplus(
|
| backbone: str = 'b5',
|
| tome_strategy: str = 'bsm_hq',
|
| out_channels: int = 19,
|
| pretrained: bool = True,
|
| checkpoint_url: str = None,
|
| **kwargs
|
| ) -> torch.nn.Module:
|
| """
|
| Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation.
|
|
|
| Loads a SegFormer++ model with the specified backbone and head architecture.
|
| Install requirements via:
|
| pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision
|
|
|
| Args:
|
| backbone (str): The backbone type. Selectable from: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5'].
|
| tome_strategy (str): The token merging strategy. Selectable from: ['bsm_hq', 'bsm_fast', 'n2d_2x2'].
|
| out_channels (int): Number of output classes (e.g., 19 for Cityscapes).
|
| pretrained (bool): Whether to load the ImageNet pre-trained weights.
|
| checkpoint_url (str, optional): A URL to a specific checkpoint.
|
| **Important:** The download uses torch.hub.download_url_to_file(),
|
| which may be required for non-direct links.
|
|
|
| Returns:
|
| torch.nn.Module: The instantiated SegFormer++ model.
|
| """
|
| model = create_model(
|
| backbone=backbone,
|
| tome_strategy=tome_strategy,
|
| out_channels=out_channels,
|
| pretrained=pretrained
|
| )
|
|
|
| if checkpoint_url:
|
|
|
|
|
| local_filepath = _get_local_cache_path(
|
| url=checkpoint_url,
|
| filename=f"segformer_plusplus_{backbone}"
|
| )
|
|
|
| print(f"Attempting to load checkpoint from {checkpoint_url}...")
|
|
|
| if not os.path.exists(local_filepath):
|
|
|
| try:
|
| print(f"File not in cache. Downloading to {local_filepath}...")
|
|
|
|
|
| download_url_to_file(
|
| checkpoint_url,
|
| local_filepath,
|
| progress=True
|
| )
|
| print("Download successful.")
|
|
|
| except Exception as e:
|
| print(f"Error downloading checkpoint from {checkpoint_url}. Check the URL or use a direct asset link. Error: {e}")
|
|
|
| return model
|
|
|
|
|
| try:
|
| state_dict = torch.load(local_filepath, map_location='cpu')
|
|
|
|
|
|
|
| if 'state_dict' in state_dict:
|
| state_dict = state_dict['state_dict']
|
|
|
| model.load_state_dict(state_dict, strict=True)
|
| print("Checkpoint loaded successfully.")
|
|
|
| except Exception as e:
|
| print(f"Error loading state dict from file {local_filepath}: {e}")
|
|
|
| print("The model was instantiated, but the checkpoint could not be loaded.")
|
|
|
| return model
|
|
|
|
|
|
|
| def data_transforms(
|
| resolution: Tuple[int, int] = (1024, 1024),
|
| mean: List[float] = [0.485, 0.456, 0.406],
|
| std: List[float] = [0.229, 0.224, 0.225],
|
| ) -> T.Compose:
|
| """
|
| Provides the appropriate data transformations for a given dataset.
|
|
|
| This function is an entry point to get the necessary preprocessing steps
|
| for images based on typical ImageNet values.
|
|
|
| Args:
|
| resolution (Tuple[int, int]): The desired size for the images (width, height).
|
| Defaults to (1024, 1024).
|
| mean (List[float]): The mean values for normalization. Defaults to the
|
| ImageNet means.
|
| std (List[float]): The standard deviations for normalization. Defaults to the
|
| ImageNet standard deviations.
|
|
|
| Returns:
|
| torchvision.transforms.Compose: A composition of transforms
|
| that can be applied to input images.
|
|
|
| Example:
|
| >>> # Load transforms with default parameters
|
| >>> transform = torch.hub.load('user/repo_name', 'data_transforms')
|
| >>>
|
| >>> # Load transforms with resize to custom image resolution and default normalization
|
| >>> transform_small = torch.hub.load('user/repo_name', 'data_transforms', resolution=(512, 512))
|
| """
|
| transform = T.Compose([
|
| T.Resize(resolution),
|
| T.ToTensor(),
|
| T.Normalize(mean=mean, std=std)
|
| ])
|
| return transform
|
|
|
|
|
|
|
| def random_benchmark_entrypoint(**kwargs):
|
| """
|
| Runs a random benchmark for SegFormer++.
|
| """
|
| return random_benchmark(**kwargs) |