Skip to content

fix vae dtype when accelerate config using --mixed_precision="fp16" #9601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
weight_dtype = torch.bfloat16

# Move vae, transformer and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=torch.float32)
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
Expand Down
Loading