File tree 1 file changed +9
-1
lines changed
1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -357,6 +357,11 @@ def parse_args(input_args=None):
357
357
action = "store_true" ,
358
358
help = "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass." ,
359
359
)
360
+ parser .add_argument (
361
+ "--upcast_vae" ,
362
+ action = "store_true" ,
363
+ help = "Whether or not to upcast vae to fp32" ,
364
+ )
360
365
parser .add_argument (
361
366
"--learning_rate" ,
362
367
type = float ,
@@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
1094
1099
weight_dtype = torch .bfloat16
1095
1100
1096
1101
# Move vae, transformer and text_encoder to device and cast to weight_dtype
1097
- vae .to (accelerator .device , dtype = torch .float32 )
1102
+ if args .upcast_vae :
1103
+ vae .to (accelerator .device , dtype = torch .float32 )
1104
+ else :
1105
+ vae .to (accelerator .device , dtype = weight_dtype )
1098
1106
transformer .to (accelerator .device , dtype = weight_dtype )
1099
1107
text_encoder_one .to (accelerator .device , dtype = weight_dtype )
1100
1108
text_encoder_two .to (accelerator .device , dtype = weight_dtype )
You can’t perform that action at this time.
0 commit comments