From 58a0a8fea30fdb8f9d7229d1cd97df809e5c19dd Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Sun, 15 Sep 2024 22:37:35 +0200 Subject: [PATCH 1/2] Modify Jaccard, Dice and Tversky losses --- .../losses/_functional.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index 74301e6d..93d266de 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -157,15 +157,7 @@ def soft_jaccard_score( dims=None, ) -> torch.Tensor: assert output.size() == target.size() - if dims is not None: - intersection = torch.sum(output * target, dim=dims) - cardinality = torch.sum(output + target, dim=dims) - else: - intersection = torch.sum(output * target) - cardinality = torch.sum(output + target) - - union = cardinality - intersection - jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) + jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims) return jaccard_score @@ -177,13 +169,7 @@ def soft_dice_score( dims=None, ) -> torch.Tensor: assert output.size() == target.size() - if dims is not None: - intersection = torch.sum(output * target, dim=dims) - cardinality = torch.sum(output + target, dim=dims) - else: - intersection = torch.sum(output * target) - cardinality = torch.sum(output + target) - dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) + dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims) return dice_score @@ -196,15 +182,28 @@ def soft_tversky_score( eps: float = 1e-7, dims=None, ) -> torch.Tensor: + """Tversky loss + + References: + https://arxiv.org/pdf/2302.05666 + https://arxiv.org/pdf/2303.16296 + + """ assert output.size() == target.size() if dims is not None: - intersection = torch.sum(output * target, dim=dims) # TP - fp = torch.sum(output * (1.0 - target), dim=dims) - fn = torch.sum((1 - output) * target, dim=dims) + difference = torch.norm(output - target, p=1, dim=dims) + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + intersection = (output_sum + target_sum - difference) / 2 # TP + fp = output_sum - intersection + fn = target_sum - intersection else: - intersection = torch.sum(output * target) # TP - fp = torch.sum(output * (1.0 - target)) - fn = torch.sum((1 - output) * target) + difference = torch.norm(output - target, p=1) + output_sum = torch.sum(output) + target_sum = torch.sum(target) + intersection = (output_sum + target_sum - difference) / 2 # TP + fp = output_sum - intersection + fn = target_sum - intersection tversky_score = (intersection + smooth) / ( intersection + alpha * fp + beta * fn + smooth From c4cbd1eca41ca8c59831a603657010ec40daa75f Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Wed, 18 Sep 2024 19:41:26 +0200 Subject: [PATCH 2/2] Modify the Tversky loss --- .../losses/_functional.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index 93d266de..a26f3f48 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -4,6 +4,7 @@ from typing import Optional import torch +import torch.linalg as LA import torch.nn.functional as F __all__ = [ @@ -190,20 +191,14 @@ def soft_tversky_score( """ assert output.size() == target.size() - if dims is not None: - difference = torch.norm(output - target, p=1, dim=dims) - output_sum = torch.sum(output, dim=dims) - target_sum = torch.sum(target, dim=dims) - intersection = (output_sum + target_sum - difference) / 2 # TP - fp = output_sum - intersection - fn = target_sum - intersection - else: - difference = torch.norm(output - target, p=1) - output_sum = torch.sum(output) - target_sum = torch.sum(target) - intersection = (output_sum + target_sum - difference) / 2 # TP - fp = output_sum - intersection - fn = target_sum - intersection + + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + difference = LA.vector_norm(output - target, ord=1, dim=dims) + + intersection = (output_sum + target_sum - difference) / 2 # TP + fp = output_sum - intersection + fn = target_sum - intersection tversky_score = (intersection + smooth) / ( intersection + alpha * fp + beta * fn + smooth