Skip to content

Commit cd5d3c2

Browse files
Expose timm constructor arguments (#960)
* Expose timm constructor arguments * Remove leak from other branch * Rename dupls to duplicates
1 parent b90b3c5 commit cd5d3c2

File tree

10 files changed

+97
-36
lines changed

10 files changed

+97
-36
lines changed

segmentation_models_pytorch/decoders/deeplabv3/model.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22

33
from segmentation_models_pytorch.base import (
4-
SegmentationModel,
5-
SegmentationHead,
64
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
1011

1112

@@ -36,6 +37,8 @@ class DeepLabV3(SegmentationModel):
3637
- dropout (float): Dropout factor in [0, 1)
3738
- activation (str): An activation function to apply "sigmoid"/"softmax"
3839
(could be **None** to return logits)
40+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
41+
3942
Returns:
4043
``torch.nn.Module``: **DeepLabV3**
4144
@@ -55,6 +58,7 @@ def __init__(
5558
activation: Optional[str] = None,
5659
upsampling: int = 8,
5760
aux_params: Optional[dict] = None,
61+
**kwargs: dict[str, Any],
5862
):
5963
super().__init__()
6064

@@ -64,6 +68,7 @@ def __init__(
6468
depth=encoder_depth,
6569
weights=encoder_weights,
6670
output_stride=8,
71+
**kwargs,
6772
)
6873

6974
self.decoder = DeepLabV3Decoder(
@@ -116,6 +121,8 @@ class DeepLabV3Plus(SegmentationModel):
116121
- dropout (float): Dropout factor in [0, 1)
117122
- activation (str): An activation function to apply "sigmoid"/"softmax"
118123
(could be **None** to return logits)
124+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
125+
119126
Returns:
120127
``torch.nn.Module``: **DeepLabV3Plus**
121128
@@ -137,6 +144,7 @@ def __init__(
137144
activation: Optional[str] = None,
138145
upsampling: int = 4,
139146
aux_params: Optional[dict] = None,
147+
**kwargs: dict[str, Any],
140148
):
141149
super().__init__()
142150

@@ -153,6 +161,7 @@ def __init__(
153161
depth=encoder_depth,
154162
weights=encoder_weights,
155163
output_stride=encoder_output_stride,
164+
**kwargs,
156165
)
157166

158167
self.decoder = DeepLabV3PlusDecoder(

segmentation_models_pytorch/decoders/fpn/model.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22

33
from segmentation_models_pytorch.base import (
4-
SegmentationModel,
5-
SegmentationHead,
64
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import FPNDecoder
1011

1112

@@ -40,6 +41,7 @@ class FPN(SegmentationModel):
4041
- dropout (float): Dropout factor in [0, 1)
4142
- activation (str): An activation function to apply "sigmoid"/"softmax"
4243
(could be **None** to return logits)
44+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4345
4446
Returns:
4547
``torch.nn.Module``: **FPN**
@@ -63,6 +65,7 @@ def __init__(
6365
activation: Optional[str] = None,
6466
upsampling: int = 4,
6567
aux_params: Optional[dict] = None,
68+
**kwargs: dict[str, Any],
6669
):
6770
super().__init__()
6871

@@ -77,6 +80,7 @@ def __init__(
7780
in_channels=in_channels,
7881
depth=encoder_depth,
7982
weights=encoder_weights,
83+
**kwargs,
8084
)
8185

8286
self.decoder = FPNDecoder(

segmentation_models_pytorch/decoders/linknet/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union
1+
from typing import Any, Optional, Union
22

33
from segmentation_models_pytorch.base import (
4+
ClassificationHead,
45
SegmentationHead,
56
SegmentationModel,
6-
ClassificationHead,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import LinknetDecoder
1011

1112

@@ -43,6 +44,7 @@ class Linknet(SegmentationModel):
4344
- dropout (float): Dropout factor in [0, 1)
4445
- activation (str): An activation function to apply "sigmoid"/"softmax"
4546
(could be **None** to return logits)
47+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4648
4749
Returns:
4850
``torch.nn.Module``: **Linknet**
@@ -61,6 +63,7 @@ def __init__(
6163
classes: int = 1,
6264
activation: Optional[Union[str, callable]] = None,
6365
aux_params: Optional[dict] = None,
66+
**kwargs: dict[str, Any],
6467
):
6568
super().__init__()
6669

@@ -74,6 +77,7 @@ def __init__(
7477
in_channels=in_channels,
7578
depth=encoder_depth,
7679
weights=encoder_weights,
80+
**kwargs,
7781
)
7882

7983
self.decoder = LinknetDecoder(

segmentation_models_pytorch/decoders/manet/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union, List
1+
from typing import Any, List, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import MAnetDecoder
1011

1112

@@ -45,6 +46,7 @@ class MAnet(SegmentationModel):
4546
- dropout (float): Dropout factor in [0, 1)
4647
- activation (str): An activation function to apply "sigmoid"/"softmax"
4748
(could be **None** to return logits)
49+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4850
4951
Returns:
5052
``torch.nn.Module``: **MAnet**
@@ -66,6 +68,7 @@ def __init__(
6668
classes: int = 1,
6769
activation: Optional[Union[str, callable]] = None,
6870
aux_params: Optional[dict] = None,
71+
**kwargs: dict[str, Any],
6972
):
7073
super().__init__()
7174

@@ -74,6 +77,7 @@ def __init__(
7477
in_channels=in_channels,
7578
depth=encoder_depth,
7679
weights=encoder_weights,
80+
**kwargs,
7781
)
7882

7983
self.decoder = MAnetDecoder(

segmentation_models_pytorch/decoders/pan/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union
1+
from typing import Any, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import PANDecoder
1011

1112

@@ -38,6 +39,7 @@ class PAN(SegmentationModel):
3839
- dropout (float): Dropout factor in [0, 1)
3940
- activation (str): An activation function to apply "sigmoid"/"softmax"
4041
(could be **None** to return logits)
42+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4143
4244
Returns:
4345
``torch.nn.Module``: **PAN**
@@ -58,6 +60,7 @@ def __init__(
5860
activation: Optional[Union[str, callable]] = None,
5961
upsampling: int = 4,
6062
aux_params: Optional[dict] = None,
63+
**kwargs: dict[str, Any],
6164
):
6265
super().__init__()
6366

@@ -74,6 +77,7 @@ def __init__(
7477
depth=5,
7578
weights=encoder_weights,
7679
output_stride=encoder_output_stride,
80+
**kwargs,
7781
)
7882

7983
self.decoder = PANDecoder(

segmentation_models_pytorch/decoders/pspnet/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union
1+
from typing import Any, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import PSPDecoder
1011

1112

@@ -44,6 +45,7 @@ class PSPNet(SegmentationModel):
4445
- dropout (float): Dropout factor in [0, 1)
4546
- activation (str): An activation function to apply "sigmoid"/"softmax"
4647
(could be **None** to return logits)
48+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4749
4850
Returns:
4951
``torch.nn.Module``: **PSPNet**
@@ -65,6 +67,7 @@ def __init__(
6567
activation: Optional[Union[str, callable]] = None,
6668
upsampling: int = 8,
6769
aux_params: Optional[dict] = None,
70+
**kwargs: dict[str, Any],
6871
):
6972
super().__init__()
7073

@@ -73,6 +76,7 @@ def __init__(
7376
in_channels=in_channels,
7477
depth=encoder_depth,
7578
weights=encoder_weights,
79+
**kwargs,
7680
)
7781

7882
self.decoder = PSPDecoder(

segmentation_models_pytorch/decoders/unet/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union, List
1+
from typing import Any, List, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import UnetDecoder
1011

1112

@@ -44,6 +45,7 @@ class Unet(SegmentationModel):
4445
- dropout (float): Dropout factor in [0, 1)
4546
- activation (str): An activation function to apply "sigmoid"/"softmax"
4647
(could be **None** to return logits)
48+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4749
4850
Returns:
4951
``torch.nn.Module``: Unet
@@ -65,6 +67,7 @@ def __init__(
6567
classes: int = 1,
6668
activation: Optional[Union[str, callable]] = None,
6769
aux_params: Optional[dict] = None,
70+
**kwargs: dict[str, Any],
6871
):
6972
super().__init__()
7073

@@ -73,6 +76,7 @@ def __init__(
7376
in_channels=in_channels,
7477
depth=encoder_depth,
7578
weights=encoder_weights,
79+
**kwargs,
7680
)
7781

7882
self.decoder = UnetDecoder(

segmentation_models_pytorch/decoders/unetplusplus/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union, List
1+
from typing import Any, List, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import UnetPlusPlusDecoder
1011

1112

@@ -44,6 +45,7 @@ class UnetPlusPlus(SegmentationModel):
4445
- dropout (float): Dropout factor in [0, 1)
4546
- activation (str): An activation function to apply "sigmoid"/"softmax"
4647
(could be **None** to return logits)
48+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
4749
4850
Returns:
4951
``torch.nn.Module``: **Unet++**
@@ -65,6 +67,7 @@ def __init__(
6567
classes: int = 1,
6668
activation: Optional[Union[str, callable]] = None,
6769
aux_params: Optional[dict] = None,
70+
**kwargs: dict[str, Any],
6871
):
6972
super().__init__()
7073

@@ -78,6 +81,7 @@ def __init__(
7881
in_channels=in_channels,
7982
depth=encoder_depth,
8083
weights=encoder_weights,
84+
**kwargs,
8185
)
8286

8387
self.decoder = UnetPlusPlusDecoder(

segmentation_models_pytorch/decoders/upernet/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Union
1+
from typing import Any, Optional, Union
22

3-
from segmentation_models_pytorch.encoders import get_encoder
43
from segmentation_models_pytorch.base import (
5-
SegmentationModel,
6-
SegmentationHead,
74
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
87
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
910
from .decoder import UPerNetDecoder
1011

1112

@@ -36,6 +37,7 @@ class UPerNet(SegmentationModel):
3637
- dropout (float): Dropout factor in [0, 1)
3738
- activation (str): An activation function to apply "sigmoid"/"softmax"
3839
(could be **None** to return logits)
40+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
3941
4042
Returns:
4143
``torch.nn.Module``: **UPerNet**
@@ -56,6 +58,7 @@ def __init__(
5658
classes: int = 1,
5759
activation: Optional[Union[str, callable]] = None,
5860
aux_params: Optional[dict] = None,
61+
**kwargs: dict[str, Any],
5962
):
6063
super().__init__()
6164

@@ -64,6 +67,7 @@ def __init__(
6467
in_channels=in_channels,
6568
depth=encoder_depth,
6669
weights=encoder_weights,
70+
**kwargs,
6771
)
6872

6973
self.decoder = UPerNetDecoder(

0 commit comments

Comments
 (0)