diff --git a/requirements/minimum.old b/requirements/minimum.old index ee655eb0..3f687871 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -5,6 +5,6 @@ pillow==8.0.0 pretrainedmodels==0.7.1 six==1.5.0 timm==0.9.0 -torch==1.8.0 -torchvision==0.9.0 +torch==1.9.0 +torchvision==0.10.0 tqdm==4.42.1 diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index a26f3f48..791901f0 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -192,9 +192,14 @@ def soft_tversky_score( """ assert output.size() == target.size() - output_sum = torch.sum(output, dim=dims) - target_sum = torch.sum(target, dim=dims) - difference = LA.vector_norm(output - target, ord=1, dim=dims) + if dims is not None: + output_sum = torch.sum(output, dim=dims) + target_sum = torch.sum(target, dim=dims) + difference = LA.vector_norm(output - target, ord=1, dim=dims) + else: + output_sum = torch.sum(output) + target_sum = torch.sum(target) + difference = LA.vector_norm(output - target, ord=1) intersection = (output_sum + target_sum - difference) / 2 # TP fp = output_sum - intersection