From 00b04dd2a3ccf2f9ea39d34fac12f5b2ee70017c Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Thu, 19 Nov 2020 10:01:10 -0600 Subject: [PATCH 1/4] add unet++ --- segmentation_models_pytorch/__init__.py | 1 + .../unetplusplus/__init__.py | 1 + .../unetplusplus/decoder.py | 136 ++++++++++++++++++ .../unetplusplus/model.py | 90 ++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 segmentation_models_pytorch/unetplusplus/__init__.py create mode 100644 segmentation_models_pytorch/unetplusplus/decoder.py create mode 100644 segmentation_models_pytorch/unetplusplus/model.py diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 244868fd..cbf685fa 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -1,4 +1,5 @@ from .unet import Unet +from .unetplusplus import UnetPlusPlus from .linknet import Linknet from .fpn import FPN from .pspnet import PSPNet diff --git a/segmentation_models_pytorch/unetplusplus/__init__.py b/segmentation_models_pytorch/unetplusplus/__init__.py new file mode 100644 index 00000000..bda62b70 --- /dev/null +++ b/segmentation_models_pytorch/unetplusplus/__init__.py @@ -0,0 +1 @@ +from .model import UnetPlusPlus diff --git a/segmentation_models_pytorch/unetplusplus/decoder.py b/segmentation_models_pytorch/unetplusplus/decoder.py new file mode 100644 index 00000000..0cca88ad --- /dev/null +++ b/segmentation_models_pytorch/unetplusplus/decoder.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=None, + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.attention2 = md.Attention(attention_type, in_channels=out_channels) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution + encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels] + list(decoder_channels[:-1]) + self.skip_channels = list(encoder_channels[1:]) + [0] + self.out_channels = decoder_channels + if center: + self.center = CenterBlock( + head_channels, head_channels, use_batchnorm=use_batchnorm + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx+1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx+1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + blocks[f'x_{0}_{len(self.in_channels)-1}'] =\ + DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, *features): + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + # start bulding dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels)-1): + for depth_idx in range(self.depth-layer_idx): + if layer_idx == 0: + output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1]) + dense_x[f'x_{depth_idx}_{depth_idx}'] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)] + cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1) + dense_x[f'x_{depth_idx}_{dense_l_i}'] =\ + self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features) + dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}']) + return dense_x[f'x_{0}_{self.depth}'] diff --git a/segmentation_models_pytorch/unetplusplus/model.py b/segmentation_models_pytorch/unetplusplus/model.py new file mode 100644 index 00000000..040827cd --- /dev/null +++ b/segmentation_models_pytorch/unetplusplus/model.py @@ -0,0 +1,90 @@ +from typing import Optional, Union, List +from .decoder import UnetPlusPlusDecoder +from ..encoders import get_encoder +from ..base import SegmentationModel +from ..base import SegmentationHead, ClassificationHead + + +class UnetPlusPlus(SegmentationModel): + """Unet++_ is a fully convolution neural network for image semantic segmentation + + Args: + encoder_name: name of classification model (without last dense layers) used as feature + extractor to build segmentation model. + encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. + e.g. for depth=3 encoder will generate list of features with following spatial shapes + [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have + spatial resolution (H/(2^depth), W/(2^depth)] + encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). + decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks + decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers + is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. + One of [True, False, 'inplace'] + decoder_attention_type: attention module used in decoder of the model + One of [``None``, ``scse``] + in_channels: number of input channels for model, default is 3. + classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). + activation: activation function to apply after final convolution; + One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] + aux_params: if specified model will have additional classification auxiliary output + build on top of encoder, supported params: + - classes (int): number of classes + - pooling (str): one of 'max', 'avg'. Default is 'avg'. + - dropout (float): dropout factor in [0, 1) + - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) + + Returns: + ``torch.nn.Module``: **Unet** + + .. _UnetPlusPlus: + https://arxiv.org/pdf/1807.10165.pdf + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: str = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + center=True if encoder_name.startswith("vgg") else False, + attention_type=decoder_attention_type, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "u-{}".format(encoder_name) + self.initialize() From 2d8565dbc6063eb0f35ec5395f9b421dc2b2bc77 Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Mon, 23 Nov 2020 09:07:47 -0600 Subject: [PATCH 2/4] update README.md --- README.md | 2 +- segmentation_models_pytorch/unetplusplus/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4aa4ce69..88c572b2 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') ### Models #### Architectures - - [Unet](https://arxiv.org/abs/1505.04597) + - [Unet](https://arxiv.org/abs/1505.04597) and [Unet++](https://arxiv.org/pdf/1807.10165.pdf) - [Linknet](https://arxiv.org/abs/1707.03718) - [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf) - [PSPNet](https://arxiv.org/abs/1612.01105) diff --git a/segmentation_models_pytorch/unetplusplus/model.py b/segmentation_models_pytorch/unetplusplus/model.py index 040827cd..bc0d8d21 100644 --- a/segmentation_models_pytorch/unetplusplus/model.py +++ b/segmentation_models_pytorch/unetplusplus/model.py @@ -34,7 +34,7 @@ class UnetPlusPlus(SegmentationModel): - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) Returns: - ``torch.nn.Module``: **Unet** + ``torch.nn.Module``: **Unet++** .. _UnetPlusPlus: https://arxiv.org/pdf/1807.10165.pdf From 6e49ca5a3cea04da735f12e1cf53f6b61bc05a8d Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Mon, 23 Nov 2020 09:14:43 -0600 Subject: [PATCH 3/4] update tests for unet++ --- tests/test_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 76c8d8ac..d2162c57 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -30,7 +30,7 @@ def get_encoders(): def get_sample(model_class): - if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet]: + if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) @@ -57,7 +57,7 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_name", ENCODERS) @pytest.mark.parametrize("encoder_depth", [3, 5]) -@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet]) +@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): if model_class is smp.Unet: kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] @@ -76,7 +76,7 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): @pytest.mark.parametrize( "model_class", - [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.DeepLabV3] + [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3] ) def test_forward_backward(model_class): sample = get_sample(model_class) @@ -84,7 +84,7 @@ def test_forward_backward(model_class): _test_forward_backward(model, sample) -@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet]) +@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) def test_aux_output(model_class): model = model_class( DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) From b2f506ac42229ed0d391a2c8d788b2e9c6f930ff Mon Sep 17 00:00:00 2001 From: Sergey Kolchenko Date: Mon, 23 Nov 2020 09:52:16 -0600 Subject: [PATCH 4/4] fixed test behaviour for unet++ --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index d2162c57..da6c3168 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -59,7 +59,7 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_depth", [3, 5]) @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): - if model_class is smp.Unet: + if model_class is smp.Unet or model_class is smp.UnetPlusPlus: kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] model = model_class( encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs