-
Notifications
You must be signed in to change notification settings - Fork 6k
add sana-sprint #11074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sana-sprint #11074
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice Work !! just a heads up this PR might have conflicts with #11040 if merged first |
Wonderful work. Since SANA-Sprint and SANA-1.5 follow the same architecture, so this PR would make SANA-1.5 work as well. |
* 1. update conversion script for sana1.5; 2. add conversion script for sana-sprint; * seperate sana and sana-sprint conversion scripts; * update for upstream * fix the } bug * add a doc for SanaSprintPipeline; * minor update; * make style && make quality
@bot /style |
@bot/ style |
Style fixes have been applied. View the workflow run here. |
cc @lawrence-cj can you do a review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really amazing work! Can't wait for the release ❤️
>>> from diffusers import SanaPipeline | ||
|
||
>>> pipe = SanaPipeline.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example to be updated to SanaSprintPipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yiyixuxu
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other recent models we found that attention mask with shape [B, 1, 1, N]
is faster as the total size is smaller and PyTorch's broadcasting handles it. Something to look into, if we see a benefit all occurrences of this code can be updated.
latents = latents.to(self.vae.dtype) | ||
try: | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
except torch.cuda.OutOfMemoryError as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
For XPU we need to use torch.OutOfMemoryError
, also looks like that will work on CUDA.
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_xpu.py#L446
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_cuda.py#L3950
https://github.com/pytorch/pytorch/blob/00a2c68f67adbd38847845016fd1ab9275cefbab/test/test_cuda.py#L4154
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on this.
add a note about max_timesteps
Co-authored-by: Aryan <aryan@huggingface.co>
@yiyixuxu |
try: | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
except torch.cuda.OutOfMemoryError as e: | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be logger.warning()
?
else: | ||
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here | ||
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float() | ||
print(f"Set timesteps: {self.timesteps}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(f"Set timesteps: {self.timesteps}") |
vibe tests with different timesteps settings