Skip to content

Commit 2cb383f

Browse files
fix vae dtype when accelerate config using --mixed_precision="fp16" (#9601)
* fix vae dtype when accelerate config using --mixed_precision="fp16" * Add param for upcast vae
1 parent 31010ec commit 2cb383f

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,11 @@ def parse_args(input_args=None):
357357
action="store_true",
358358
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
359359
)
360+
parser.add_argument(
361+
"--upcast_vae",
362+
action="store_true",
363+
help="Whether or not to upcast vae to fp32",
364+
)
360365
parser.add_argument(
361366
"--learning_rate",
362367
type=float,
@@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
10941099
weight_dtype = torch.bfloat16
10951100

10961101
# Move vae, transformer and text_encoder to device and cast to weight_dtype
1097-
vae.to(accelerator.device, dtype=torch.float32)
1102+
if args.upcast_vae:
1103+
vae.to(accelerator.device, dtype=torch.float32)
1104+
else:
1105+
vae.to(accelerator.device, dtype=weight_dtype)
10981106
transformer.to(accelerator.device, dtype=weight_dtype)
10991107
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
11001108
text_encoder_two.to(accelerator.device, dtype=weight_dtype)

0 commit comments

Comments
 (0)