Skip to content

Commit b87df91

Browse files
committed
more conversions to torch
1 parent 1cf59fa commit b87df91

4 files changed

+89
-71
lines changed

tests/ignite/metrics/regression/test_median_absolute_error.py

+15-21
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,29 @@ def test_median_absolute_error(available_device):
4141
# Size of dataset will be odd for these tests
4242

4343
size = 51
44-
np_y_pred = np.random.rand(size)
45-
np_y = np.random.rand(size)
46-
np_median_absolute_error = np.median(np.abs(np_y - np_y_pred))
44+
y_pred = torch.rand(size)
45+
y = torch.rand(size)
46+
expected_median_absolute_error = torch.median(torch.abs(y - y_pred).cpu()).item()
4747

4848
m = MedianAbsoluteError(device=available_device)
4949
assert m._device == torch.device(available_device)
50-
y_pred = torch.from_numpy(np_y_pred)
51-
y = torch.from_numpy(np_y)
5250

5351
m.reset()
5452
m.update((y_pred, y))
5553

56-
assert np_median_absolute_error == pytest.approx(m.compute())
54+
assert expected_median_absolute_error == pytest.approx(m.compute())
5755

5856

5957
def test_median_absolute_error_2(available_device):
6058
np.random.seed(1)
6159
size = 105
62-
np_y_pred = np.random.rand(size, 1)
63-
np_y = np.random.rand(size, 1)
64-
np.random.shuffle(np_y)
65-
np_median_absolute_error = np.median(np.abs(np_y - np_y_pred))
60+
y_pred = torch.rand(size, 1)
61+
y = torch.rand(size, 1)
62+
y = y[torch.randperm(size)]
63+
expected_median_absolute_error = torch.median(torch.abs(y - y_pred).cpu()).item()
6664

6765
m = MedianAbsoluteError(device=available_device)
6866
assert m._device == torch.device(available_device)
69-
y_pred = torch.from_numpy(np_y_pred)
70-
y = torch.from_numpy(np_y)
7167

7268
m.reset()
7369
batch_size = 16
@@ -76,24 +72,22 @@ def test_median_absolute_error_2(available_device):
7672
idx = i * batch_size
7773
m.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
7874

79-
assert np_median_absolute_error == pytest.approx(m.compute())
75+
assert expected_median_absolute_error == pytest.approx(m.compute())
8076

8177

8278
def test_integration_median_absolute_error(available_device):
8379
np.random.seed(1)
8480
size = 105
85-
np_y_pred = np.random.rand(size, 1)
86-
np_y = np.random.rand(size, 1)
87-
np.random.shuffle(np_y)
88-
np_median_absolute_error = np.median(np.abs(np_y - np_y_pred))
81+
y_pred = torch.rand(size, 1)
82+
y = torch.rand(size, 1)
83+
y = y[torch.randperm(size)]
8984

85+
expected = torch.median(torch.abs(y - y_pred).cpu()).item()
9086
batch_size = 15
9187

9288
def update_fn(engine, batch):
9389
idx = (engine.state.iteration - 1) * batch_size
94-
y_true_batch = np_y[idx : idx + batch_size]
95-
y_pred_batch = np_y_pred[idx : idx + batch_size]
96-
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
90+
return y_pred[idx : idx + batch_size], y[idx : idx + batch_size]
9791

9892
engine = Engine(update_fn)
9993

@@ -104,7 +98,7 @@ def update_fn(engine, batch):
10498
data = list(range(size // batch_size))
10599
median_absolute_error = engine.run(data, max_epochs=1).metrics["median_absolute_error"]
106100

107-
assert np_median_absolute_error == pytest.approx(median_absolute_error)
101+
assert expected == pytest.approx(median_absolute_error)
108102

109103

110104
def _test_distrib_compute(device):

tests/ignite/metrics/regression/test_median_absolute_percentage_error.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,37 @@ def test_median_absolute_percentage_error(available_device):
4141
# Size of dataset will be odd for these tests
4242

4343
size = 51
44-
np_y_pred = np.random.rand(size)
45-
np_y = np.random.rand(size)
46-
np_median_absolute_percentage_error = 100.0 * np.median(np.abs(np_y - np_y_pred) / np.abs(np_y))
44+
y_pred = torch.rand(size)
45+
y = torch.rand(size)
46+
47+
epsilon = 1e-8
48+
safe_y = torch.where(y == 0, torch.full_like(y, epsilon), y)
49+
expected = torch.median(torch.abs((y - y_pred) / safe_y).cpu()).item() * 100.0
4750

4851
m = MedianAbsolutePercentageError(device=available_device)
4952
assert m._device == torch.device(available_device)
50-
y_pred = torch.from_numpy(np_y_pred)
51-
y = torch.from_numpy(np_y)
5253

5354
m.reset()
5455
m.update((y_pred, y))
5556

56-
assert np_median_absolute_percentage_error == pytest.approx(m.compute())
57+
assert expected == pytest.approx(m.compute())
5758

5859

5960
def test_median_absolute_percentage_error_2(available_device):
6061
np.random.seed(1)
6162
size = 105
62-
np_y_pred = np.random.rand(size, 1)
63-
np_y = np.random.rand(size, 1)
64-
np.random.shuffle(np_y)
65-
np_median_absolute_percentage_error = 100.0 * np.median(np.abs(np_y - np_y_pred) / np.abs(np_y))
63+
y_pred = torch.rand(size, 1)
64+
y = torch.rand(size, 1)
65+
66+
indices = torch.randperm(size)
67+
y = y[indices]
68+
69+
epsilon = 1e-8
70+
safe_y = torch.where(y == 0, torch.full_like(y, epsilon), y)
71+
expected = torch.median(torch.abs((y - y_pred) / safe_y).cpu()).item() * 100.0
6672

6773
m = MedianAbsolutePercentageError(device=available_device)
6874
assert m._device == torch.device(available_device)
69-
y_pred = torch.from_numpy(np_y_pred)
70-
y = torch.from_numpy(np_y)
7175

7276
m.reset()
7377
batch_size = 16
@@ -76,24 +80,27 @@ def test_median_absolute_percentage_error_2(available_device):
7680
idx = i * batch_size
7781
m.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
7882

79-
assert np_median_absolute_percentage_error == pytest.approx(m.compute())
83+
assert expected == pytest.approx(m.compute())
8084

8185

8286
def test_integration_median_absolute_percentage_error(available_device):
8387
np.random.seed(1)
8488
size = 105
85-
np_y_pred = np.random.rand(size, 1)
86-
np_y = np.random.rand(size, 1)
87-
np.random.shuffle(np_y)
88-
np_median_absolute_percentage_error = 100.0 * np.median(np.abs(np_y - np_y_pred) / np.abs(np_y))
89+
y_pred = torch.rand(size, 1)
90+
y = torch.rand(size, 1)
91+
92+
indices = torch.randperm(size)
93+
y = y[indices]
94+
95+
epsilon = 1e-8
96+
safe_y = torch.where(y == 0, torch.full_like(y, epsilon), y)
97+
expected = torch.median(torch.abs((y - y_pred) / safe_y).cpu()).item() * 100.0
8998

9099
batch_size = 15
91100

92101
def update_fn(engine, batch):
93102
idx = (engine.state.iteration - 1) * batch_size
94-
y_true_batch = np_y[idx : idx + batch_size]
95-
y_pred_batch = np_y_pred[idx : idx + batch_size]
96-
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
103+
return y_pred[idx : idx + batch_size], y[idx : idx + batch_size]
97104

98105
engine = Engine(update_fn)
99106

@@ -104,7 +111,7 @@ def update_fn(engine, batch):
104111
data = list(range(size // batch_size))
105112
median_absolute_percentage_error = engine.run(data, max_epochs=1).metrics["median_absolute_percentage_error"]
106113

107-
assert np_median_absolute_percentage_error == pytest.approx(median_absolute_percentage_error)
114+
assert expected == pytest.approx(median_absolute_percentage_error)
108115

109116

110117
def _test_distrib_compute(device):

tests/ignite/metrics/regression/test_median_relative_absolute_error.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,33 @@ def test_median_relative_absolute_error(available_device):
4141
# Size of dataset will be odd for these tests
4242

4343
size = 51
44-
np_y_pred = np.random.rand(size)
45-
np_y = np.random.rand(size)
46-
np_median_absolute_relative_error = np.median(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean()))
44+
y_pred = torch.rand(size)
45+
y = torch.rand(size)
46+
47+
baseline = torch.abs(y - y.mean())
48+
expected = torch.median((torch.abs(y - y_pred) / baseline).cpu()).item()
4749

4850
m = MedianRelativeAbsoluteError(device=available_device)
4951
assert m._device == torch.device(available_device)
50-
y_pred = torch.from_numpy(np_y_pred)
51-
y = torch.from_numpy(np_y)
5252

5353
m.reset()
5454
m.update((y_pred, y))
5555

56-
assert np_median_absolute_relative_error == pytest.approx(m.compute())
56+
assert expected == pytest.approx(m.compute())
5757

5858

5959
def test_median_relative_absolute_error_2(available_device):
6060
np.random.seed(1)
6161
size = 105
62-
np_y_pred = np.random.rand(size, 1)
63-
np_y = np.random.rand(size, 1)
64-
np.random.shuffle(np_y)
65-
np_median_absolute_relative_error = np.median(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean()))
62+
y_pred = torch.rand(size, 1)
63+
y = torch.rand(size, 1)
64+
y = y[torch.randperm(size)]
65+
66+
baseline = torch.abs(y - y.mean())
67+
expected = torch.median((torch.abs(y - y_pred) / baseline).cpu()).item()
6668

6769
m = MedianRelativeAbsoluteError(device=available_device)
6870
assert m._device == torch.device(available_device)
69-
y_pred = torch.from_numpy(np_y_pred)
70-
y = torch.from_numpy(np_y)
7171

7272
m.reset()
7373
batch_size = 16
@@ -76,24 +76,26 @@ def test_median_relative_absolute_error_2(available_device):
7676
idx = i * batch_size
7777
m.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
7878

79-
assert np_median_absolute_relative_error == pytest.approx(m.compute())
79+
assert expected == pytest.approx(m.compute())
8080

8181

8282
def test_integration_median_relative_absolute_error_with_output_transform(available_device):
8383
np.random.seed(1)
8484
size = 105
85-
np_y_pred = np.random.rand(size, 1)
86-
np_y = np.random.rand(size, 1)
87-
np.random.shuffle(np_y)
88-
np_median_absolute_relative_error = np.median(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean()))
85+
y_pred = torch.rand(size, 1)
86+
y = torch.rand(size, 1)
87+
y = y[torch.randperm(size)] # shuffle y
88+
89+
baseline = torch.abs(y - y.mean())
90+
expected = torch.median((torch.abs(y - y_pred) / baseline.cpu()).cpu()).item()
8991

9092
batch_size = 15
9193

9294
def update_fn(engine, batch):
9395
idx = (engine.state.iteration - 1) * batch_size
94-
y_true_batch = np_y[idx : idx + batch_size]
95-
y_pred_batch = np_y_pred[idx : idx + batch_size]
96-
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
96+
y_true_batch = y[idx : idx + batch_size]
97+
y_pred_batch = y_pred[idx : idx + batch_size]
98+
return y_pred_batch, y_true_batch
9799

98100
engine = Engine(update_fn)
99101

@@ -104,7 +106,7 @@ def update_fn(engine, batch):
104106
data = list(range(size // batch_size))
105107
median_absolute_relative_error = engine.run(data, max_epochs=1).metrics["median_absolute_relative_error"]
106108

107-
assert np_median_absolute_relative_error == pytest.approx(median_absolute_relative_error)
109+
assert expected == pytest.approx(median_absolute_relative_error)
108110

109111

110112
def _test_distrib_compute(device):

tests/ignite/metrics/regression/test_pearson_correlation.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@ def np_corr_eps(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8):
2020
return corr
2121

2222

23+
def torch_corr_eps(y_pred: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
24+
y_pred = y_pred.to(dtype=torch.float32)
25+
y = y.to(dtype=torch.float32)
26+
27+
y_pred_mean = y_pred.mean()
28+
y_mean = y.mean()
29+
30+
pred_centered = y_pred - y_pred_mean
31+
y_centered = y - y_mean
32+
33+
cov = torch.mean(pred_centered * y_centered)
34+
std_pred = pred_centered.std(unbiased=False)
35+
std_y = y_centered.std(unbiased=False)
36+
37+
denom = torch.clamp(std_pred * std_y, min=eps)
38+
corr = cov / denom
39+
return corr
40+
41+
2342
def scipy_corr(np_y_pred: np.ndarray, np_y: np.ndarray):
2443
corr = pearsonr(np_y_pred, np_y)
2544
return corr.statistic
@@ -51,21 +70,17 @@ def test_degenerated_sample(available_device):
5170
y = torch.tensor([1.0])
5271
m.update((y_pred, y))
5372

54-
np_y_pred = y_pred.numpy()
55-
np_y = y_pred.numpy()
56-
np_res = np_corr_eps(np_y_pred, np_y)
57-
assert pytest.approx(np_res) == m.compute()
73+
res = torch_corr_eps(y_pred, y)
74+
assert pytest.approx(res) == m.compute()
5875

5976
# constant samples
6077
m.reset()
6178
y_pred = torch.ones(10).float()
6279
y = torch.zeros(10).float()
6380
m.update((y_pred, y))
6481

65-
np_y_pred = y_pred.numpy()
66-
np_y = y_pred.numpy()
67-
np_res = np_corr_eps(np_y_pred, np_y)
68-
assert pytest.approx(np_res) == m.compute()
82+
res = torch_corr_eps(y_pred, y)
83+
assert pytest.approx(res) == m.compute()
6984

7085

7186
def test_pearson_correlation(available_device):

0 commit comments

Comments
 (0)