Skip to content

Commit a4220a0

Browse files
authored
Add MAnet (#310)
* add MAnet arch * add MAnet arch * update docs, readme, docstring * fix docstring, rename decoder parameters
1 parent c885803 commit a4220a0

File tree

7 files changed

+299
-8
lines changed

7 files changed

+299
-8
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
1111
The main features of this library are:
1212

1313
- High level API (just two lines to create neural network)
14-
- 8 models architectures for binary and multi class segmentation (including legendary Unet)
14+
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
1515
- 99 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

@@ -76,6 +76,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
7676
#### Architectures <a name="architectires"></a>
7777
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
7878
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
79+
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)]
7980
- Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)]
8081
- FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)]
8182
- PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)]

docs/models.rst

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Unet++
99
~~~~~~
1010
.. autoclass:: segmentation_models_pytorch.UnetPlusPlus
1111

12+
MAnet
13+
~~~~~~
14+
.. autoclass:: segmentation_models_pytorch.MAnet
15+
1216
Linknet
1317
~~~~~~~
1418
.. autoclass:: segmentation_models_pytorch.Linknet

segmentation_models_pytorch/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .unet import Unet
22
from .unetplusplus import UnetPlusPlus
3+
from .manet import MAnet
34
from .linknet import Linknet
45
from .fpn import FPN
56
from .pspnet import PSPNet
@@ -24,10 +25,10 @@ def create_model(
2425
**kwargs,
2526
) -> torch.nn.Module:
2627
"""Models wrapper. Allows to create any model just with parametes
27-
28+
2829
"""
29-
30-
archs = [Unet, UnetPlusPlus, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
30+
31+
archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
3132
archs_dict = {a.__name__.lower(): a for a in archs}
3233
try:
3334
model_class = archs_dict[arch.lower()]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import MAnet
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from ..base import modules as md
5+
6+
7+
class PAB(nn.Module):
8+
def __init__(self, in_channels, out_channels, pab_channels=64):
9+
super(PAB, self).__init__()
10+
# Series of 1x1 conv to generate attention feature maps
11+
self.pab_channels = pab_channels
12+
self.in_channels = in_channels
13+
self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
14+
self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
15+
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
16+
self.map_softmax = nn.Softmax(dim=1)
17+
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
18+
19+
def forward(self, x):
20+
bsize = x.size()[0]
21+
h = x.size()[2]
22+
w = x.size()[3]
23+
x_top = self.top_conv(x)
24+
x_center = self.center_conv(x)
25+
x_bottom = self.bottom_conv(x)
26+
27+
x_top = x_top.flatten(2)
28+
x_center = x_center.flatten(2).transpose(1, 2)
29+
x_bottom = x_bottom.flatten(2).transpose(1, 2)
30+
31+
sp_map = torch.matmul(x_center, x_top)
32+
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
33+
sp_map = torch.matmul(sp_map, x_bottom)
34+
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
35+
x = x + sp_map
36+
x = self.out_conv(x)
37+
return x
38+
39+
40+
class MFAB(nn.Module):
41+
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
42+
# MFAB is just a modified version of SE-blocks, one for skip, one for input
43+
super(MFAB, self).__init__()
44+
self.hl_conv = nn.Sequential(
45+
md.Conv2dReLU(
46+
in_channels,
47+
in_channels,
48+
kernel_size=3,
49+
padding=1,
50+
use_batchnorm=use_batchnorm,
51+
),
52+
md.Conv2dReLU(
53+
in_channels,
54+
skip_channels,
55+
kernel_size=1,
56+
use_batchnorm=use_batchnorm,
57+
)
58+
)
59+
self.SE_ll = nn.Sequential(
60+
nn.AdaptiveAvgPool2d(1),
61+
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
62+
nn.ReLU(inplace=True),
63+
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
64+
nn.Sigmoid(),
65+
)
66+
self.SE_hl = nn.Sequential(
67+
nn.AdaptiveAvgPool2d(1),
68+
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
69+
nn.ReLU(inplace=True),
70+
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
71+
nn.Sigmoid(),
72+
)
73+
self.conv1 = md.Conv2dReLU(
74+
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
75+
out_channels,
76+
kernel_size=3,
77+
padding=1,
78+
use_batchnorm=use_batchnorm,
79+
)
80+
self.conv2 = md.Conv2dReLU(
81+
out_channels,
82+
out_channels,
83+
kernel_size=3,
84+
padding=1,
85+
use_batchnorm=use_batchnorm,
86+
)
87+
88+
def forward(self, x, skip=None):
89+
x = self.hl_conv(x)
90+
x = F.interpolate(x, scale_factor=2, mode="nearest")
91+
attention_hl = self.SE_hl(x)
92+
if skip is not None:
93+
attention_ll = self.SE_ll(skip)
94+
attention_hl = attention_hl + attention_ll
95+
x = x * attention_hl
96+
x = torch.cat([x, skip], dim=1)
97+
x = self.conv1(x)
98+
x = self.conv2(x)
99+
return x
100+
101+
102+
class DecoderBlock(nn.Module):
103+
def __init__(
104+
self,
105+
in_channels,
106+
skip_channels,
107+
out_channels,
108+
use_batchnorm=True
109+
):
110+
super().__init__()
111+
self.conv1 = md.Conv2dReLU(
112+
in_channels + skip_channels,
113+
out_channels,
114+
kernel_size=3,
115+
padding=1,
116+
use_batchnorm=use_batchnorm,
117+
)
118+
self.conv2 = md.Conv2dReLU(
119+
out_channels,
120+
out_channels,
121+
kernel_size=3,
122+
padding=1,
123+
use_batchnorm=use_batchnorm,
124+
)
125+
126+
def forward(self, x, skip=None):
127+
x = F.interpolate(x, scale_factor=2, mode="nearest")
128+
if skip is not None:
129+
x = torch.cat([x, skip], dim=1)
130+
x = self.conv1(x)
131+
x = self.conv2(x)
132+
return x
133+
134+
135+
class MAnetDecoder(nn.Module):
136+
def __init__(
137+
self,
138+
encoder_channels,
139+
decoder_channels,
140+
n_blocks=5,
141+
reduction=16,
142+
use_batchnorm=True,
143+
pab_channels=64
144+
):
145+
super().__init__()
146+
147+
if n_blocks != len(decoder_channels):
148+
raise ValueError(
149+
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
150+
n_blocks, len(decoder_channels)
151+
)
152+
)
153+
154+
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
155+
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
156+
157+
# computing blocks input and output channels
158+
head_channels = encoder_channels[0]
159+
in_channels = [head_channels] + list(decoder_channels[:-1])
160+
skip_channels = list(encoder_channels[1:]) + [0]
161+
out_channels = decoder_channels
162+
163+
self.center = PAB(head_channels, head_channels, pab_channels=pab_channels)
164+
165+
# combine decoder keyword arguments
166+
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
167+
blocks = [
168+
MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else
169+
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
170+
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
171+
]
172+
# for the last we dont have skip connection -> use simple decoder block
173+
self.blocks = nn.ModuleList(blocks)
174+
175+
def forward(self, *features):
176+
177+
features = features[1:] # remove first skip with same spatial resolution
178+
features = features[::-1] # reverse channels to start from head of encoder
179+
180+
head = features[0]
181+
skips = features[1:]
182+
183+
x = self.center(head)
184+
for i, decoder_block in enumerate(self.blocks):
185+
skip = skips[i] if i < len(skips) else None
186+
x = decoder_block(x, skip)
187+
188+
return x
+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from typing import Optional, Union, List
2+
from .decoder import MAnetDecoder
3+
from ..encoders import get_encoder
4+
from ..base import SegmentationModel
5+
from ..base import SegmentationHead, ClassificationHead
6+
7+
8+
class MAnet(SegmentationModel):
9+
"""MAnet_ : Multi-scale Attention Net.
10+
The MA-Net can capture rich contextual dependencies based on the attention mechanism, using two blocks:
11+
Position-wise Attention Block (PAB, which captures the spatial dependencies between pixels in a global view)
12+
and Multi-scale Fusion Attention Block (MFAB, which captures the channel dependencies between any feature map by
13+
multi-scale semantic feature fusion)
14+
15+
Args:
16+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17+
to extract features of different spatial resolution
18+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19+
two times smaller in spatial dimentions than previous one (e.g. for depth 0 we will have features
20+
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21+
Default is 5
22+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23+
other pretrained weights (see table with available weights for each encoder_name)
24+
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
25+
Lenght of the list should be the same as **encoder_depth**
26+
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
27+
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
28+
Avaliable options are **True, False, "inplace"**
29+
decoder_pab_channels: A number of channels for PAB module in decoder.
30+
Default is 64.
31+
in_channels: A number of input channels for the model, default is 3 (RGB images)
32+
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
33+
activation: An activation function to apply after the final convolution layer.
34+
Avaliable options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"identity"**, **callable** and **None**.
35+
Default is **None**
36+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
37+
on top of encoder if **aux_params** is not **None** (default). Supported params:
38+
- classes (int): A number of classes
39+
- pooling (str): One of "max", "avg". Default is "avg"
40+
- dropout (float): Dropout factor in [0, 1)
41+
- activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits)
42+
43+
Returns:
44+
``torch.nn.Module``: **MAnet**
45+
46+
.. _MAnet:
47+
https://ieeexplore.ieee.org/abstract/document/9201310
48+
49+
"""
50+
51+
def __init__(
52+
self,
53+
encoder_name: str = "resnet34",
54+
encoder_depth: int = 5,
55+
encoder_weights: str = "imagenet",
56+
decoder_use_batchnorm: bool = True,
57+
decoder_channels: List[int] = (256, 128, 64, 32, 16),
58+
decoder_pab_channels: int = 64,
59+
in_channels: int = 3,
60+
classes: int = 1,
61+
activation: Optional[Union[str, callable]] = None,
62+
aux_params: Optional[dict] = None
63+
):
64+
super().__init__()
65+
66+
self.encoder = get_encoder(
67+
encoder_name,
68+
in_channels=in_channels,
69+
depth=encoder_depth,
70+
weights=encoder_weights,
71+
)
72+
73+
self.decoder = MAnetDecoder(
74+
encoder_channels=self.encoder.out_channels,
75+
decoder_channels=decoder_channels,
76+
n_blocks=encoder_depth,
77+
use_batchnorm=decoder_use_batchnorm,
78+
pab_channels=decoder_pab_channels
79+
)
80+
81+
self.segmentation_head = SegmentationHead(
82+
in_channels=decoder_channels[-1],
83+
out_channels=classes,
84+
activation=activation,
85+
kernel_size=3,
86+
)
87+
88+
if aux_params is not None:
89+
self.classification_head = ClassificationHead(
90+
in_channels=self.encoder.out_channels[-1], **aux_params
91+
)
92+
else:
93+
self.classification_head = None
94+
95+
self.name = "manet-{}".format(encoder_name)
96+
self.initialize()

tests/test_models.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_encoders():
2929

3030

3131
def get_sample(model_class):
32-
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]:
32+
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]:
3333
sample = torch.ones([1, 3, 64, 64])
3434
elif model_class == smp.PAN:
3535
sample = torch.ones([2, 3, 256, 256])
@@ -58,7 +58,7 @@ def _test_forward_backward(model, sample, test_shape=False):
5858
@pytest.mark.parametrize("encoder_depth", [3, 5])
5959
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
6060
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
61-
if model_class is smp.Unet or model_class is smp.UnetPlusPlus:
61+
if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet:
6262
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
6363
model = model_class(
6464
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
@@ -75,15 +75,15 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
7575

7676
@pytest.mark.parametrize(
7777
"model_class",
78-
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3]
78+
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3]
7979
)
8080
def test_forward_backward(model_class):
8181
sample = get_sample(model_class)
8282
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
8383
_test_forward_backward(model, sample)
8484

8585

86-
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
86+
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet])
8787
def test_aux_output(model_class):
8888
model = model_class(
8989
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)

0 commit comments

Comments
 (0)