diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index 74301e6d..a26f3f48 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -4,6 +4,7 @@ from typing import Optional import torch +import torch.linalg as LA import torch.nn.functional as F __all__ = [ @@ -157,15 +158,7 @@ def soft_jaccard_score( dims=None, ) -> torch.Tensor: assert output.size() == target.size() - if dims is not None: - intersection = torch.sum(output * target, dim=dims) - cardinality = torch.sum(output + target, dim=dims) - else: - intersection = torch.sum(output * target) - cardinality = torch.sum(output + target) - - union = cardinality - intersection - jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) + jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims) return jaccard_score @@ -177,13 +170,7 @@ def soft_dice_score( dims=None, ) -> torch.Tensor: assert output.size() == target.size() - if dims is not None: - intersection = torch.sum(output * target, dim=dims) - cardinality = torch.sum(output + target, dim=dims) - else: - intersection = torch.sum(output * target) - cardinality = torch.sum(output + target) - dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) + dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims) return dice_score @@ -196,15 +183,22 @@ def soft_tversky_score( eps: float = 1e-7, dims=None, ) -> torch.Tensor: + """Tversky loss + + References: + https://arxiv.org/pdf/2302.05666 + https://arxiv.org/pdf/2303.16296 + + """ assert output.size() == target.size() - if dims is not None: - intersection = torch.sum(output * target, dim=dims) # TP - fp = torch.sum(output * (1.0 - target), dim=dims) - fn = torch.sum((1 - output) * target, dim=dims) - else: - intersection = torch.sum(output * target) # TP - fp = torch.sum(output * (1.0 - target)) - fn = torch.sum((1 - output) * target) + + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + difference = LA.vector_norm(output - target, ord=1, dim=dims) + + intersection = (output_sum + target_sum - difference) / 2 # TP + fp = output_sum - intersection + fn = target_sum - intersection tversky_score = (intersection + smooth) / ( intersection + alpha * fp + beta * fn + smooth