Skip to content

Commit 6ffe236

Browse files
authored
Fix LR scheduler issue with CPU offload optimizer (#1649)
* synchronize param H2D * let CPU offload inherits Optimizer * add scheduler to test
1 parent 122eb73 commit 6ffe236

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

test/prototype/test_low_bit_optim.py

+5
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
287287
offload_gradients=offload_grad,
288288
)
289289

290+
scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optim1, 100)
291+
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, 100)
292+
290293
rng = torch.Generator(device=device)
291294
rng.manual_seed(42)
292295

@@ -299,6 +302,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
299302

300303
optim1.step()
301304
optim1.zero_grad()
305+
scheduler1.step()
302306

303307
# reset the rng
304308
rng.manual_seed(42)
@@ -309,6 +313,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
309313

310314
optim2.step()
311315
optim2.zero_grad()
316+
scheduler2.step()
312317

313318
for p1, p2 in zip(model1.parameters(), model2.parameters()):
314319
torch.testing.assert_close(p2, p1)

torchao/prototype/low_bit_optim/cpu_offload.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices
77

88

9-
class CPUOffloadOptimizer:
9+
# NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR
10+
# schedulers. (those schedulers specifically check for instances of Optimizer).
11+
# However, it won't behave exactly like Optimizer e.g. we don't call
12+
# Optimizer.__init__(), there is no self.defaults.
13+
class CPUOffloadOptimizer(Optimizer):
1014
def __init__(
1115
self,
1216
params: ParamsT,

0 commit comments

Comments
 (0)