Skip to content

Commit 225823b

Browse files
markson14qubvel
andauthored
add timm-MobileNetV3 as an Encoder (#355)
* add timm-mobilenetv3 as encoder * fix import bug Co-authored-by: Pavel Yakubovskiy <qubvel@gmail.com>
1 parent 23a54b4 commit 225823b

File tree

4 files changed

+202
-1
lines changed

4 files changed

+202
-1
lines changed

README.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 109 available encoders
15+
- 115 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -337,6 +337,22 @@ The following is a list of supported encoders in the SMP. Select the appropriate
337337
</div>
338338
</details>
339339

340+
<details>
341+
<summary style="margin-left: 25px;">MobileNetV3</summary>
342+
<div style="margin-left: 25px;">
343+
344+
|Encoder |Weights |Params, M |
345+
|--------------------------------|:------------------------------:|:------------------------------:|
346+
|timm-mobilenetv3_large_075 |imagenet |1.78M |
347+
|timm-mobilenetv3_large_100 |imagenet |2.97M |
348+
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
349+
|timm-mobilenetv3_small_075 |imagenet |0.57M |
350+
|timm-mobilenetv3_small_100 |imagenet |0.93M |
351+
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
352+
353+
</div>
354+
</details>
355+
340356

341357
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
342358

docs/encoders.rst

+19
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,22 @@ VGG
316316
+-------------+------------+-------------+
317317
| vgg19\_bn | imagenet | 20M |
318318
+-------------+------------+-------------+
319+
320+
MobileNetV3
321+
~~~~~~~~~
322+
323+
+-----------------------------------+------------+-------------+
324+
| Encoder | Weights | Params, M |
325+
+===================================+============+=============+
326+
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
327+
+-----------------------------------+------------+-------------+
328+
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
329+
+-----------------------------------+------------+-------------+
330+
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
331+
+-----------------------------------+------------+-------------+
332+
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
333+
+-----------------------------------+------------+-------------+
334+
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
335+
+-----------------------------------+------------+-------------+
336+
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
337+
+-----------------------------------+------------+-------------+

segmentation_models_pytorch/encoders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .timm_res2net import timm_res2net_encoders
1818
from .timm_regnet import timm_regnet_encoders
1919
from .timm_sknet import timm_sknet_encoders
20+
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
2021
try:
2122
from .timm_gernet import timm_gernet_encoders
2223
except ImportError as e:
@@ -43,6 +44,7 @@
4344
encoders.update(timm_res2net_encoders)
4445
encoders.update(timm_regnet_encoders)
4546
encoders.update(timm_sknet_encoders)
47+
encoders.update(timm_mobilenetv3_encoders)
4648
encoders.update(timm_gernet_encoders)
4749

4850

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from timm import create_model
2+
import torch.nn as nn
3+
from ._base import EncoderMixin
4+
5+
6+
def make_divisible(x, divisible_by=8):
7+
import numpy as np
8+
return int(np.ceil(x * 1. / divisible_by) * divisible_by)
9+
10+
11+
class MobileNetV3Encoder(nn.Module, EncoderMixin):
12+
def __init__(self, model, width_mult, depth=5, **kwargs):
13+
super().__init__()
14+
self._depth = depth
15+
if 'small' in str(model):
16+
self.mode = 'small'
17+
self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult)
18+
self._out_channels = tuple(map(make_divisible, self._out_channels))
19+
elif 'large' in str(model):
20+
self.mode = 'large'
21+
self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult)
22+
self._out_channels = tuple(map(make_divisible, self._out_channels))
23+
else:
24+
self.mode = 'None'
25+
raise ValueError(
26+
'MobileNetV3 mode should be small or large, got {}'.format(self.mode))
27+
self._out_channels = (3,) + self._out_channels
28+
self._in_channels = 3
29+
# minimal models replace hardswish with relu
30+
model = create_model(model_name=model,
31+
scriptable=True, # torch.jit scriptable
32+
exportable=True, # onnx export
33+
features_only=True)
34+
self.conv_stem = model.conv_stem
35+
self.bn1 = model.bn1
36+
self.act1 = model.act1
37+
self.blocks = model.blocks
38+
39+
def get_stages(self):
40+
if self.mode == 'small':
41+
return [
42+
nn.Identity(),
43+
nn.Sequential(self.conv_stem, self.bn1, self.act1),
44+
self.blocks[0],
45+
self.blocks[1],
46+
self.blocks[2:4],
47+
self.blocks[4:],
48+
]
49+
elif self.mode == 'large':
50+
return [
51+
nn.Identity(),
52+
nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]),
53+
self.blocks[1],
54+
self.blocks[2],
55+
self.blocks[3:5],
56+
self.blocks[5:],
57+
]
58+
else:
59+
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode))
60+
61+
def forward(self, x):
62+
stages = self.get_stages()
63+
64+
features = []
65+
for i in range(self._depth + 1):
66+
x = stages[i](x)
67+
features.append(x)
68+
69+
return features
70+
71+
def load_state_dict(self, state_dict, **kwargs):
72+
state_dict.pop('conv_head.weight')
73+
state_dict.pop('conv_head.bias')
74+
state_dict.pop('classifier.weight')
75+
state_dict.pop('classifier.bias')
76+
super().load_state_dict(state_dict, **kwargs)
77+
78+
79+
mobilenetv3_weights = {
80+
'tf_mobilenetv3_large_075': {
81+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
82+
},
83+
'tf_mobilenetv3_large_100': {
84+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
85+
},
86+
'tf_mobilenetv3_large_minimal_100': {
87+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
88+
},
89+
'tf_mobilenetv3_small_075': {
90+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
91+
},
92+
'tf_mobilenetv3_small_100': {
93+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
94+
},
95+
'tf_mobilenetv3_small_minimal_100': {
96+
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
97+
},
98+
99+
100+
}
101+
102+
pretrained_settings = {}
103+
for model_name, sources in mobilenetv3_weights.items():
104+
pretrained_settings[model_name] = {}
105+
for source_name, source_url in sources.items():
106+
pretrained_settings[model_name][source_name] = {
107+
"url": source_url,
108+
'input_range': [0, 1],
109+
'mean': [0.485, 0.456, 0.406],
110+
'std': [0.229, 0.224, 0.225],
111+
'input_space': 'RGB',
112+
}
113+
114+
115+
timm_mobilenetv3_encoders = {
116+
'timm-mobilenetv3_large_075': {
117+
'encoder': MobileNetV3Encoder,
118+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
119+
'params': {
120+
'model': 'tf_mobilenetv3_large_075',
121+
'width_mult': 0.75
122+
}
123+
},
124+
'timm-mobilenetv3_large_100': {
125+
'encoder': MobileNetV3Encoder,
126+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
127+
'params': {
128+
'model': 'tf_mobilenetv3_large_100',
129+
'width_mult': 1.0
130+
}
131+
},
132+
'timm-mobilenetv3_large_minimal_100': {
133+
'encoder': MobileNetV3Encoder,
134+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
135+
'params': {
136+
'model': 'tf_mobilenetv3_large_minimal_100',
137+
'width_mult': 1.0
138+
}
139+
},
140+
'timm-mobilenetv3_small_075': {
141+
'encoder': MobileNetV3Encoder,
142+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
143+
'params': {
144+
'model': 'tf_mobilenetv3_small_075',
145+
'width_mult': 0.75
146+
}
147+
},
148+
'timm-mobilenetv3_small_100': {
149+
'encoder': MobileNetV3Encoder,
150+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
151+
'params': {
152+
'model': 'tf_mobilenetv3_small_100',
153+
'width_mult': 1.0
154+
}
155+
},
156+
'timm-mobilenetv3_small_minimal_100': {
157+
'encoder': MobileNetV3Encoder,
158+
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
159+
'params': {
160+
'model': 'tf_mobilenetv3_small_minimal_100',
161+
'width_mult': 1.0
162+
}
163+
},
164+
}

0 commit comments

Comments
 (0)