Skip to content

add sana-sprint #11074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 21, 2025
Merged

add sana-sprint #11074

merged 22 commits into from
Mar 21, 2025

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Mar 17, 2025

# test sana sprint
"""
python scripts/convert_sana_to_diffusers.py \
  --orig_ckpt_path /raid/.cache/huggingface/yiyi/models--JunsongChen--Sana_Sprint_1600M_1024px/snapshots/8ecfdc7e6269e5b065f5b2cf3fac9a2a1a778c6a/Sana_Sprint_1600M_1024px_36K.pth \
  --model_type SanaSprint_1600M_P1_D20 \
  --image_size 1024 \
  --dump_path /raid/yiyi/Sana-Sprint-yiyi \
  --save_full_pipeline \
  --scheduler_type scm
"""

from diffusers import SanaSprintPipeline
import torch

device = "cuda:0"
dtype = torch.bfloat16

repo = "/raid/yiyi/Sana-Sprint-yiyi"

pipeline = SanaSprintPipeline.from_pretrained(repo, torch_dtype=dtype)
pipeline.to(device)


prompt = "a tiny astronaut hatching from an egg on the moon"

image = pipeline(prompt=prompt, num_inference_steps=2).images[0]
image.save("test_out.png")

yiyi_test_8_out

vibe tests with different timesteps settings

# test sana sprint
# (pipeline)
test_max_timesteps = [1.57080, 1.56830, 1.56580, 1.56454, 1.56246, 1.55830, 1.55413, 1.55080, 1.54580]
test_intermediate_timesteps = [None, 1.0, 1.1, 1.2, 1.3, 1.4]
test_num_inference_steps = [1,2,4]

# test_max_timesteps = [1.57080]
# test_intermediate_timesteps = [None]
# test_num_inference_steps = [1]

from diffusers import SanaSprintPipeline
import torch

device = "cuda:0"
dtype = torch.bfloat16
repo = "/raid/yiyi/Sana-Sprint-yiyi"

def run_sana(pipeline, num_inference_steps, max_timesteps, intermediate_timesteps):
    prompt = "a tiny astronaut hatching from an egg on the moon"
    generator = torch.Generator(device=device).manual_seed(123)
    test_name = f"num_inference_steps_{num_inference_steps}_max_timesteps_{max_timesteps}_intermediate_timesteps_{intermediate_timesteps}"
    print(f"--------------------------------")
    print(f"Running test:")
    print(f"num_inference_steps: {num_inference_steps}")
    print(f"max_timesteps: {max_timesteps}")
    print(f"intermediate_timesteps: {intermediate_timesteps}")
    try:
        image = pipeline(prompt=prompt, num_inference_steps=num_inference_steps, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps, generator=generator).images[0]
        image.save(f"yiyi_test_10_1_out_{test_name}.png")
    except Exception as e:
        print(e)
    print(f"--------------------------------")


pipeline = SanaSprintPipeline.from_pretrained(repo, torch_dtype=dtype)
pipeline.to(device)

for num_inference_steps in test_num_inference_steps:
    for max_timesteps in test_max_timesteps:
        for intermediate_timesteps in test_intermediate_timesteps:
            run_sana(pipeline, num_inference_steps, max_timesteps, intermediate_timesteps)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ishan-modi
Copy link
Contributor

Nice Work !! just a heads up this PR might have conflicts with #11040 if merged first

@lawrence-cj
Copy link
Contributor

lawrence-cj commented Mar 17, 2025

Wonderful work. Since SANA-Sprint and SANA-1.5 follow the same architecture, so this PR would make SANA-1.5 work as well.
@yiyixuxu @sayakpaul

lawrence-cj and others added 10 commits March 16, 2025 20:39
* 1. update conversion script for sana1.5;
2. add conversion script for sana-sprint;

* seperate sana and sana-sprint conversion scripts;

* update for upstream

* fix the } bug

* add a doc for SanaSprintPipeline;

* minor update;

* make style && make quality
@yiyixuxu yiyixuxu requested review from a-r-r-o-w and hlky March 20, 2025 10:47
@yiyixuxu
Copy link
Collaborator Author

@bot /style

@yiyixuxu
Copy link
Collaborator Author

@bot/ style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

@yiyixuxu
Copy link
Collaborator Author

cc @lawrence-cj can you do a review?

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really amazing work! Can't wait for the release ❤️

Comment on lines 114 to 116
>>> from diffusers import SanaPipeline

>>> pipe = SanaPipeline.from_pretrained(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example to be updated to SanaSprintPipeline

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yiyixuxu

Comment on lines +149 to +152
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other recent models we found that attention mask with shape [B, 1, 1, N] is faster as the total size is smaller and PyTorch's broadcasting handles it. Something to look into, if we see a benefit all occurrences of this code can be updated.

latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this.

@lawrence-cj
Copy link
Contributor

lawrence-cj commented Mar 21, 2025

try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be logger.warning()?

else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
print(f"Set timesteps: {self.timesteps}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"Set timesteps: {self.timesteps}")

lawrence-cj and others added 2 commits March 21, 2025 05:57
* change sample prompt;

* only 1024px is supported;
@yiyixuxu yiyixuxu merged commit 8a63aa5 into main Mar 21, 2025
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants