4
4
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
5
5
# - DeepSpeed (ZeRO-Offload):
6
6
# sudo apt install libopenmpi-dev
7
- # LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
7
+ # LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
8
8
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
9
9
#
10
10
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
31
31
import torch .nn .functional as F
32
32
import wandb
33
33
from torch .utils .data import DataLoader
34
+ from torchao .utils import get_available_devices
34
35
from torchvision .transforms import v2
35
36
from tqdm import tqdm
36
37
37
38
from torchao .prototype import low_bit_optim
38
39
40
+ _DEVICE = get_available_devices ()[- 1 ]
41
+ assert _DEVICE in ["cuda" , "xpu" ], "Benchmark currently only supports CUDA & XPU(BF16)"
42
+
39
43
OPTIM_MAP = dict (
40
44
AdamW = partial (torch .optim .AdamW , fused = True ),
41
45
AdamW8bitBnb = bnb .optim .AdamW8bit ,
49
53
50
54
OPTIM_MAP .update (
51
55
AdamW4bitLpmm = partial (lpmm .optim .AdamW , fused = True ),
52
- AdamW4bitRank1Lpmm = partial (lpmm .optim .AdamW , qconfig = argparse .Namespace (scale_type = "rank1" )),
56
+ AdamW4bitRank1Lpmm = partial (
57
+ lpmm .optim .AdamW , qconfig = argparse .Namespace (scale_type = "rank1" )
58
+ ),
53
59
)
54
60
55
61
except ImportError :
@@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
67
73
if step < self .warmup_steps :
68
74
return self .lr * step / self .warmup_steps
69
75
if step < self .total_steps :
70
- progress = (step - self .warmup_steps ) / (self .total_steps - self .warmup_steps )
71
- return self .final_lr + 0.5 * (self .lr - self .final_lr ) * (1 + math .cos (progress * math .pi ))
76
+ progress = (step - self .warmup_steps ) / (
77
+ self .total_steps - self .warmup_steps
78
+ )
79
+ return self .final_lr + 0.5 * (self .lr - self .final_lr ) * (
80
+ 1 + math .cos (progress * math .pi )
81
+ )
72
82
return self .final_lr
73
83
74
84
@@ -92,7 +102,9 @@ def get_parser():
92
102
parser .add_argument ("--weight_decay" , type = float , default = 0 )
93
103
parser .add_argument ("--optim_kwargs" , type = json .loads , default = dict ())
94
104
parser .add_argument ("--cosine_lr_scheduler" , action = "store_true" )
95
- parser .add_argument ("--optim_cpu_offload" , choices = ["ao" , "ao_offload_grads" , "deepspeed" ])
105
+ parser .add_argument (
106
+ "--optim_cpu_offload" , choices = ["ao" , "ao_offload_grads" , "deepspeed" ]
107
+ )
96
108
97
109
parser .add_argument ("--project" )
98
110
parser .add_argument ("--run_name" , default = "debug" )
@@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
110
122
transforms .extend ([v2 .Resize (256 ), v2 .CenterCrop (224 )])
111
123
112
124
transforms .append (v2 .ToDtype (torch .float32 , scale = True ))
113
- transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
125
+ transforms .append (
126
+ v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
127
+ )
114
128
transforms = v2 .Compose (transforms )
115
129
116
130
# use dataset from HF so download is fast
117
- ds = datasets .load_dataset ("timm/resisc45" , split = "train" if training else "validation" )
131
+ ds = datasets .load_dataset (
132
+ "timm/resisc45" , split = "train" if training else "validation"
133
+ )
118
134
ds = ds .select_columns (["image" , "label" ])
119
135
ds .set_transform (lambda x : dict (image = transforms (x ["image" ]), label = x ["label" ]))
120
136
@@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
128
144
)
129
145
130
146
131
- def get_amp_ctx (amp ):
147
+ def get_amp_ctx (amp , device ):
132
148
dtype = dict (bf16 = torch .bfloat16 , fp16 = torch .float16 , none = None )[amp ]
133
- return torch .autocast ("cuda" , dtype = dtype , enabled = amp != "none" )
149
+ return torch .autocast (device , dtype = dtype , enabled = amp != "none" )
134
150
135
151
136
152
@torch .no_grad ()
@@ -148,8 +164,8 @@ def evaluate_model(model, args):
148
164
if args .channels_last :
149
165
batch ["image" ] = batch ["image" ].to (memory_format = torch .channels_last )
150
166
151
- with get_amp_ctx (args .amp ):
152
- all_preds .append (model (batch ["image" ].cuda ( )).argmax (1 ).cpu ())
167
+ with get_amp_ctx (args .amp , _DEVICE ):
168
+ all_preds .append (model (batch ["image" ].to ( _DEVICE )).argmax (1 ).cpu ())
153
169
154
170
all_labels = torch .cat (all_labels , dim = 0 )
155
171
all_preds = torch .cat (all_preds , dim = 0 )
@@ -164,8 +180,12 @@ def evaluate_model(model, args):
164
180
if args .full_bf16 :
165
181
assert args .amp == "none" , "When --full_bf16 is set, --amp must be none"
166
182
if args .optim_cpu_offload == "deepspeed" :
167
- assert args .amp == "none" , "When using DeepSpeed ZeRO-Offload, --amp must be none"
168
- assert args .optim == "AdamW" , "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
183
+ assert (
184
+ args .amp == "none"
185
+ ), "When using DeepSpeed ZeRO-Offload, --amp must be none"
186
+ assert (
187
+ args .optim == "AdamW"
188
+ ), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
169
189
if args .profile :
170
190
args .n_epochs = 1
171
191
if args .seed is not None :
@@ -185,14 +205,16 @@ def evaluate_model(model, args):
185
205
dloader = get_dloader (args , True )
186
206
print (f"Train dataset: { len (dloader .dataset ):,} images" )
187
207
188
- model = timm .create_model (args .model , pretrained = True , num_classes = 45 , ** args .model_kwargs )
208
+ model = timm .create_model (
209
+ args .model , pretrained = True , num_classes = 45 , ** args .model_kwargs
210
+ )
189
211
if args .checkpoint_activations :
190
212
model .set_grad_checkpointing ()
191
213
if args .full_bf16 :
192
214
model .bfloat16 ()
193
215
if args .channels_last :
194
216
model .to (memory_format = torch .channels_last )
195
- model .cuda ( ) # move model to CUDA after optionally convert it to BF16
217
+ model .to ( _DEVICE ) # move model to DEVICE after optionally convert it to BF16
196
218
if args .compile :
197
219
model .compile (fullgraph = True )
198
220
print (f"Model parameters: { sum (p .numel () for p in model .parameters ()):,} " )
@@ -227,9 +249,15 @@ def evaluate_model(model, args):
227
249
optim_cls = OPTIM_MAP [args .optim ]
228
250
229
251
if args .optim_cpu_offload == "ao" :
230
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls )
252
+ optim_cls = partial (
253
+ low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls
254
+ )
231
255
elif args .optim_cpu_offload == "ao_offload_grads" :
232
- optim_cls = partial (low_bit_optim .CPUOffloadOptimizer , optimizer_class = optim_cls , offload_gradients = True )
256
+ optim_cls = partial (
257
+ low_bit_optim .CPUOffloadOptimizer ,
258
+ optimizer_class = optim_cls ,
259
+ offload_gradients = True ,
260
+ )
233
261
234
262
optim = optim_cls (
235
263
model .parameters (),
@@ -239,24 +267,30 @@ def evaluate_model(model, args):
239
267
)
240
268
241
269
lr_schedule = CosineSchedule (args .lr , len (dloader ) * args .n_epochs )
242
- grad_scaler = torch .amp .GradScaler ("cuda" , enabled = args .amp == "fp16" )
270
+ grad_scaler = torch .amp .GradScaler (_DEVICE , enabled = args .amp == "fp16" )
243
271
log_interval = 10
244
272
t0 = time .perf_counter ()
245
273
246
274
step = 0
247
275
for epoch_idx in range (args .n_epochs ):
248
276
model .train ()
249
- pbar = tqdm (dloader , dynamic_ncols = True , desc = f"Epoch { epoch_idx + 1 } /{ args .n_epochs } " )
277
+ pbar = tqdm (
278
+ dloader , dynamic_ncols = True , desc = f"Epoch { epoch_idx + 1 } /{ args .n_epochs } "
279
+ )
250
280
251
281
with torch .profiler .profile () if args .profile else nullcontext () as prof :
252
282
for batch in pbar :
253
283
if args .full_bf16 :
254
284
batch ["image" ] = batch ["image" ].bfloat16 ()
255
285
if args .channels_last :
256
- batch ["image" ] = batch ["image" ].to (memory_format = torch .channels_last )
286
+ batch ["image" ] = batch ["image" ].to (
287
+ memory_format = torch .channels_last
288
+ )
257
289
258
- with get_amp_ctx (args .amp ):
259
- loss = F .cross_entropy (model (batch ["image" ].cuda ()), batch ["label" ].cuda ())
290
+ with get_amp_ctx (args .amp , _DEVICE ):
291
+ loss = F .cross_entropy (
292
+ model (batch ["image" ].to (_DEVICE )), batch ["label" ].to (_DEVICE )
293
+ )
260
294
261
295
if args .optim_cpu_offload == "deepspeed" :
262
296
model .backward (loss )
@@ -275,7 +309,9 @@ def evaluate_model(model, args):
275
309
log_dict = dict (loss = loss .item (), lr = optim .param_groups [0 ]["lr" ])
276
310
if step > 0 :
277
311
t1 = time .perf_counter ()
278
- log_dict ["imgs_per_second" ] = args .batch_size * log_interval / (t1 - t0 )
312
+ log_dict ["imgs_per_second" ] = (
313
+ args .batch_size * log_interval / (t1 - t0 )
314
+ )
279
315
t0 = t1
280
316
logger .log (log_dict , step = step )
281
317
@@ -296,9 +332,11 @@ def evaluate_model(model, args):
296
332
297
333
else :
298
334
val_acc = evaluate_model (model , args )
299
- print (f"Epoch { epoch_idx + 1 } /{ args .n_epochs } : val_acc={ val_acc .item () * 100 :.2f} " )
335
+ print (
336
+ f"Epoch { epoch_idx + 1 } /{ args .n_epochs } : val_acc={ val_acc .item () * 100 :.2f} "
337
+ )
300
338
logger .log (dict (val_acc = val_acc ), step = step )
301
339
302
- peak_mem = torch . cuda .max_memory_allocated () / 1e9
340
+ peak_mem = getattr ( torch , _DEVICE ) .max_memory_allocated () / 1e9
303
341
print (f"Max memory used: { peak_mem :.02f} GB" )
304
342
logger .log (dict (max_memory_allocated = peak_mem ))
0 commit comments