Skip to content

Add MAnet #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of this library are:

- High level API (just two lines to create neural network)
- 8 models architectures for binary and multi class segmentation (including legendary Unet)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 99 available encoders
- All encoders have pre-trained weights for faster and better convergence

Expand Down Expand Up @@ -76,6 +76,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
#### Architectures <a name="architectires"></a>
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)]
- Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)]
- FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)]
- PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)]
Expand Down
4 changes: 4 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Unet++
~~~~~~
.. autoclass:: segmentation_models_pytorch.UnetPlusPlus

MAnet
~~~~~~
.. autoclass:: segmentation_models_pytorch.MAnet

Linknet
~~~~~~~
.. autoclass:: segmentation_models_pytorch.Linknet
Expand Down
7 changes: 4 additions & 3 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .unet import Unet
from .unetplusplus import UnetPlusPlus
from .manet import MAnet
from .linknet import Linknet
from .fpn import FPN
from .pspnet import PSPNet
Expand All @@ -24,10 +25,10 @@ def create_model(
**kwargs,
) -> torch.nn.Module:
"""Models wrapper. Allows to create any model just with parametes

"""
archs = [Unet, UnetPlusPlus, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]

archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
model_class = archs_dict[arch.lower()]
Expand Down
1 change: 1 addition & 0 deletions segmentation_models_pytorch/manet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import MAnet
188 changes: 188 additions & 0 deletions segmentation_models_pytorch/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md


class PAB(nn.Module):
def __init__(self, in_channels, out_channels, pab_channels=64):
super(PAB, self).__init__()
# Series of 1x1 conv to generate attention feature maps
self.pab_channels = pab_channels
self.in_channels = in_channels
self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.map_softmax = nn.Softmax(dim=1)
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

def forward(self, x):
bsize = x.size()[0]
h = x.size()[2]
w = x.size()[3]
x_top = self.top_conv(x)
x_center = self.center_conv(x)
x_bottom = self.bottom_conv(x)

x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)

sp_map = torch.matmul(x_center, x_top)
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
sp_map = torch.matmul(sp_map, x_bottom)
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
x = x + sp_map
x = self.out_conv(x)
return x


class MFAB(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
# MFAB is just a modified version of SE-blocks, one for skip, one for input
super(MFAB, self).__init__()
self.hl_conv = nn.Sequential(
md.Conv2dReLU(
in_channels,
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
),
md.Conv2dReLU(
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = md.Conv2dReLU(
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)

def forward(self, x, skip=None):
x = self.hl_conv(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
attention_hl = self.SE_hl(x)
if skip is not None:
attention_ll = self.SE_ll(skip)
attention_hl = attention_hl + attention_ll
x = x * attention_hl
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class MAnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
reduction=16,
use_batchnorm=True,
pab_channels=64
):
super().__init__()

if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)

encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder

# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

self.center = PAB(head_channels, head_channels, pab_channels=pab_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
blocks = [
MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
# for the last we dont have skip connection -> use simple decoder block
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]

x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)

return x
96 changes: 96 additions & 0 deletions segmentation_models_pytorch/manet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional, Union, List
from .decoder import MAnetDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class MAnet(SegmentationModel):
"""MAnet_ : Multi-scale Attention Net.
The MA-Net can capture rich contextual dependencies based on the attention mechanism, using two blocks:
Position-wise Attention Block (PAB, which captures the spatial dependencies between pixels in a global view)
and Multi-scale Fusion Attention Block (MFAB, which captures the channel dependencies between any feature map by
multi-scale semantic feature fusion)

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimentions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
Lenght of the list should be the same as **encoder_depth**
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Avaliable options are **True, False, "inplace"**
decoder_pab_channels: A number of channels for PAB module in decoder.
Default is 64.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Avaliable options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"identity"**, **callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits)

Returns:
``torch.nn.Module``: **MAnet**

.. _MAnet:
https://ieeexplore.ieee.org/abstract/document/9201310

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: str = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_pab_channels: int = 64,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)

self.decoder = MAnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
pab_channels=decoder_pab_channels
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "manet-{}".format(encoder_name)
self.initialize()
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_encoders():


def get_sample(model_class):
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]:
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
sample = torch.ones([2, 3, 256, 256])
Expand Down Expand Up @@ -58,7 +58,7 @@ def _test_forward_backward(model, sample, test_shape=False):
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if model_class is smp.Unet or model_class is smp.UnetPlusPlus:
if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet:
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
model = model_class(
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
Expand All @@ -75,15 +75,15 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):

@pytest.mark.parametrize(
"model_class",
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3]
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3]
)
def test_forward_backward(model_class):
sample = get_sample(model_class)
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
_test_forward_backward(model, sample)


@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet])
def test_aux_output(model_class):
model = model_class(
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
Expand Down