| from .unet import Unet |
| from .unetplusplus import UnetPlusPlus |
| from .manet import MAnet |
| from .linknet import Linknet |
| from .fpn import FPN |
| from .pspnet import PSPNet |
| from .deeplabv3 import DeepLabV3, DeepLabV3Plus |
| from .pan import PAN |
|
|
| from . import encoders |
| from . import utils |
| from . import losses |
|
|
| from .__version__ import __version__ |
|
|
| from typing import Optional |
| import torch |
|
|
|
|
| def create_model( |
| arch: str, |
| encoder_name: str = "resnet34", |
| encoder_weights: Optional[str] = "imagenet", |
| in_channels: int = 3, |
| classes: int = 1, |
| **kwargs, |
| ) -> torch.nn.Module: |
| """Models wrapper. Allows to create any model just with parametes |
| |
| """ |
|
|
| archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN] |
| archs_dict = {a.__name__.lower(): a for a in archs} |
| try: |
| model_class = archs_dict[arch.lower()] |
| except KeyError: |
| raise KeyError("Wrong architecture type `{}`. Available options are: {}".format( |
| arch, list(archs_dict.keys()), |
| )) |
| return model_class( |
| encoder_name=encoder_name, |
| encoder_weights=encoder_weights, |
| in_channels=in_channels, |
| classes=classes, |
| **kwargs, |
| ) |