Skip to content

add unet++ #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
### Models <a name="models"></a>

#### Architectures <a name="architectires"></a>
- [Unet](https://arxiv.org/abs/1505.04597)
- [Unet](https://arxiv.org/abs/1505.04597) and [Unet++](https://arxiv.org/pdf/1807.10165.pdf)
- [Linknet](https://arxiv.org/abs/1707.03718)
- [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)
- [PSPNet](https://arxiv.org/abs/1612.01105)
Expand Down
1 change: 1 addition & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .unet import Unet
from .unetplusplus import UnetPlusPlus
from .linknet import Linknet
from .fpn import FPN
from .pspnet import PSPNet
Expand Down
1 change: 1 addition & 0 deletions segmentation_models_pytorch/unetplusplus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import UnetPlusPlus
136 changes: 136 additions & 0 deletions segmentation_models_pytorch/unetplusplus/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..base import modules as md


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x


class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)


class UnetPlusPlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
):
super().__init__()

if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)

encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
self.in_channels = [head_channels] + list(decoder_channels[:-1])
self.skip_channels = list(encoder_channels[1:]) + [0]
self.out_channels = decoder_channels
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)

blocks = {}
for layer_idx in range(len(self.in_channels) - 1):
for depth_idx in range(layer_idx+1):
if depth_idx == 0:
in_ch = self.in_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
out_ch = self.out_channels[layer_idx]
else:
out_ch = self.skip_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx)
in_ch = self.skip_channels[layer_idx - 1]
blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
blocks[f'x_{0}_{len(self.in_channels)-1}'] =\
DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs)
self.blocks = nn.ModuleDict(blocks)
self.depth = len(self.in_channels) - 1

def forward(self, *features):

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
# start bulding dense connections
dense_x = {}
for layer_idx in range(len(self.in_channels)-1):
for depth_idx in range(self.depth-layer_idx):
if layer_idx == 0:
output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1])
dense_x[f'x_{depth_idx}_{depth_idx}'] = output
else:
dense_l_i = depth_idx + layer_idx
cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)]
cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1)
dense_x[f'x_{depth_idx}_{dense_l_i}'] =\
self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features)
dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}'])
return dense_x[f'x_{0}_{self.depth}']
90 changes: 90 additions & 0 deletions segmentation_models_pytorch/unetplusplus/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional, Union, List
from .decoder import UnetPlusPlusDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class UnetPlusPlus(SegmentationModel):
"""Unet++_ is a fully convolution neural network for image semantic segmentation

Args:
encoder_name: name of classification model (without last dense layers) used as feature
extractor to build segmentation model.
encoder_depth (int): number of stages used in decoder, larger depth - more features are generated.
e.g. for depth=3 encoder will generate list of features with following spatial shapes
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have
spatial resolution (H/(2^depth), W/(2^depth)]
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
One of [True, False, 'inplace']
decoder_attention_type: attention module used in decoder of the model
One of [``None``, ``scse``]
in_channels: number of input channels for model, default is 3.
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
activation: activation function to apply after final convolution;
One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None]
aux_params: if specified model will have additional classification auxiliary output
build on top of encoder, supported params:
- classes (int): number of classes
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
- dropout (float): dropout factor in [0, 1)
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)

Returns:
``torch.nn.Module``: **Unet++**

.. _UnetPlusPlus:
https://arxiv.org/pdf/1807.10165.pdf

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: str = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)

self.decoder = UnetPlusPlusDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
attention_type=decoder_attention_type,
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "u-{}".format(encoder_name)
self.initialize()
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_encoders():


def get_sample(model_class):
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet]:
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
sample = torch.ones([2, 3, 256, 256])
Expand All @@ -57,9 +57,9 @@ def _test_forward_backward(model, sample, test_shape=False):

@pytest.mark.parametrize("encoder_name", ENCODERS)
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if model_class is smp.Unet:
if model_class is smp.Unet or model_class is smp.UnetPlusPlus:
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
model = model_class(
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
Expand All @@ -76,15 +76,15 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):

@pytest.mark.parametrize(
"model_class",
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.DeepLabV3]
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3]
)
def test_forward_backward(model_class):
sample = get_sample(model_class)
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
_test_forward_backward(model, sample)


@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
def test_aux_output(model_class):
model = model_class(
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
Expand Down