From bbd363fc4a57e8405e63484f7dbb9d6dbeaa4482 Mon Sep 17 00:00:00 2001 From: Dimitris Mantas Date: Tue, 22 Oct 2024 14:32:28 +0300 Subject: [PATCH 1/3] Expose timm constructor arguments --- .../decoders/deeplabv3/model.py | 57 +++++++++++++------ .../decoders/fpn/model.py | 10 +++- .../decoders/linknet/model.py | 8 ++- .../decoders/manet/model.py | 12 ++-- .../decoders/pan/model.py | 12 ++-- .../decoders/pspnet/model.py | 12 ++-- .../decoders/unet/model.py | 12 ++-- .../decoders/unetplusplus/model.py | 12 ++-- .../decoders/upernet/model.py | 12 ++-- .../encoders/timm_universal.py | 28 +++++++-- 10 files changed, 124 insertions(+), 51 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index ad422dbc..796dd0dc 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -1,11 +1,13 @@ -from typing import Optional +from collections.abc import Iterable +from typing import Any, Literal, Optional from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder @@ -22,13 +24,17 @@ class DeepLabV3(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) decoder_channels: A number of convolution filters in ASPP module. Default is 256 + encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) + decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity + upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity). aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -36,6 +42,8 @@ class DeepLabV3(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + Returns: ``torch.nn.Module``: **DeepLabV3** @@ -49,12 +57,17 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", + encoder_output_stride: Literal[8, 16] = 8, decoder_channels: int = 256, + decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = False, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, - upsampling: int = 8, + upsampling: Optional[int] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -63,11 +76,16 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, - output_stride=8, + output_stride=encoder_output_stride, + **kwargs, ) self.decoder = DeepLabV3Decoder( - in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels + in_channels=self.encoder.out_channels[-1], + out_channels=decoder_channels, + atrous_rates=decoder_atrous_rates, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead( @@ -75,7 +93,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=1, - upsampling=upsampling, + upsampling=encoder_output_stride if upsampling is None else upsampling, ) if aux_params is not None: @@ -100,7 +118,9 @@ class DeepLabV3Plus(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) - decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) + decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) + decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True + decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 decoder_channels: A number of convolution filters in ASPP module. Default is 256 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -116,6 +136,8 @@ class DeepLabV3Plus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + Returns: ``torch.nn.Module``: **DeepLabV3Plus** @@ -129,30 +151,27 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: int = 16, + encoder_output_stride: Literal[8, 16] = 16, decoder_channels: int = 256, - decoder_atrous_rates: tuple = (12, 24, 36), + decoder_atrous_rates: Iterable[int] = (12, 24, 36), + decoder_aspp_separable: bool = True, + decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() - if encoder_output_stride not in [8, 16]: - raise ValueError( - "Encoder output stride should be 8 or 16, got {}".format( - encoder_output_stride - ) - ) - self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, + **kwargs, ) self.decoder = DeepLabV3PlusDecoder( @@ -160,6 +179,8 @@ def __init__( out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, + aspp_separable=decoder_aspp_separable, + aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index f18457d5..373269c5 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -1,11 +1,12 @@ -from typing import Optional +from typing import Any, Optional from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import FPNDecoder @@ -40,6 +41,7 @@ class FPN(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **FPN** @@ -63,6 +65,7 @@ def __init__( activation: Optional[str] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -77,6 +80,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = FPNDecoder( diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index b8c3139f..708ea562 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union from segmentation_models_pytorch.base import ( + ClassificationHead, SegmentationHead, SegmentationModel, - ClassificationHead, ) from segmentation_models_pytorch.encoders import get_encoder + from .decoder import LinknetDecoder @@ -43,6 +44,7 @@ class Linknet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **Linknet** @@ -61,6 +63,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = LinknetDecoder( diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 08e64a2a..6651dee6 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import MAnetDecoder @@ -45,6 +46,7 @@ class MAnet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **MAnet** @@ -66,6 +68,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = MAnetDecoder( diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 8086d024..5c46f489 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import PANDecoder @@ -38,6 +39,7 @@ class PAN(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **PAN** @@ -58,6 +60,7 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -74,6 +77,7 @@ def __init__( depth=5, weights=encoder_weights, output_stride=encoder_output_stride, + **kwargs, ) self.decoder = PANDecoder( diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 9f9997f8..dbf04ea4 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import PSPDecoder @@ -44,6 +45,7 @@ class PSPNet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **PSPNet** @@ -65,6 +67,7 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -73,6 +76,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = PSPDecoder( diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 46528c5a..0ac7b5bd 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UnetDecoder @@ -44,6 +45,7 @@ class Unet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: Unet @@ -65,6 +67,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -73,6 +76,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UnetDecoder( diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 60d591f0..9ba72321 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union, List +from typing import Any, List, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UnetPlusPlusDecoder @@ -44,6 +45,7 @@ class UnetPlusPlus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **Unet++** @@ -65,6 +67,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -78,6 +81,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UnetPlusPlusDecoder( diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 18b97a94..de30a7bb 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,11 +1,12 @@ -from typing import Optional, Union +from typing import Any, Optional, Union -from segmentation_models_pytorch.encoders import get_encoder from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, ClassificationHead, + SegmentationHead, + SegmentationModel, ) +from segmentation_models_pytorch.encoders import get_encoder + from .decoder import UPerNetDecoder @@ -36,6 +37,7 @@ class UPerNet(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **UPerNet** @@ -56,6 +58,7 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], ): super().__init__() @@ -64,6 +67,7 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **kwargs, ) self.decoder = UPerNetDecoder( diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 9702a7c3..b901f1a5 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,11 +1,21 @@ +from typing import Any + import timm import torch.nn as nn class TimmUniversalEncoder(nn.Module): - def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): + def __init__( + self, + name: str, + pretrained: bool = True, + in_channels: int = 3, + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ): super().__init__() - kwargs = dict( + common_kwargs = dict( in_chans=in_channels, features_only=True, output_stride=output_stride, @@ -15,9 +25,11 @@ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride= # not all models support output stride argument, drop it by default if output_stride == 32: - kwargs.pop("output_stride") + common_kwargs.pop("output_stride") - self.model = timm.create_model(name, **kwargs) + self.model = timm.create_model( + name, **_merge_kwargs_no_dupls(common_kwargs, kwargs) + ) self._in_channels = in_channels self._out_channels = [in_channels] + self.model.feature_info.channels() @@ -36,3 +48,11 @@ def out_channels(self): @property def output_stride(self): return min(self._output_stride, 2**self._depth) + + +def _merge_kwargs_no_dupls(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + dupls = a.keys() & b.keys() + if dupls: + raise ValueError(f"'{dupls}' already specified internally") + + return a | b From 2d654c4da04c61a10012b2ebf8c39e86aaf6f293 Mon Sep 17 00:00:00 2001 From: Dimitris Mantas Date: Tue, 22 Oct 2024 14:43:06 +0300 Subject: [PATCH 2/3] Remove leak from other branch --- .../decoders/deeplabv3/model.py | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 796dd0dc..6527c7a7 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -1,5 +1,4 @@ -from collections.abc import Iterable -from typing import Any, Literal, Optional +from typing import Any, Optional from segmentation_models_pytorch.base import ( ClassificationHead, @@ -24,17 +23,13 @@ class DeepLabV3(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) decoder_channels: A number of convolution filters in ASPP module. Default is 256 - encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) - decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) - decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False - decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity). + upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -57,15 +52,11 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: Literal[8, 16] = 8, decoder_channels: int = 256, - decoder_atrous_rates: Iterable[int] = (12, 24, 36), - decoder_aspp_separable: bool = False, - decoder_aspp_dropout: float = 0.5, in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, - upsampling: Optional[int] = None, + upsampling: int = 8, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -76,16 +67,12 @@ def __init__( in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, - output_stride=encoder_output_stride, + output_stride=8, **kwargs, ) self.decoder = DeepLabV3Decoder( - in_channels=self.encoder.out_channels[-1], - out_channels=decoder_channels, - atrous_rates=decoder_atrous_rates, - aspp_separable=decoder_aspp_separable, - aspp_dropout=decoder_aspp_dropout, + in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels ) self.segmentation_head = SegmentationHead( @@ -93,7 +80,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=1, - upsampling=encoder_output_stride if upsampling is None else upsampling, + upsampling=upsampling, ) if aux_params is not None: @@ -118,9 +105,7 @@ class DeepLabV3Plus(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation) - decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values) - decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True - decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5 + decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values) decoder_channels: A number of convolution filters in ASPP module. Default is 256 in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -151,11 +136,9 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: Literal[8, 16] = 16, + encoder_output_stride: int = 16, decoder_channels: int = 256, - decoder_atrous_rates: Iterable[int] = (12, 24, 36), - decoder_aspp_separable: bool = True, - decoder_aspp_dropout: float = 0.5, + decoder_atrous_rates: tuple = (12, 24, 36), in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, @@ -165,6 +148,13 @@ def __init__( ): super().__init__() + if encoder_output_stride not in [8, 16]: + raise ValueError( + "Encoder output stride should be 8 or 16, got {}".format( + encoder_output_stride + ) + ) + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -179,8 +169,6 @@ def __init__( out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, - aspp_separable=decoder_aspp_separable, - aspp_dropout=decoder_aspp_dropout, ) self.segmentation_head = SegmentationHead( From 80d087e55a81a6bca4d85a36846ac464202bc2ad Mon Sep 17 00:00:00 2001 From: Dimitris Mantas Date: Wed, 6 Nov 2024 13:31:00 +0200 Subject: [PATCH 3/3] Rename dupls to duplicates --- segmentation_models_pytorch/encoders/timm_universal.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index b901f1a5..eb008221 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -28,7 +28,7 @@ def __init__( common_kwargs.pop("output_stride") self.model = timm.create_model( - name, **_merge_kwargs_no_dupls(common_kwargs, kwargs) + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) ) self._in_channels = in_channels @@ -50,9 +50,9 @@ def output_stride(self): return min(self._output_stride, 2**self._depth) -def _merge_kwargs_no_dupls(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: - dupls = a.keys() & b.keys() - if dupls: - raise ValueError(f"'{dupls}' already specified internally") +def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + duplicates = a.keys() & b.keys() + if duplicates: + raise ValueError(f"'{duplicates}' already specified internally") return a | b