From 3179751aa326d49e54eaa0973868f0efe0c75065 Mon Sep 17 00:00:00 2001 From: Munehiro Kobayashi Date: Sat, 12 Feb 2022 22:25:02 +0900 Subject: [PATCH 1/3] fix issue qubvel/segmentation_models.pytorch#377 --- .../decoders/deeplabv3/decoder.py | 12 ++++++++++-- .../decoders/deeplabv3/model.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index caeb95d1..54661ef8 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -71,6 +71,7 @@ class DeepLabV3PlusDecoder(nn.Module): def __init__( self, encoder_channels: Sequence[int, ...], + encoder_depth: Literal[3, 4, 5], out_channels: int, atrous_rates: Iterable[int], output_stride: Literal[8, 16], @@ -104,7 +105,14 @@ def __init__( scale_factor = 2 if output_stride == 8 else 4 self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) - highres_in_channels = encoder_channels[-4] + if encoder_depth == 3 and output_stride == 8: + self.highres_input_index = -2 + elif encoder_depth == 3 or encoder_depth == 4: + self.highres_input_index = -3 + else: + self.highres_input_index = -4 + + highres_in_channels = encoder_channels[self.highres_input_index] highres_out_channels = 48 # proposed by authors of paper self.block1 = nn.Sequential( nn.Conv2d( @@ -128,7 +136,7 @@ def __init__( def forward(self, *features): aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) - high_res_features = self.block1(features[-4]) + high_res_features = self.block1(features[self.highres_input_index]) concat_features = torch.cat([aspp_features, high_res_features], dim=1) fused_features = self.block2(concat_features) return fused_features diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index d67a3be3..08f980a5 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -150,7 +150,7 @@ class DeepLabV3Plus(SegmentationModel): def __init__( self, encoder_name: str = "resnet34", - encoder_depth: int = 5, + encoder_depth: Literal[3, 4, 5] = 5, encoder_weights: Optional[str] = "imagenet", encoder_output_stride: Literal[8, 16] = 16, decoder_channels: int = 256, @@ -177,6 +177,7 @@ def __init__( self.decoder = DeepLabV3PlusDecoder( encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, out_channels=decoder_channels, atrous_rates=decoder_atrous_rates, output_stride=encoder_output_stride, From 2efd97406080d5191dad5c2aef471dda71ed5aa0 Mon Sep 17 00:00:00 2001 From: Munehiro Kobayashi Date: Sun, 13 Feb 2022 12:02:06 +0900 Subject: [PATCH 2/3] modify docstring for upsampling of DeepLabV3Plus --- segmentation_models_pytorch/decoders/deeplabv3/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 08f980a5..c0ef1238 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -129,7 +129,8 @@ class DeepLabV3Plus(SegmentationModel): Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None** - upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity + upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case + **encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes From b39b8f351cda874bee4cd3d4699a2aa7468c1772 Mon Sep 17 00:00:00 2001 From: Munehiro Kobayashi Date: Mon, 25 Nov 2024 10:58:29 +0900 Subject: [PATCH 3/3] modify type hint and value check --- segmentation_models_pytorch/decoders/deeplabv3/decoder.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 54661ef8..e20acf3f 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -70,7 +70,7 @@ def forward(self, *features): class DeepLabV3PlusDecoder(nn.Module): def __init__( self, - encoder_channels: Sequence[int, ...], + encoder_channels: Sequence[int], encoder_depth: Literal[3, 4, 5], out_channels: int, atrous_rates: Iterable[int], @@ -79,7 +79,11 @@ def __init__( aspp_dropout: float, ): super().__init__() - if output_stride not in {8, 16}: + if encoder_depth not in (3, 4, 5): + raise ValueError( + "Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth) + ) + if output_stride not in (8, 16): raise ValueError( "Output stride should be 8 or 16, got {}.".format(output_stride) )