diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index e20acf3f..3fd73786 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -61,7 +61,6 @@ def __init__( nn.BatchNorm2d(out_channels), nn.ReLU(), ) - self.out_channels = out_channels def forward(self, *features): return super().forward(features[-1]) @@ -79,17 +78,12 @@ def __init__( aspp_dropout: float, ): super().__init__() - if encoder_depth not in (3, 4, 5): + if encoder_depth < 3: raise ValueError( - "Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth) + "Encoder depth for DeepLabV3Plus decoder cannot be less than 3, got {}.".format( + encoder_depth + ) ) - if output_stride not in (8, 16): - raise ValueError( - "Output stride should be 8 or 16, got {}.".format(output_stride) - ) - - self.out_channels = out_channels - self.output_stride = output_stride self.aspp = nn.Sequential( ASPP( @@ -106,17 +100,10 @@ def __init__( nn.ReLU(), ) - scale_factor = 2 if output_stride == 8 else 4 + scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2 self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) - if encoder_depth == 3 and output_stride == 8: - self.highres_input_index = -2 - elif encoder_depth == 3 or encoder_depth == 4: - self.highres_input_index = -3 - else: - self.highres_input_index = -4 - - highres_in_channels = encoder_channels[self.highres_input_index] + highres_in_channels = encoder_channels[2] highres_out_channels = 48 # proposed by authors of paper self.block1 = nn.Sequential( nn.Conv2d( @@ -140,7 +127,7 @@ def __init__( def forward(self, *features): aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) - high_res_features = self.block1(features[self.highres_input_index]) + high_res_features = self.block1(features[2]) concat_features = torch.cat([aspp_features, high_res_features], dim=1) fused_features = self.block2(concat_features) return fused_features @@ -240,13 +227,13 @@ def forward(self, x): class SeparableConv2d(nn.Sequential): def __init__( self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - bias=True, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = True, ): dephtwise_conv = nn.Conv2d( in_channels, diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index c0ef1238..830906cb 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -35,7 +35,7 @@ class DeepLabV3(SegmentationModel): Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity). + upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -43,7 +43,8 @@ class DeepLabV3(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) - kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. + Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **DeepLabV3** @@ -72,6 +73,12 @@ def __init__( ): super().__init__() + if encoder_output_stride not in [8, 16]: + raise ValueError( + "DeeplabV3 support output stride 8 or 16, got {}.".format( + encoder_output_stride + ) + ) self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -81,6 +88,14 @@ def __init__( **kwargs, ) + if upsampling is None: + if encoder_depth <= 3: + scale_factor = 2**encoder_depth + else: + scale_factor = encoder_output_stride + else: + scale_factor = upsampling + self.decoder = DeepLabV3Decoder( in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels, @@ -90,11 +105,11 @@ def __init__( ) self.segmentation_head = SegmentationHead( - in_channels=self.decoder.out_channels, + in_channels=decoder_channels, out_channels=classes, activation=activation, kernel_size=1, - upsampling=encoder_output_stride if upsampling is None else upsampling, + upsampling=scale_factor, ) if aux_params is not None: @@ -129,8 +144,7 @@ class DeepLabV3Plus(SegmentationModel): Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case - **encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve. + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -138,7 +152,8 @@ class DeepLabV3Plus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) - kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. + Keys with ``None`` values are pruned before passing. Returns: ``torch.nn.Module``: **DeepLabV3Plus** @@ -167,6 +182,13 @@ def __init__( ): super().__init__() + if encoder_output_stride not in [8, 16]: + raise ValueError( + "DeeplabV3Plus support output stride 8 or 16, got {}.".format( + encoder_output_stride + ) + ) + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -187,7 +209,7 @@ def __init__( ) self.segmentation_head = SegmentationHead( - in_channels=self.decoder.out_channels, + in_channels=decoder_channels, out_channels=classes, activation=activation, kernel_size=1,