Skip to content

Commit 0e2da9e

Browse files
authored
Merge pull request #937 from qubvel-org/fix-dims-none
Fix dims=None in loss
2 parents 9e07716 + ebe54a4 commit 0e2da9e

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

requirements/minimum.old

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ pillow==8.0.0
55
pretrainedmodels==0.7.1
66
six==1.5.0
77
timm==0.9.0
8-
torch==1.8.0
9-
torchvision==0.9.0
8+
torch==1.9.0
9+
torchvision==0.10.0
1010
tqdm==4.42.1

segmentation_models_pytorch/losses/_functional.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,14 @@ def soft_tversky_score(
192192
"""
193193
assert output.size() == target.size()
194194

195-
output_sum = torch.sum(output, dim=dims)
196-
target_sum = torch.sum(target, dim=dims)
197-
difference = LA.vector_norm(output - target, ord=1, dim=dims)
195+
if dims is not None:
196+
output_sum = torch.sum(output, dim=dims)
197+
target_sum = torch.sum(target, dim=dims)
198+
difference = LA.vector_norm(output - target, ord=1, dim=dims)
199+
else:
200+
output_sum = torch.sum(output)
201+
target_sum = torch.sum(target)
202+
difference = LA.vector_norm(output - target, ord=1)
198203

199204
intersection = (output_sum + target_sum - difference) / 2 # TP
200205
fp = output_sum - intersection

0 commit comments

Comments
 (0)