From 351122122738b8e0cd69fd74d61c25842773b9c5 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 13:18:12 +0300 Subject: [PATCH 1/7] gernet from regnet --- requirements.txt | 2 +- .../encoders/timm_gernet.py | 332 ++++++++++++++++++ 2 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 segmentation_models_pytorch/encoders/timm_gernet.py diff --git a/requirements.txt b/requirements.txt index a88a7a87..c1bbde72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ torchvision>=0.3.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.6.3 -timm==0.3.2 +git+https://github.com/rwightman/pytorch-image-models@d8e69206be253892b2956341fea09fdebfaae4e3 diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py new file mode 100644 index 00000000..e02ad59b --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -0,0 +1,332 @@ +from ._base import EncoderMixin +from timm.models.regnet import RegNet +import torch.nn as nn + + +class RegNetEncoder(RegNet, EncoderMixin): + def __init__(self, out_channels, depth=5, **kwargs): + super().__init__(**kwargs) + self._depth = depth + self._out_channels = out_channels + self._in_channels = 3 + + del self.head + + def get_stages(self): + return [ + nn.Identity(), + self.stem, + self.s1, + self.s2, + self.s3, + self.s4, + ] + + def forward(self, x): + stages = self.get_stages() + + features = [] + for i in range(self._depth + 1): + x = stages[i](x) + features.append(x) + + return features + + def load_state_dict(self, state_dict, **kwargs): + state_dict.pop("head.fc.weight") + state_dict.pop("head.fc.bias") + super().load_state_dict(state_dict, **kwargs) + + +regnet_weights = { + 'timm-regnetx_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth', + }, + 'timm-regnetx_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth', + }, + 'timm-regnetx_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth', + }, + 'timm-regnetx_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth', + }, + 'timm-regnetx_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth', + }, + 'timm-regnetx_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth', + }, + 'timm-regnetx_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth', + }, + 'timm-regnetx_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth', + }, + 'timm-regnetx_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth', + }, + 'timm-regnetx_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth', + }, + 'timm-regnetx_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth', + }, + 'timm-regnetx_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth', + }, + 'timm-regnety_002': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth', + }, + 'timm-regnety_004': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth', + }, + 'timm-regnety_006': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth', + }, + 'timm-regnety_008': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth', + }, + 'timm-regnety_016': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth', + }, + 'timm-regnety_032': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth' + }, + 'timm-regnety_040': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth' + }, + 'timm-regnety_064': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth' + }, + 'timm-regnety_080': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth', + }, + 'timm-regnety_120': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth', + }, + 'timm-regnety_160': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth', + }, + 'timm-regnety_320': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' + } +} + +pretrained_settings = {} +for model_name, sources in regnet_weights.items(): + pretrained_settings[model_name] = {} + for source_name, source_url in sources.items(): + pretrained_settings[model_name][source_name] = { + "url": source_url, + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + +# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +timm_regnet_encoders = { + 'timm-regnetx_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13) + }, + }, + 'timm-regnetx_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_004"], + 'params': { + 'out_channels': (3, 32, 32, 64, 160, 384), + 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22) + }, + }, + 'timm-regnetx_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_006"], + 'params': { + 'out_channels': (3, 32, 48, 96, 240, 528), + 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16) + }, + }, + 'timm-regnetx_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 288, 672), + 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16) + }, + }, + 'timm-regnetx_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_016"], + 'params': { + 'out_channels': (3, 32, 72, 168, 408, 912), + 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18) + }, + }, + 'timm-regnetx_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_032"], + 'params': { + 'out_channels': (3, 32, 96, 192, 432, 1008), + 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25) + }, + }, + 'timm-regnetx_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_040"], + 'params': { + 'out_channels': (3, 32, 80, 240, 560, 1360), + 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23) + }, + }, + 'timm-regnetx_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_064"], + 'params': { + 'out_channels': (3, 32, 168, 392, 784, 1624), + 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17) + }, + }, + 'timm-regnetx_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_080"], + 'params': { + 'out_channels': (3, 32, 80, 240, 720, 1920), + 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23) + }, + }, + 'timm-regnetx_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19) + }, + }, + 'timm-regnetx_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_160"], + 'params': { + 'out_channels': (3, 32, 256, 512, 896, 2048), + 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22) + }, + }, + 'timm-regnetx_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnetx_320"], + 'params': { + 'out_channels': (3, 32, 336, 672, 1344, 2520), + 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23) + }, + }, + #regnety + 'timm-regnety_002': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_002"], + 'params': { + 'out_channels': (3, 32, 24, 56, 152, 368), + 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25) + }, + }, + 'timm-regnety_004': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_004"], + 'params': { + 'out_channels': (3, 32, 48, 104, 208, 440), + 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25) + }, + }, + 'timm-regnety_006': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_006"], + 'params': { + 'out_channels': (3, 32, 48, 112, 256, 608), + 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25) + }, + }, + 'timm-regnety_008': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_008"], + 'params': { + 'out_channels': (3, 32, 64, 128, 320, 768), + 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25) + }, + }, + 'timm-regnety_016': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_016"], + 'params': { + 'out_channels': (3, 32, 48, 120, 336, 888), + 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25) + }, + }, + 'timm-regnety_032': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_032"], + 'params': { + 'out_channels': (3, 32, 72, 216, 576, 1512), + 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25) + }, + }, + 'timm-regnety_040': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_040"], + 'params': { + 'out_channels': (3, 32, 128, 192, 512, 1088), + 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25) + }, + }, + 'timm-regnety_064': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_064"], + 'params': { + 'out_channels': (3, 32, 144, 288, 576, 1296), + 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25) + }, + }, + 'timm-regnety_080': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_080"], + 'params': { + 'out_channels': (3, 32, 168, 448, 896, 2016), + 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25) + }, + }, + 'timm-regnety_120': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_120"], + 'params': { + 'out_channels': (3, 32, 224, 448, 896, 2240), + 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25) + }, + }, + 'timm-regnety_160': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_160"], + 'params': { + 'out_channels': (3, 32, 224, 448, 1232, 3024), + 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25) + }, + }, + 'timm-regnety_320': { + 'encoder': RegNetEncoder, + "pretrained_settings": pretrained_settings["timm-regnety_320"], + 'params': { + 'out_channels': (3, 32, 232, 696, 1392, 3712), + 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25) + }, + }, +} From 099c08fb35d2b863d6e4cb258efb48ae3d7162b3 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 15:04:44 +0300 Subject: [PATCH 2/7] basic gernet --- .../encoders/__init__.py | 2 + .../encoders/timm_gernet.py | 344 ++++-------------- 2 files changed, 68 insertions(+), 278 deletions(-) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 5b153da8..3df33e11 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -16,6 +16,7 @@ from .timm_res2net import timm_res2net_encoders from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders +from .timm_gernet import timm_gernet_encoders from ._preprocessing import preprocess_input encoders = {} @@ -34,6 +35,7 @@ encoders.update(timm_res2net_encoders) encoders.update(timm_regnet_encoders) encoders.update(timm_sknet_encoders) +encoders.update(timm_gernet_encoders) def get_encoder(name, in_channels=3, depth=5, weights=None): diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py index e02ad59b..55f5c3ce 100644 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -1,10 +1,11 @@ +from timm.models import ByobCfg, BlocksCfg, ByobNet + from ._base import EncoderMixin -from timm.models.regnet import RegNet import torch.nn as nn -class RegNetEncoder(RegNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): +class GERNetEncoder(ByobNet, EncoderMixin): + def __init__(self, out_channels, depth=6, **kwargs): super().__init__(**kwargs) self._depth = depth self._out_channels = out_channels @@ -16,10 +17,8 @@ def get_stages(self): return [ nn.Identity(), self.stem, - self.s1, - self.s2, - self.s3, - self.s4, + *self.stages[:-1], + nn.Sequential(self.stages[-1], self.final_conv) ] def forward(self, x): @@ -39,78 +38,15 @@ def load_state_dict(self, state_dict, **kwargs): regnet_weights = { - 'timm-regnetx_002': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth', - }, - 'timm-regnetx_004': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth', - }, - 'timm-regnetx_006': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth', - }, - 'timm-regnetx_008': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth', - }, - 'timm-regnetx_016': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth', - }, - 'timm-regnetx_032': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth', - }, - 'timm-regnetx_040': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth', - }, - 'timm-regnetx_064': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth', - }, - 'timm-regnetx_080': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth', - }, - 'timm-regnetx_120': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth', - }, - 'timm-regnetx_160': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth', - }, - 'timm-regnetx_320': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth', + 'timm-gernet_s': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth', }, - 'timm-regnety_002': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth', + 'timm-gernet_m': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth', }, - 'timm-regnety_004': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth', + 'timm-gernet_l': { + 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth', }, - 'timm-regnety_006': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth', - }, - 'timm-regnety_008': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth', - }, - 'timm-regnety_016': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth', - }, - 'timm-regnety_032': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth' - }, - 'timm-regnety_040': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth' - }, - 'timm-regnety_064': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth' - }, - 'timm-regnety_080': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth', - }, - 'timm-regnety_120': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth', - }, - 'timm-regnety_160': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth', - }, - 'timm-regnety_320': { - 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' - } } pretrained_settings = {} @@ -119,214 +55,66 @@ def load_state_dict(self, state_dict, **kwargs): for source_name, source_url in sources.items(): pretrained_settings[model_name][source_name] = { "url": source_url, - 'input_size': [3, 224, 224], + 'input_size': [3, 224, 224] if not model_name == 'timm-gernet_l' else [3, 256, 256], 'input_range': [0, 1], 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'num_classes': 1000 } -# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo - - -def _mcfg(**kwargs): - cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) - cfg.update(**kwargs) - return cfg - - -timm_regnet_encoders = { - 'timm-regnetx_002': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_002"], - 'params': { - 'out_channels': (3, 32, 24, 56, 152, 368), - 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13) - }, - }, - 'timm-regnetx_004': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_004"], - 'params': { - 'out_channels': (3, 32, 32, 64, 160, 384), - 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22) - }, - }, - 'timm-regnetx_006': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_006"], - 'params': { - 'out_channels': (3, 32, 48, 96, 240, 528), - 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16) - }, - }, - 'timm-regnetx_008': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_008"], - 'params': { - 'out_channels': (3, 32, 64, 128, 288, 672), - 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16) - }, - }, - 'timm-regnetx_016': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_016"], - 'params': { - 'out_channels': (3, 32, 72, 168, 408, 912), - 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18) - }, - }, - 'timm-regnetx_032': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_032"], - 'params': { - 'out_channels': (3, 32, 96, 192, 432, 1008), - 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25) - }, - }, - 'timm-regnetx_040': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_040"], - 'params': { - 'out_channels': (3, 32, 80, 240, 560, 1360), - 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23) - }, - }, - 'timm-regnetx_064': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_064"], - 'params': { - 'out_channels': (3, 32, 168, 392, 784, 1624), - 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17) - }, - }, - 'timm-regnetx_080': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_080"], - 'params': { - 'out_channels': (3, 32, 80, 240, 720, 1920), - 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23) - }, - }, - 'timm-regnetx_120': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_120"], - 'params': { - 'out_channels': (3, 32, 224, 448, 896, 2240), - 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19) - }, - }, - 'timm-regnetx_160': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_160"], - 'params': { - 'out_channels': (3, 32, 256, 512, 896, 2048), - 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22) - }, - }, - 'timm-regnetx_320': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_320"], - 'params': { - 'out_channels': (3, 32, 336, 672, 1344, 2520), - 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23) - }, - }, - #regnety - 'timm-regnety_002': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_002"], - 'params': { - 'out_channels': (3, 32, 24, 56, 152, 368), - 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25) - }, - }, - 'timm-regnety_004': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_004"], - 'params': { - 'out_channels': (3, 32, 48, 104, 208, 440), - 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25) - }, - }, - 'timm-regnety_006': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_006"], - 'params': { - 'out_channels': (3, 32, 48, 112, 256, 608), - 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25) - }, - }, - 'timm-regnety_008': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_008"], - 'params': { - 'out_channels': (3, 32, 64, 128, 320, 768), - 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25) - }, - }, - 'timm-regnety_016': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_016"], - 'params': { - 'out_channels': (3, 32, 48, 120, 336, 888), - 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25) - }, - }, - 'timm-regnety_032': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_032"], - 'params': { - 'out_channels': (3, 32, 72, 216, 576, 1512), - 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25) - }, - }, - 'timm-regnety_040': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_040"], - 'params': { - 'out_channels': (3, 32, 128, 192, 512, 1088), - 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25) - }, - }, - 'timm-regnety_064': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_064"], - 'params': { - 'out_channels': (3, 32, 144, 288, 576, 1296), - 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25) - }, - }, - 'timm-regnety_080': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_080"], - 'params': { - 'out_channels': (3, 32, 168, 448, 896, 2016), - 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25) - }, - }, - 'timm-regnety_120': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_120"], - 'params': { - 'out_channels': (3, 32, 224, 448, 896, 2240), - 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25) - }, - }, - 'timm-regnety_160': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_160"], - 'params': { - 'out_channels': (3, 32, 224, 448, 1232, 3024), - 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25) - }, - }, - 'timm-regnety_320': { - 'encoder': RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_320"], - 'params': { - 'out_channels': (3, 32, 232, 696, 1392, 3712), - 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25) +timm_gernet_encoders = { + 'timm-gernet_s': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_s"], + 'params': { + 'out_channels': (3, 13, 48, 48, 384, 560, 1920), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ), + stem_chs=13, + num_features=1920, + ) + }, + }, + 'timm-gernet_m': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_m"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 640, 2560), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + num_features=2560, + ) + }, + }, + 'timm-gernet_l': { + 'encoder': GERNetEncoder, + "pretrained_settings": pretrained_settings["timm-gernet_l"], + 'params': { + 'out_channels': (3, 32, 128, 192, 640, 640, 2560), + 'cfg': ByobCfg( + blocks=( + BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + num_features=2560, + ) }, }, } From cc42d1c8d341a90332fec39ba61af57c62a0a7d6 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 17:07:18 +0300 Subject: [PATCH 3/7] depth set to 5, and requirements+import update --- requirements.txt | 2 +- segmentation_models_pytorch/encoders/__init__.py | 8 +++++++- .../encoders/timm_gernet.py | 14 ++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index c1bbde72..a88a7a87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ torchvision>=0.3.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.6.3 -git+https://github.com/rwightman/pytorch-image-models@d8e69206be253892b2956341fea09fdebfaae4e3 +timm==0.3.2 diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 3df33e11..2b192271 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -16,7 +16,13 @@ from .timm_res2net import timm_res2net_encoders from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders -from .timm_gernet import timm_gernet_encoders +try: + from .timm_gernet import timm_gernet_encoders +except ImportError as e: + timm_gernet_encoders = {} + print("Current timm version doesn't support GERNet." + "If GERNet support is needed please update timm") + from ._preprocessing import preprocess_input encoders = {} diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py index 55f5c3ce..39be994b 100644 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -5,7 +5,7 @@ class GERNetEncoder(ByobNet, EncoderMixin): - def __init__(self, out_channels, depth=6, **kwargs): + def __init__(self, out_channels, depth=5, **kwargs): super().__init__(**kwargs) self._depth = depth self._out_channels = out_channels @@ -17,8 +17,10 @@ def get_stages(self): return [ nn.Identity(), self.stem, - *self.stages[:-1], - nn.Sequential(self.stages[-1], self.final_conv) + self.stages[0], + self.stages[1], + self.stages[2], + nn.Sequential(self.stages[3], self.stages[4], self.final_conv) ] def forward(self, x): @@ -67,7 +69,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': GERNetEncoder, "pretrained_settings": pretrained_settings["timm-gernet_s"], 'params': { - 'out_channels': (3, 13, 48, 48, 384, 560, 1920), + 'out_channels': (3, 13, 48, 48, 384, 1920), 'cfg': ByobCfg( blocks=( BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), @@ -85,7 +87,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': GERNetEncoder, "pretrained_settings": pretrained_settings["timm-gernet_m"], 'params': { - 'out_channels': (3, 32, 128, 192, 640, 640, 2560), + 'out_channels': (3, 32, 128, 192, 640, 2560), 'cfg': ByobCfg( blocks=( BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), @@ -103,7 +105,7 @@ def load_state_dict(self, state_dict, **kwargs): 'encoder': GERNetEncoder, "pretrained_settings": pretrained_settings["timm-gernet_l"], 'params': { - 'out_channels': (3, 32, 128, 192, 640, 640, 2560), + 'out_channels': (3, 32, 128, 192, 640, 2560), 'cfg': ByobCfg( blocks=( BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), From bc92a9fd92ec5e9e4d9622ac8176719dfd60f796 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 17:17:45 +0300 Subject: [PATCH 4/7] docs --- README.md | 15 ++++++++++++++- docs/encoders.rst | 13 +++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7093a43e..937536d5 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The main features of this library are: - High level API (just two lines to create a neural network) - 9 models architectures for binary and multi class segmentation (including legendary Unet) - - 104 available encoders + - 107 available encoders - All encoders have pre-trained weights for faster and better convergence ### [📚 Project Documentation 📚](http://smp.readthedocs.io/) @@ -188,6 +188,19 @@ The following is a list of supported encoders in the SMP. Select the appropriate +
+RegNet(x/y) +
+ +|Encoder |Weights |Params, M | +|--------------------------------|:------------------------------:|:------------------------------:| +|timm-gernet_s |imagenet |6M | +|timm-gernet_m |imagenet |18M | +|timm-gernet_l |imagenet |28M | + +
+
+
SE-Net
diff --git a/docs/encoders.rst b/docs/encoders.rst index 7c55373b..cfc2f9c1 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -136,6 +136,19 @@ RegNet(x/y) | timm-regnety\_320 | imagenet | 141M | +---------------------+------------+-------------+ +GERNet +~~~~~~ + ++-------------------------+------------+-------------+ +| Encoder | Weights | Params, M | ++=========================+============+=============+ +| timm-gernet\_s | imagenet | 6M | ++-------------------------+------------+-------------+ +| timm-gernet\_m | imagenet | 18M | ++-------------------------+------------+-------------+ +| timm-gernet\_l | imagenet | 28M | ++-------------------------+------------+-------------+ + SE-Net ~~~~~~ From 3b1961baf5347b2be8295d86b2a2bd7b50152d7d Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 17:21:06 +0300 Subject: [PATCH 5/7] Fix summary error --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 937536d5..8086c755 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ The following is a list of supported encoders in the SMP. Select the appropriate
-RegNet(x/y) +GERNet
|Encoder |Weights |Params, M | From 97f811424c6b87c9a1c9807317a76c8ba9934040 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 17:23:04 +0300 Subject: [PATCH 6/7] remove input size --- segmentation_models_pytorch/encoders/timm_gernet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py index 39be994b..93cb94d1 100644 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ b/segmentation_models_pytorch/encoders/timm_gernet.py @@ -57,7 +57,6 @@ def load_state_dict(self, state_dict, **kwargs): for source_name, source_url in sources.items(): pretrained_settings[model_name][source_name] = { "url": source_url, - 'input_size': [3, 224, 224] if not model_name == 'timm-gernet_l' else [3, 256, 256], 'input_range': [0, 1], 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], From e5e19dcfd51cff6760777be268acefe39238b8a4 Mon Sep 17 00:00:00 2001 From: Alexander Yaroshevich Date: Thu, 11 Feb 2021 20:11:06 +0300 Subject: [PATCH 7/7] manet fix and test with latest timm --- .github/workflows/tests.yml | 1 + segmentation_models_pytorch/manet/decoder.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index be4c4462..c4200863 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,6 +29,7 @@ jobs: python -m pip install codecov pytest mock pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html pip install . + pip install -U git+https://github.com/rwightman/pytorch-image-models - name: Test run: | python -m pytest -s tests diff --git a/segmentation_models_pytorch/manet/decoder.py b/segmentation_models_pytorch/manet/decoder.py index 2d587671..81822091 100644 --- a/segmentation_models_pytorch/manet/decoder.py +++ b/segmentation_models_pytorch/manet/decoder.py @@ -56,18 +56,19 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, use_batchnorm=use_batchnorm, ) ) + reduced_channels = max(1, skip_channels // reduction) self.SE_ll = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), - nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.SE_hl = nn.Sequential( nn.AdaptiveAvgPool2d(1), - nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), - nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.conv1 = md.Conv2dReLU(