From e2a791f9ed4b141273d7cf69491136446031df5b Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 7 Oct 2024 10:16:28 +0100 Subject: [PATCH 1/3] Fix dims=None in loss --- segmentation_models_pytorch/losses/_functional.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index a26f3f48..26928348 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -191,10 +191,15 @@ def soft_tversky_score( """ assert output.size() == target.size() - - output_sum = torch.sum(output, dim=dims) - target_sum = torch.sum(target, dim=dims) - difference = LA.vector_norm(output - target, ord=1, dim=dims) + + if dims is not None: + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + difference = LA.vector_norm(output - target, ord=1, dim=dims) + else: + output_sum = torch.sum(output) + target_sum = torch.sum(target) + difference = LA.vector_norm(output - target, ord=1) intersection = (output_sum + target_sum - difference) / 2 # TP fp = output_sum - intersection From d80dce4a995e40cf24b713ccc3be60d8bf8ad055 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 7 Oct 2024 10:21:58 +0100 Subject: [PATCH 2/3] Fixup --- segmentation_models_pytorch/losses/_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index 26928348..791901f0 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -191,7 +191,7 @@ def soft_tversky_score( """ assert output.size() == target.size() - + if dims is not None: output_sum = torch.sum(output, dim=dims) target_sum = torch.sum(target, dim=dims) From ebe54a4cedf5c1e1fdfb04300a0c63b8fa4e7183 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 7 Oct 2024 10:31:33 +0100 Subject: [PATCH 3/3] Bump reqs --- requirements/minimum.old | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/minimum.old b/requirements/minimum.old index ee655eb0..3f687871 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -5,6 +5,6 @@ pillow==8.0.0 pretrainedmodels==0.7.1 six==1.5.0 timm==0.9.0 -torch==1.8.0 -torchvision==0.9.0 +torch==1.9.0 +torchvision==0.10.0 tqdm==4.42.1