From c952370cb4a59aa5202ba19c51ec38b0930c0382 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 17 Mar 2025 02:14:55 +0100 Subject: [PATCH 01/21] first commit --- src/diffusers/__init__.py | 4 + src/diffusers/models/attention_processor.py | 5 + .../models/transformers/sana_transformer.py | 127 ++- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/sana/__init__.py | 2 + .../pipelines/sana/pipeline_sana_scm.py | 991 ++++++++++++++++++ src/diffusers/schedulers/__init__.py | 3 +- src/diffusers/schedulers/scheduling_scm.py | 237 +++++ 8 files changed, 1364 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/pipelines/sana/pipeline_sana_scm.py create mode 100644 src/diffusers/schedulers/scheduling_scm.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 913816ec9a93..e848d1efea42 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -276,6 +276,7 @@ "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", + "SCMScheduler", ] ) _import_structure["training_utils"] = ["EMAModel"] @@ -421,6 +422,7 @@ "ReduxImageEncoder", "SanaPAGPipeline", "SanaPipeline", + "SanaSCMPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -839,6 +841,7 @@ UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, + SCMScheduler, ) from .training_utils import EMAModel @@ -965,6 +968,7 @@ ReduxImageEncoder, SanaPAGPipeline, SanaPipeline, + SanaSCMPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21d17d6acdab..34276a544160 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -6020,6 +6020,11 @@ def __call__( key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index b8cc96d3532c..bbd95d0d0d2e 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -30,7 +30,9 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from ..embeddings import TimestepEmbedding, Timesteps +import torch.nn.functional as F logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -96,6 +98,102 @@ def forward( return hidden_states +class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): + """ + For Sana. + + Reference: + """ + + def __init__(self, embedding_dim): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + guidance_proj = self.guidance_condition_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype)) + conditioning = timesteps_emb + guidance_emb + + return self.linear(self.silu(conditioning)), conditioning + + + +class SanaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + class SanaTransformerBlock(nn.Module): r""" Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). @@ -115,6 +213,7 @@ def __init__( norm_eps: float = 1e-6, attention_out_bias: bool = True, mlp_ratio: float = 2.5, + qk_norm: Optional[str] = None, ) -> None: super().__init__() @@ -124,6 +223,8 @@ def __init__( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, + kv_heads=num_attention_heads if qk_norm is not None else None, + qk_norm=qk_norm, dropout=dropout, bias=attention_bias, cross_attention_dim=None, @@ -135,13 +236,15 @@ def __init__( self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn2 = Attention( query_dim=dim, + qk_norm=qk_norm, + kv_heads=num_cross_attention_heads if qk_norm is not None else None, cross_attention_dim=cross_attention_dim, heads=num_cross_attention_heads, dim_head=cross_attention_head_dim, dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0(), + processor=SanaAttnProcessor2_0(), ) # 3. Feed-forward @@ -258,6 +361,8 @@ def __init__( norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, interpolation_scale: Optional[int] = None, + guidance_embeds: bool = False, + qk_norm: Optional[str] = None, ) -> None: super().__init__() @@ -276,7 +381,10 @@ def __init__( ) # 2. Additional condition embeddings - self.time_embed = AdaLayerNormSingle(inner_dim) + if guidance_embeds: + self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim) + else: + self.time_embed = AdaLayerNormSingle(inner_dim) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) @@ -296,6 +404,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, mlp_ratio=mlp_ratio, + qk_norm=qk_norm, ) for _ in range(num_layers) ] @@ -372,7 +481,8 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, + timestep: torch.Tensor, + guidance: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -423,9 +533,14 @@ def forward( hidden_states = self.patch_embed(hidden_states) - timestep, embedded_timestep = self.time_embed( - timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) + if guidance is not None: + timestep, embedded_timestep = self.time_embed( + timestep, guidance=guidance, hidden_dtype=hidden_states.dtype + ) + else: + timestep, embedded_timestep = self.time_embed( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 541d1a743bcb..4504e6049143 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -280,7 +280,7 @@ _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] - _import_structure["sana"] = ["SanaPipeline"] + _import_structure["sana"] = ["SanaPipeline", "SanaSCMPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["stable_audio"] = [ @@ -651,7 +651,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .sana import SanaPipeline + from .sana import SanaPipeline, SanaSCMPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 53b6ba762466..72f2402658ac 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_sana"] = ["SanaPipeline"] + _import_structure["pipeline_sana_scm"] = ["SanaSCMPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_sana import SanaPipeline + from .pipeline_sana_scm import SanaSCMPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_sana_scm.py b/src/diffusers/pipelines/sana/pipeline_sana_scm.py new file mode 100644 index 000000000000..cb5056598815 --- /dev/null +++ b/src/diffusers/pipelines/sana/pipeline_sana_scm.py @@ -0,0 +1,991 @@ +# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPipeline + + >>> pipe = SanaPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSCMPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0][:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: List[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> Union[SanaPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`List[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + _, + _, + ) = self.encode_prompt( + prompt, + False, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + # prompt_embeds = torch.load("/raid/yiyi/Sana-Sprint-diffusers/y.pt").to(device, dtype=prompt_embeds.dtype) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # latents = torch.load("/raid/yiyi/Sana-Sprint-diffusers/latents.pt").to(device, dtype=latents.dtype) + + latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + # YiYi TODO: cfg_embed_scale = 0.1 (refactor this out) + guidance = guidance * 0.1 + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # YiYi TODO: refactor this + timesteps = timesteps[:-1] + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype) + + # YiYi TODO: self.scheduler.scale_model_input? + latents_model_input = latents / self.scheduler.config.sigma_data + + # YiYi TODO: refator this out + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + latent_model_input = latents_model_input * torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + # YiYi TODO: refator this out + noise_pred = ((1 - 2 * scm_timestep) * latent_model_input + (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred) / torch.sqrt( + scm_timestep**2 + (1 - scm_timestep) ** 2 + ) + # YiYi TODO: check if this can be refatored into scheduler + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step(noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # YiYi TODO: refator this out + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except torch.cuda.OutOfMemoryError as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index bb9088538653..905ac9f95141 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -74,6 +74,7 @@ _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] + _import_structure["scheduling_scm"] = ["SCMScheduler"] try: if not is_flax_available(): @@ -174,7 +175,7 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler - + from .scheduling_scm import SCMScheduler try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py new file mode 100644 index 000000000000..eb131b7cf658 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -0,0 +1,237 @@ +# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class SCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + denoised: Optional[torch.FloatTensor] = None + + +class SCMScheduler(SchedulerMixin, ConfigMixin): + """ + `SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + # _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "trigflow", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + max_timesteps: float = 1.57080, + intermediate_timesteps: Optional[int] = 1.3, + sigma_data: float = 0.5, + ): + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def set_timesteps( + self, + num_inference_steps: int, + timesteps: torch.Tensor = None, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + if timesteps is not None and len(timesteps) == num_inference_steps + 1: + if isinstance(timesteps, list): + self.timesteps = torch.tensor(timesteps, device=device).float() + elif isinstance(timesteps, torch.Tensor): + self.timesteps = timesteps.to(device).float() + else: + raise ValueError(f"Unsupported timesteps type: {type(timesteps)}") + elif self.config.intermediate_timesteps and num_inference_steps == 2: + self.timesteps = torch.tensor([self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device).float() + elif self.config.intermediate_timesteps: + self.timesteps = torch.linspace(self.config.max_timesteps, 0, num_inference_steps + 1, device=device).float() + warnings.warn( + f"Intermediate timesteps for SCM is not supported when num_inference_steps != 2. " + f"Reset timesteps to {self.timesteps} default max_timesteps" + ) + else: + # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here + self.timesteps = torch.linspace(self.config.max_timesteps, 0, num_inference_steps + 1, device=device).float() + + print(f"Set timesteps: {self.timesteps}") + + def step( + self, + model_output: torch.FloatTensor, + timeindex: int, + timestep: float, + sample: torch.FloatTensor, + generator: torch.Generator = None, + return_dict: bool = True, + ) -> Union[SCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, *optional*, defaults to `True`): + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # 2. compute alphas, betas + t = self.timesteps[timeindex + 1] + s = self.timesteps[timeindex] + + # 4. Different Parameterization: + parameterization = self.config.prediction_type + + if parameterization == "trigflow": + pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output + else: + raise ValueError(f"Unsupported parameterization: {parameterization}") + + # 5. Sample z ~ N(0, I), For MultiStep Inference + # Noise is not used for one-step sampling. + if len(self.timesteps) > 1: + noise = torch.randn(model_output.shape, device=model_output.device, generator=generator) * self.config.sigma_data + prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise + else: + prev_sample = pred_x0 + + if not return_dict: + return (prev_sample, pred_x0) + + return SCMSchedulerOutput(prev_sample=prev_sample, denoised=pred_x0) + + def __len__(self): + return self.config.num_train_timesteps + From 9714187c30913113132a41864e1fe3aa31f9fb22 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Mon, 17 Mar 2025 14:39:59 +0800 Subject: [PATCH 02/21] change name from SanaSCMPipeline to SanaSprintPipeline. (#11076) --- src/diffusers/__init__.py | 4 ++-- src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/sana/__init__.py | 4 ++-- src/diffusers/pipelines/sana/pipeline_sana_scm.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e848d1efea42..8924ea7b0df7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -422,7 +422,7 @@ "ReduxImageEncoder", "SanaPAGPipeline", "SanaPipeline", - "SanaSCMPipeline", + "SanaSprintPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -968,7 +968,7 @@ ReduxImageEncoder, SanaPAGPipeline, SanaPipeline, - SanaSCMPipeline, + SanaSprintPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4504e6049143..50eb4672839f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -280,7 +280,7 @@ _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] - _import_structure["sana"] = ["SanaPipeline", "SanaSCMPipeline"] + _import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["stable_audio"] = [ @@ -651,7 +651,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .sana import SanaPipeline, SanaSCMPipeline + from .sana import SanaPipeline, SanaSprintPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 72f2402658ac..2d7dfde54f8b 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -23,7 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_sana"] = ["SanaPipeline"] - _import_structure["pipeline_sana_scm"] = ["SanaSCMPipeline"] + _import_structure["pipeline_sana_scm"] = ["SanaSprintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_sana import SanaPipeline - from .pipeline_sana_scm import SanaSCMPipeline + from .pipeline_sana_scm import SanaSprintPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_sana_scm.py b/src/diffusers/pipelines/sana/pipeline_sana_scm.py index cb5056598815..c45e8258204d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_scm.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_scm.py @@ -186,9 +186,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class SanaSCMPipeline(DiffusionPipeline, SanaLoraLoaderMixin): +class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): r""" - Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). """ # fmt: off From ae4c3fda10f559387c15d8405411edecdd92160c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 19 Mar 2025 21:32:52 +0100 Subject: [PATCH 03/21] add conversion sript --- scripts/convert_sana_to_diffusers.py | 207 ++++++++++++++++++++------- 1 file changed, 158 insertions(+), 49 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 99a9ff322251..a8c61cccbb58 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -16,7 +16,9 @@ DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, SanaPipeline, + SanaSprintPipeline, SanaTransformer2DModel, + SCMScheduler, ) from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -72,15 +74,41 @@ def main(args): converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - # AdaLN-single LN - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + # Handle different time embedding structure based on model type + if args.model_type == "SanaSprint_1600M_P1_D20": + # For Sana Sprint, the time embedding structure is different + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + # Guidance embedder for Sana Sprint + converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop( + "cfg_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias") + converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop( + "cfg_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias") + else: + # Original Sana time embedding structure + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( + "t_embedder.mlp.0.bias" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( + "t_embedder.mlp.2.bias" + ) # Shared norm. converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") @@ -96,7 +124,7 @@ def main(args): flow_shift = 3.0 # model config - if args.model_type == "SanaMS_1600M_P1_D20": + if args.model_type == "SanaMS_1600M_P1_D20" or args.model_type == "SanaSprint_1600M_P1_D20": layer_num = 20 elif args.model_type == "SanaMS_600M_P1_D28": layer_num = 28 @@ -125,6 +153,15 @@ def main(args): f"blocks.{depth}.attn.proj.bias" ) + # Add Q/K normalization for self-attention (attn1) - needed for Sana Sprint + if args.model_type == "SanaSprint_1600M_P1_D20": + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) + # Feed-forward. converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop( f"blocks.{depth}.mlp.inverted_conv.conv.weight" @@ -155,6 +192,15 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + # Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint + if args.model_type == "SanaSprint_1600M_P1_D20": + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.k_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.cross_attn.proj.weight" ) @@ -169,24 +215,31 @@ def main(args): # Transformer with CTX(): - transformer = SanaTransformer2DModel( - in_channels=32, - out_channels=32, - num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"], - attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"], - num_layers=model_kwargs[args.model_type]["num_layers"], - num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"], - cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"], - cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"], - caption_channels=2304, - mlp_ratio=2.5, - attention_bias=False, - sample_size=args.image_size // 32, - patch_size=1, - norm_elementwise_affine=False, - norm_eps=1e-6, - interpolation_scale=interpolation_scale[args.image_size], - ) + transformer_kwargs = { + "in_channels": 32, + "out_channels": 32, + "num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"], + "attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"], + "num_layers": model_kwargs[args.model_type]["num_layers"], + "num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"], + "cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"], + "cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"], + "caption_channels": 2304, + "mlp_ratio": 2.5, + "attention_bias": False, + "sample_size": args.image_size // 32, + "patch_size": 1, + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "interpolation_scale": interpolation_scale[args.image_size], + } + + # Add qk_norm parameter for Sana Sprint + if args.model_type == "SanaSprint_1600M_P1_D20": + transformer_kwargs["qk_norm"] = "rms_norm_across_heads" + transformer_kwargs["guidance_embeds"] = True + + transformer = SanaTransformer2DModel(**transformer_kwargs) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_state_dict) @@ -196,6 +249,8 @@ def main(args): try: state_dict.pop("y_embedder.y_embedding") state_dict.pop("pos_embed") + state_dict.pop("logvar_linear.weight") + state_dict.pop("logvar_linear.bias") except KeyError: print("y_embedder.y_embedding or pos_embed not found in the state_dict") @@ -210,7 +265,7 @@ def main(args): print( colored( f"Only saving transformer model of {args.model_type}. " - f"Set --save_full_pipeline to save the whole SanaPipeline", + f"Set --save_full_pipeline to save the whole Pipeline", "green", attrs=["bold"], ) @@ -219,7 +274,7 @@ def main(args): os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant ) else: - print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"])) + print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32) @@ -231,25 +286,64 @@ def main(args): text_encoder_model_path, torch_dtype=torch.bfloat16 ).get_decoder() - # Scheduler - if args.scheduler_type == "flow-dpm_solver": - scheduler = DPMSolverMultistepScheduler( - flow_shift=flow_shift, - use_flow_sigmas=True, - prediction_type="flow_prediction", + # Choose the appropriate pipeline and scheduler based on model type + if args.model_type == "SanaSprint_1600M_P1_D20": + # Force SCM Scheduler for Sana Sprint regardless of scheduler_type + if args.scheduler_type != "scm": + print( + colored( + f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model", + "yellow", + attrs=["bold"], + ) + ) + + # SCM Scheduler for Sana Sprint + scheduler_config = { + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "clip_sample": True, + "clip_sample_range": 1.0, + "dynamic_thresholding_ratio": 0.995, + "num_train_timesteps": 1000, + "prediction_type": "trigflow", + "rescale_betas_zero_snr": False, + "sample_max_value": 1.0, + "set_alpha_to_one": True, + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "leading", + } + scheduler = SCMScheduler(**scheduler_config) + pipe = SanaSprintPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=ae, + scheduler=scheduler, ) - elif args.scheduler_type == "flow-euler": - scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) else: - raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") - - pipe = SanaPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=ae, - scheduler=scheduler, - ) + # Original Sana scheduler + if args.scheduler_type == "flow-dpm_solver": + scheduler = DPMSolverMultistepScheduler( + flow_shift=flow_shift, + use_flow_sigmas=True, + prediction_type="flow_prediction", + ) + elif args.scheduler_type == "flow-euler": + scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) + else: + raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") + + pipe = SanaPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=ae, + scheduler=scheduler, + ) + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant) @@ -281,10 +375,17 @@ def main(args): help="Image size of pretrained model, 512, 1024, 2048 or 4096.", ) parser.add_argument( - "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] + "--model_type", + default="SanaMS_1600M_P1_D20", + type=str, + choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaSprint_1600M_P1_D20"], ) parser.add_argument( - "--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"] + "--scheduler_type", + default="flow-dpm_solver", + type=str, + choices=["flow-dpm_solver", "flow-euler", "scm"], + help="Scheduler type to use. Use 'scm' for Sana Sprint models.", ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.") @@ -309,6 +410,14 @@ def main(args): "cross_attention_dim": 1152, "num_layers": 28, }, + "SanaSprint_1600M_P1_D20": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 20, + }, } device = "cuda" if torch.cuda.is_available() else "cpu" From 0d6309ae003d08114c1f553e58b34f24d0eabcd1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 19 Mar 2025 21:33:06 +0100 Subject: [PATCH 04/21] style --- src/diffusers/__init__.py | 4 +-- .../models/transformers/sana_transformer.py | 9 ++---- .../pipelines/sana/pipeline_sana_scm.py | 13 +++++--- src/diffusers/schedulers/__init__.py | 4 +-- src/diffusers/schedulers/scheduling_scm.py | 32 ++++++++++++------- 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8924ea7b0df7..f6bdf1b7506a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -271,12 +271,12 @@ "RePaintScheduler", "SASolverScheduler", "SchedulerMixin", + "SCMScheduler", "ScoreSdeVeScheduler", "TCDScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", - "SCMScheduler", ] ) _import_structure["training_utils"] = ["EMAModel"] @@ -836,12 +836,12 @@ RePaintScheduler, SASolverScheduler, SchedulerMixin, + SCMScheduler, ScoreSdeVeScheduler, TCDScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, - SCMScheduler, ) from .training_utils import EMAModel diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index bbd95d0d0d2e..0ad02cffdd2b 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config @@ -23,16 +24,13 @@ from ..attention_processor import ( Attention, AttentionProcessor, - AttnProcessor2_0, SanaLinearAttnProcessor2_0, ) -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm -from ..embeddings import TimestepEmbedding, Timesteps -import torch.nn.functional as F logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -127,7 +125,6 @@ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_ return self.linear(self.silu(conditioning)), conditioning - class SanaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -144,7 +141,6 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -194,6 +190,7 @@ def __call__( return hidden_states + class SanaTransformerBlock(nn.Module): r""" Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). diff --git a/src/diffusers/pipelines/sana/pipeline_sana_scm.py b/src/diffusers/pipelines/sana/pipeline_sana_scm.py index c45e8258204d..33e62c09bdf2 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_scm.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_scm.py @@ -936,14 +936,17 @@ def __call__( )[0] # YiYi TODO: refator this out - noise_pred = ((1 - 2 * scm_timestep) * latent_model_input + (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred) / torch.sqrt( - scm_timestep**2 + (1 - scm_timestep) ** 2 - ) - # YiYi TODO: check if this can be refatored into scheduler + noise_pred = ( + (1 - 2 * scm_timestep) * latent_model_input + + (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred + ) / torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) + # YiYi TODO: check if this can be refatored into scheduler noise_pred = noise_pred.float() * self.scheduler.config.sigma_data # compute previous image: x_t -> x_t-1 - latents, denoised = self.scheduler.step(noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False) + latents, denoised = self.scheduler.step( + noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False + ) if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 905ac9f95141..05cd21cd0034 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -68,13 +68,13 @@ _import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] + _import_structure["scheduling_scm"] = ["SCMScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] - _import_structure["scheduling_scm"] = ["SCMScheduler"] try: if not is_flax_available(): @@ -169,13 +169,13 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler + from .scheduling_scm import SCMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler - from .scheduling_scm import SCMScheduler try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index eb131b7cf658..02bc57a8393f 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -15,7 +15,6 @@ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion -import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -25,18 +24,17 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..schedulers.scheduling_utils import SchedulerMixin from ..utils import BaseOutput, logging -from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name - @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class SCMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. + Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the @@ -53,9 +51,9 @@ class SCMSchedulerOutput(BaseOutput): class SCMScheduler(SchedulerMixin, ConfigMixin): """ `SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with - non-Markovian guidance. - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. + non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass + documentation for the generic methods the library implements for all schedulers such as loading and saving. + Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. @@ -140,6 +138,7 @@ def set_timesteps( ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. @@ -161,16 +160,22 @@ def set_timesteps( else: raise ValueError(f"Unsupported timesteps type: {type(timesteps)}") elif self.config.intermediate_timesteps and num_inference_steps == 2: - self.timesteps = torch.tensor([self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device).float() + self.timesteps = torch.tensor( + [self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device + ).float() elif self.config.intermediate_timesteps: - self.timesteps = torch.linspace(self.config.max_timesteps, 0, num_inference_steps + 1, device=device).float() - warnings.warn( + self.timesteps = torch.linspace( + self.config.max_timesteps, 0, num_inference_steps + 1, device=device + ).float() + logger.warning( f"Intermediate timesteps for SCM is not supported when num_inference_steps != 2. " f"Reset timesteps to {self.timesteps} default max_timesteps" ) else: # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here - self.timesteps = torch.linspace(self.config.max_timesteps, 0, num_inference_steps + 1, device=device).float() + self.timesteps = torch.linspace( + self.config.max_timesteps, 0, num_inference_steps + 1, device=device + ).float() print(f"Set timesteps: {self.timesteps}") @@ -186,6 +191,7 @@ def step( """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). + Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. @@ -222,7 +228,10 @@ def step( # 5. Sample z ~ N(0, I), For MultiStep Inference # Noise is not used for one-step sampling. if len(self.timesteps) > 1: - noise = torch.randn(model_output.shape, device=model_output.device, generator=generator) * self.config.sigma_data + noise = ( + torch.randn(model_output.shape, device=model_output.device, generator=generator) + * self.config.sigma_data + ) prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise else: prev_sample = pred_x0 @@ -234,4 +243,3 @@ def step( def __len__(self): return self.config.num_train_timesteps - From 5b19b22685ed2c6e1df8c21fd6bbad1904a8fe23 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 19 Mar 2025 21:37:57 +0100 Subject: [PATCH 05/21] copies --- src/diffusers/schedulers/scheduling_scm.py | 12 ++++++------ src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 02bc57a8393f..da27666babff 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -30,22 +30,22 @@ @dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM class SCMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ - prev_sample: torch.FloatTensor - denoised: Optional[torch.FloatTensor] = None + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None class SCMScheduler(SchedulerMixin, ConfigMixin): @@ -239,7 +239,7 @@ def step( if not return_dict: return (prev_sample, pred_x0) - return SCMSchedulerOutput(prev_sample=prev_sample, denoised=pred_x0) + return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 31d2e1e2d78d..d62179951fbf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1834,6 +1834,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SCMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 841ffbdafa52..30aef8955092 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1502,6 +1502,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SanaSprintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 4eef82b2c952d4ad289f4e2c17ada5c3ed4f32a3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Mar 2025 02:44:06 +0100 Subject: [PATCH 06/21] pipeline_sana_scm -> pipeline_sana_sprint --- src/diffusers/pipelines/sana/__init__.py | 4 ++-- .../sana/{pipeline_sana_scm.py => pipeline_sana_sprint.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/diffusers/pipelines/sana/{pipeline_sana_scm.py => pipeline_sana_sprint.py} (100%) diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 2d7dfde54f8b..1393b37e2d3a 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -23,7 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_sana"] = ["SanaPipeline"] - _import_structure["pipeline_sana_scm"] = ["SanaSprintPipeline"] + _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,7 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_sana import SanaPipeline - from .pipeline_sana_scm import SanaSprintPipeline + from .pipeline_sana_sprint import SanaSprintPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_sana_scm.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py similarity index 100% rename from src/diffusers/pipelines/sana/pipeline_sana_scm.py rename to src/diffusers/pipelines/sana/pipeline_sana_sprint.py From 398ca0c938228e403c5f17e3a0dadd411e93cc97 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Mar 2025 02:55:00 +0100 Subject: [PATCH 07/21] remove unused __init__ arg for scm scheduler --- scripts/convert_sana_to_diffusers.py | 15 +--- src/diffusers/schedulers/scheduling_scm.py | 79 +++++++--------------- 2 files changed, 28 insertions(+), 66 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index a8c61cccbb58..a8bc1a51c13a 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -300,20 +300,11 @@ def main(args): # SCM Scheduler for Sana Sprint scheduler_config = { - "beta_end": 0.02, - "beta_schedule": "linear", - "beta_start": 0.0001, - "clip_sample": True, - "clip_sample_range": 1.0, - "dynamic_thresholding_ratio": 0.995, "num_train_timesteps": 1000, "prediction_type": "trigflow", - "rescale_betas_zero_snr": False, - "sample_max_value": 1.0, - "set_alpha_to_one": True, - "steps_offset": 0, - "thresholding": False, - "timestep_spacing": "leading", + "max_timesteps": 1.57080, + "intermediate_timesteps": 1.3, + "sigma_data": 0.5, } scheduler = SCMScheduler(**scheduler_config) pipe = SanaSprintPipeline( diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index da27666babff..8a824ae4f1fb 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -16,7 +16,7 @@ # and https://github.com/hojonathanho/diffusion from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -57,45 +57,14 @@ class SCMScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): - The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): - The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): - Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - clip_sample (`bool`, defaults to `True`): - Clip the predicted sample for numerical stability. - clip_sample_range (`float`, defaults to 1.0): - The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - set_alpha_to_one (`bool`, defaults to `True`): - Each diffusion step uses the alphas product value at that step and at the previous one. For the final step - there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the alpha value at step 0. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable - Diffusion. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, defaults to `"leading"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - rescale_betas_zero_snr (`bool`, defaults to `False`): - Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and - dark samples instead of limiting it to samples with medium brightness. Loosely related to - [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + prediction_type (`str`, defaults to `trigflow`): + Prediction type of the scheduler function. Currently only supports "trigflow". + max_timesteps (`float`, defaults to 1.57080): + The maximum timestep value used in the diffusion process. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used when num_inference_steps=2. + sigma_data (`float`, defaults to 0.5): + The standard deviation of the noise added during multi-step inference. """ # _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -105,24 +74,26 @@ class SCMScheduler(SchedulerMixin, ConfigMixin): def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, prediction_type: str = "trigflow", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - clip_sample_range: float = 1.0, - sample_max_value: float = 1.0, - timestep_spacing: str = "leading", - rescale_betas_zero_snr: bool = False, max_timesteps: float = 1.57080, - intermediate_timesteps: Optional[int] = 1.3, + intermediate_timesteps: Optional[float] = 1.3, sigma_data: float = 0.5, ): + """ + Initialize the SCM scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + prediction_type (`str`, defaults to `trigflow`): + Prediction type of the scheduler function. Currently only supports "trigflow". + max_timesteps (`float`, defaults to 1.57080): + The maximum timestep value used in the diffusion process. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used when num_inference_steps=2. + sigma_data (`float`, defaults to 0.5): + The standard deviation of the noise added during multi-step inference. + """ # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 From 4e5a9efdc29828babc114d2da1f987fc7da3a61c Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Thu, 20 Mar 2025 16:25:46 +0800 Subject: [PATCH 08/21] update conversion script for SANA-1.5 and SANA-Sprint (#11082) * 1. update conversion script for sana1.5; 2. add conversion script for sana-sprint; * seperate sana and sana-sprint conversion scripts; * update for upstream * fix the } bug * add a doc for SanaSprintPipeline; * minor update; * make style && make quality --- docs/source/en/api/pipelines/sana_sprint.md | 96 +++++++++++++++++++++ scripts/convert_sana_to_diffusers.py | 75 ++++++++++++---- 2 files changed, 155 insertions(+), 16 deletions(-) create mode 100644 docs/source/en/api/pipelines/sana_sprint.md diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md new file mode 100644 index 000000000000..482ef3c2c99d --- /dev/null +++ b/docs/source/en/api/pipelines/sana_sprint.md @@ -0,0 +1,96 @@ + + +# SanaSprintPipeline + +
+ LoRA +
+ +[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA and MIT HAN Lab, by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han + +The abstract from the paper is: + +*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/). + +Available models: + +| Model | Recommended dtype | +|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:| +| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` | +| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` | + +Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information. + +Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + + +## Quantization + +Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. + +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes. + +```py +import torch +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline +from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel + +quant_config = BitsAndBytesConfig(load_in_8bit=True) +text_encoder_8bit = AutoModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="text_encoder", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +transformer_8bit = SanaTransformer2DModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +pipeline = SanaSprintPipeline.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + text_encoder=text_encoder_8bit, + transformer=transformer_8bit, + torch_dtype=torch.bfloat16, + device_map="balanced", +) + +prompt = "a tiny astronaut hatching from an egg on the moon" +image = pipeline(prompt).images[0] +image.save("sana.png") +``` + +## SanaSprintPipeline + +[[autodoc]] SanaSprintPipeline + - all + - __call__ + + +## SanaPipelineOutput + +[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index a8bc1a51c13a..47e932ba5070 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -27,6 +27,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth", "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", @@ -75,7 +76,8 @@ def main(args): converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") # Handle different time embedding structure based on model type - if args.model_type == "SanaSprint_1600M_P1_D20": + + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: # For Sana Sprint, the time embedding structure is different converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( "t_embedder.mlp.0.weight" @@ -128,10 +130,18 @@ def main(args): layer_num = 20 elif args.model_type == "SanaMS_600M_P1_D28": layer_num = 28 + elif args.model_type == "SanaMS_4800M_P1_D60": + layer_num = 60 else: raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} + qk_norm = "rms_norm_across_heads" if args.model_type in [ + "SanaMS1.5_1600M_P1_D20", + "SanaMS1.5_4800M_P1_D60", + "SanaSprint_600M_P1_D28", + "SanaSprint_1600M_P1_D20" + ] else None for depth in range(layer_num): # Transformer blocks. @@ -145,6 +155,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + if qk_norm is not None: + # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) # Projection. converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.attn.proj.weight" @@ -191,6 +209,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + if qk_norm is not None: + # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.k_norm.weight" + ) # Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint if args.model_type == "SanaSprint_1600M_P1_D20": @@ -235,8 +261,7 @@ def main(args): } # Add qk_norm parameter for Sana Sprint - if args.model_type == "SanaSprint_1600M_P1_D20": - transformer_kwargs["qk_norm"] = "rms_norm_across_heads" + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: transformer_kwargs["guidance_embeds"] = True transformer = SanaTransformer2DModel(**transformer_kwargs) @@ -271,15 +296,15 @@ def main(args): ) ) transformer.save_pretrained( - os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant + os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB" ) else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32) + ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32) # Text Encoder - text_encoder_model_path = "google/gemma-2-2b-it" + text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) tokenizer.padding_side = "right" text_encoder = AutoModelForCausalLM.from_pretrained( @@ -287,7 +312,8 @@ def main(args): ).get_decoder() # Choose the appropriate pipeline and scheduler based on model type - if args.model_type == "SanaSprint_1600M_P1_D20": + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: + # Force SCM Scheduler for Sana Sprint regardless of scheduler_type if args.scheduler_type != "scm": print( @@ -335,7 +361,7 @@ def main(args): scheduler=scheduler, ) - pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB") DTYPE_MAPPING = { @@ -344,12 +370,6 @@ def main(args): "bf16": torch.bfloat16, } -VARIANT_MAPPING = { - "fp32": None, - "fp16": "fp16", - "bf16": "bf16", -} - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -369,7 +389,7 @@ def main(args): "--model_type", default="SanaMS_1600M_P1_D20", type=str, - choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaSprint_1600M_P1_D20"], + choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaMS_4800M_P1_D60", "SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"], ) parser.add_argument( "--scheduler_type", @@ -400,6 +420,30 @@ def main(args): "cross_attention_head_dim": 72, "cross_attention_dim": 1152, "num_layers": 28, + }, + "SanaMS1.5_1600M_P1_D20": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 20, + }, + "SanaMS1.5__4800M_P1_D60": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 60, + }, + "SanaSprint_600M_P1_D28": { + "num_attention_heads": 36, + "attention_head_dim": 32, + "num_cross_attention_heads": 16, + "cross_attention_head_dim": 72, + "cross_attention_dim": 1152, + "num_layers": 28, }, "SanaSprint_1600M_P1_D20": { "num_attention_heads": 70, @@ -413,6 +457,5 @@ def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = DTYPE_MAPPING[args.dtype] - variant = VARIANT_MAPPING[args.dtype] main(args) From 8070495df13f215986dbf3bd3d6a1d2850bfc3d9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Mar 2025 11:40:13 +0100 Subject: [PATCH 09/21] up --- scripts/convert_sana_to_diffusers.py | 2 - .../models/transformers/sana_transformer.py | 1 + src/diffusers/pipelines/sana/pipeline_sana.py | 133 +++++---- .../pipelines/sana/pipeline_sana_sprint.py | 268 ++++++++---------- src/diffusers/schedulers/scheduling_scm.py | 104 +++++-- 5 files changed, 279 insertions(+), 229 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index a8bc1a51c13a..53bba6f29c2a 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -302,8 +302,6 @@ def main(args): scheduler_config = { "num_train_timesteps": 1000, "prediction_type": "trigflow", - "max_timesteps": 1.57080, - "intermediate_timesteps": 1.3, "sigma_data": 0.5, } scheduler = SCMScheduler(**scheduler_config) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 0ad02cffdd2b..61b56be96651 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -359,6 +359,7 @@ def __init__( norm_eps: float = 1e-6, interpolation_scale: Optional[int] = None, guidance_embeds: bool = False, + guidance_embeds_scale: float = 0.1, qk_norm: Optional[str] = None, ) -> None: super().__init__() diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 460e7e2a237a..6474b13d0491 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -248,6 +248,65 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( self, prompt: Union[str, List[str]], @@ -296,6 +355,13 @@ def encode_prompt( if device is None: device = self._execution_device + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): @@ -320,43 +386,18 @@ def encode_prompt( select_index = [0] + list(range(-max_length + 1, 0)) if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - - # prepare complex human instruction - if not complex_human_instruction: - max_length_all = max_length - else: - chi_prompt = "\n".join(complex_human_instruction) - prompt = [chi_prompt + p for p in prompt] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = num_chi_prompt_tokens + max_length - 2 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length_all, - truncation=True, - add_special_tokens=True, - return_tensors="pt", + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0][:, select_index] + prompt_embeds = prompt_embeds[:, select_index] prompt_attention_mask = prompt_attention_mask[:, select_index] - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.text_encoder is not None: - dtype = self.text_encoder.dtype - else: - dtype = None - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -366,25 +407,15 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, ) - negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 33e62c09bdf2..5323ef99ae99 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -248,17 +248,73 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( self, prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 300, complex_human_instruction: Optional[List[str]] = None, @@ -270,12 +326,7 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` - instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For - PixArt-Alpha, this should be "". - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): @@ -283,8 +334,6 @@ def encode_prompt( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. @@ -296,6 +345,13 @@ def encode_prompt( if device is None: device = self._execution_device + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): @@ -320,43 +376,18 @@ def encode_prompt( select_index = [0] + list(range(-max_length + 1, 0)) if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - - # prepare complex human instruction - if not complex_human_instruction: - max_length_all = max_length - else: - chi_prompt = "\n".join(complex_human_instruction) - prompt = [chi_prompt + p for p in prompt] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = num_chi_prompt_tokens + max_length - 2 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length_all, - truncation=True, - add_special_tokens=True, - return_tensors="pt", + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, ) - text_input_ids = text_inputs.input_ids - - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0][:, select_index] + prompt_embeds = prompt_embeds[:, select_index] prompt_attention_mask = prompt_attention_mask[:, select_index] - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.text_encoder is not None: - dtype = self.text_encoder.dtype - else: - dtype = None - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -364,49 +395,13 @@ def encode_prompt( prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None if self.text_encoder is not None: if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -431,12 +426,13 @@ def check_inputs( prompt, height, width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, callback_on_step_end_tensor_inputs=None, - negative_prompt=None, prompt_embeds=None, - negative_prompt_embeds=None, prompt_attention_mask=None, - negative_prompt_attention_mask=None, ): if height % 32 != 0 or width % 32 != 0: raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") @@ -460,37 +456,21 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") - if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: - raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: - raise ValueError( - "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" - f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" - f" {negative_prompt_attention_mask.shape}." - ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): @@ -632,6 +612,7 @@ def _clean_caption(self, caption): return caption.strip() + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): if latents is not None: return latents.to(device=device, dtype=dtype) @@ -659,10 +640,6 @@ def guidance_scale(self): def attention_kwargs(self): return self._attention_kwargs - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 - @property def num_timesteps(self): return self._num_timesteps @@ -676,10 +653,10 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, + num_inference_steps: int = 2, timesteps: List[int] = None, - sigmas: List[float] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: int = 1024, @@ -689,8 +666,6 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, clean_caption: bool = False, @@ -724,14 +699,14 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -822,15 +797,16 @@ def __call__( height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( - prompt, - height, - width, - callback_on_step_end_tensor_inputs, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, ) self._guidance_scale = guidance_scale @@ -852,29 +828,24 @@ def __call__( ( prompt_embeds, prompt_attention_mask, - _, - _, ) = self.encode_prompt( prompt, - False, - negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, max_sequence_length=max_sequence_length, complex_human_instruction=complex_human_instruction, lora_scale=lora_scale, ) - # prompt_embeds = torch.load("/raid/yiyi/Sana-Sprint-diffusers/y.pt").to(device, dtype=prompt_embeds.dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, device, timesteps, sigmas=None, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -889,14 +860,11 @@ def __call__( latents, ) - # latents = torch.load("/raid/yiyi/Sana-Sprint-diffusers/latents.pt").to(device, dtype=latents.dtype) - latents = latents * self.scheduler.config.sigma_data guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) - # YiYi TODO: cfg_embed_scale = 0.1 (refactor this out) - guidance = guidance * 0.1 + guidance = guidance * self.transformer.config.guidance_embeds_scale # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -915,11 +883,8 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype) - - # YiYi TODO: self.scheduler.scale_model_input? latents_model_input = latents / self.scheduler.config.sigma_data - # YiYi TODO: refator this out scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) latent_model_input = latents_model_input * torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) latent_model_input = latent_model_input.to(prompt_embeds.dtype) @@ -935,17 +900,15 @@ def __call__( attention_kwargs=self.attention_kwargs, )[0] - # YiYi TODO: refator this out noise_pred = ( (1 - 2 * scm_timestep) * latent_model_input + (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred ) / torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) - # YiYi TODO: check if this can be refatored into scheduler noise_pred = noise_pred.float() * self.scheduler.config.sigma_data # compute previous image: x_t -> x_t-1 latents, denoised = self.scheduler.step( - noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False ) if callback_on_step_end is not None: @@ -965,7 +928,6 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - # YiYi TODO: refator this out latents = denoised / self.scheduler.config.sigma_data if output_type == "latent": image = latents diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 8a824ae4f1fb..59dfeec114b2 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -75,8 +75,6 @@ def __init__( self, num_train_timesteps: int = 1000, prediction_type: str = "trigflow", - max_timesteps: float = 1.57080, - intermediate_timesteps: Optional[float] = 1.3, sigma_data: float = 0.5, ): """ @@ -87,10 +85,6 @@ def __init__( The number of diffusion steps to train the model. prediction_type (`str`, defaults to `trigflow`): Prediction type of the scheduler function. Currently only supports "trigflow". - max_timesteps (`float`, defaults to 1.57080): - The maximum timestep value used in the diffusion process. - intermediate_timesteps (`float`, *optional*, defaults to 1.3): - The intermediate timestep value used when num_inference_steps=2. sigma_data (`float`, defaults to 0.5): The standard deviation of the noise added during multi-step inference. """ @@ -101,11 +95,35 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self._step_index = None + self._begin_index = None + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps( self, num_inference_steps: int, timesteps: torch.Tensor = None, device: Union[str, torch.device] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -113,6 +131,12 @@ def set_timesteps( Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + timesteps (`torch.Tensor`, *optional*): + Custom timesteps to use for the denoising process. + max_timesteps (`float`, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( @@ -121,39 +145,68 @@ def set_timesteps( f" maximal {self.config.num_train_timesteps} timesteps." ) + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + self.num_inference_steps = num_inference_steps - if timesteps is not None and len(timesteps) == num_inference_steps + 1: + if timesteps is not None: if isinstance(timesteps, list): self.timesteps = torch.tensor(timesteps, device=device).float() elif isinstance(timesteps, torch.Tensor): self.timesteps = timesteps.to(device).float() else: raise ValueError(f"Unsupported timesteps type: {type(timesteps)}") - elif self.config.intermediate_timesteps and num_inference_steps == 2: + elif intermediate_timesteps is not None: self.timesteps = torch.tensor( - [self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device + [max_timesteps, intermediate_timesteps, 0], device=device ).float() - elif self.config.intermediate_timesteps: - self.timesteps = torch.linspace( - self.config.max_timesteps, 0, num_inference_steps + 1, device=device - ).float() - logger.warning( - f"Intermediate timesteps for SCM is not supported when num_inference_steps != 2. " - f"Reset timesteps to {self.timesteps} default max_timesteps" - ) else: # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here self.timesteps = torch.linspace( - self.config.max_timesteps, 0, num_inference_steps + 1, device=device + max_timesteps, 0, num_inference_steps + 1, device=device ).float() - print(f"Set timesteps: {self.timesteps}") + + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() def step( self, model_output: torch.FloatTensor, - timeindex: int, timestep: float, sample: torch.FloatTensor, generator: torch.Generator = None, @@ -183,10 +236,13 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - + + if self.step_index is None: + self._init_step_index(timestep) + # 2. compute alphas, betas - t = self.timesteps[timeindex + 1] - s = self.timesteps[timeindex] + t = self.timesteps[self.step_index + 1] + s = self.timesteps[self.step_index] # 4. Different Parameterization: parameterization = self.config.prediction_type @@ -206,6 +262,8 @@ def step( prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise else: prev_sample = pred_x0 + + self._step_index += 1 if not return_dict: return (prev_sample, pred_x0) From be73b5960cd9692cfe219ff499e7dd0073cbf8e7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Mar 2025 11:44:45 +0100 Subject: [PATCH 10/21] up upp --- src/diffusers/pipelines/sana/pipeline_sana.py | 1 - .../pipelines/sana/pipeline_sana_sprint.py | 36 ++++++------------- src/diffusers/schedulers/scheduling_scm.py | 28 +++++++-------- 3 files changed, 23 insertions(+), 42 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 6474b13d0491..76934d055c56 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -306,7 +306,6 @@ def _get_gemma_prompt_embeds( return prompt_embeds, prompt_attention_mask - def encode_prompt( self, prompt: Union[str, List[str]], diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 5323ef99ae99..c231876bbf0e 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -196,7 +196,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): # fmt: on model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, @@ -307,7 +307,6 @@ def _get_gemma_prompt_embeds( return prompt_embeds, prompt_attention_mask - def encode_prompt( self, prompt: Union[str, List[str]], @@ -361,13 +360,6 @@ def encode_prompt( if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - if getattr(self, "tokenizer", None) is not None: self.tokenizer.padding_side = "right" @@ -395,7 +387,6 @@ def encode_prompt( prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - if self.text_encoder is not None: if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -461,17 +452,16 @@ def check_inputs( if timesteps is not None and len(timesteps) != num_inference_steps + 1: raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") - + if timesteps is not None and max_timesteps is not None: raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") - + if timesteps is None and max_timesteps is None: raise ValueError("Should provide either `timesteps` or `max_timesteps`.") - + if intermediate_timesteps is not None and num_inference_steps != 2: raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): @@ -692,10 +682,6 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -733,11 +719,6 @@ def __call__( Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_prompt_attention_mask (`torch.Tensor`, *optional*): - Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -842,7 +823,13 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas=None, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, ) if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(0) @@ -919,7 +906,6 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 59dfeec114b2..06212c9e7739 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -101,11 +101,11 @@ def __init__( @property def step_index(self): return self._step_index - + @property def begin_index(self): return self._begin_index - + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ @@ -147,16 +147,16 @@ def set_timesteps( if timesteps is not None and len(timesteps) != num_inference_steps + 1: raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") - + if timesteps is not None and max_timesteps is not None: raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") - + if timesteps is None and max_timesteps is None: raise ValueError("Should provide either `timesteps` or `max_timesteps`.") - + if intermediate_timesteps is not None and num_inference_steps != 2: raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") - + self.num_inference_steps = num_inference_steps if timesteps is not None: @@ -167,16 +167,12 @@ def set_timesteps( else: raise ValueError(f"Unsupported timesteps type: {type(timesteps)}") elif intermediate_timesteps is not None: - self.timesteps = torch.tensor( - [max_timesteps, intermediate_timesteps, 0], device=device - ).float() + self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float() else: # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here - self.timesteps = torch.linspace( - max_timesteps, 0, num_inference_steps + 1, device=device - ).float() + self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float() print(f"Set timesteps: {self.timesteps}") - + self._step_index = None self._begin_index = None @@ -236,10 +232,10 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - + if self.step_index is None: self._init_step_index(timestep) - + # 2. compute alphas, betas t = self.timesteps[self.step_index + 1] s = self.timesteps[self.step_index] @@ -262,7 +258,7 @@ def step( prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise else: prev_sample = pred_x0 - + self._step_index += 1 if not return_dict: From 9cd5f1e66d7c3509d9b1eb80a2286ba5ae174687 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 20 Mar 2025 10:49:03 +0000 Subject: [PATCH 11/21] Apply style fixes --- scripts/convert_sana_to_diffusers.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 535a7280d291..fca3842141ec 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -136,12 +136,12 @@ def main(args): raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} - qk_norm = "rms_norm_across_heads" if args.model_type in [ - "SanaMS1.5_1600M_P1_D20", - "SanaMS1.5_4800M_P1_D60", - "SanaSprint_600M_P1_D28", - "SanaSprint_1600M_P1_D20" - ] else None + qk_norm = ( + "rms_norm_across_heads" + if args.model_type + in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"] + else None + ) for depth in range(layer_num): # Transformer blocks. @@ -313,7 +313,6 @@ def main(args): # Choose the appropriate pipeline and scheduler based on model type if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: - # Force SCM Scheduler for Sana Sprint regardless of scheduler_type if args.scheduler_type != "scm": print( @@ -387,7 +386,13 @@ def main(args): "--model_type", default="SanaMS_1600M_P1_D20", type=str, - choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28", "SanaMS_4800M_P1_D60", "SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"], + choices=[ + "SanaMS_1600M_P1_D20", + "SanaMS_600M_P1_D28", + "SanaMS_4800M_P1_D60", + "SanaSprint_1600M_P1_D20", + "SanaSprint_600M_P1_D28", + ], ) parser.add_argument( "--scheduler_type", @@ -419,7 +424,7 @@ def main(args): "cross_attention_dim": 1152, "num_layers": 28, }, - "SanaMS1.5_1600M_P1_D20": { + "SanaMS1.5_1600M_P1_D20": { "num_attention_heads": 70, "attention_head_dim": 32, "num_cross_attention_heads": 20, From 8e4f71177e181723a38dbea84e1821126f8119e5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 20 Mar 2025 22:08:52 +0530 Subject: [PATCH 12/21] [docs] add a note about max_timesteps (#11122) add a note about max_timesteps --- docs/source/en/api/pipelines/sana_sprint.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md index 482ef3c2c99d..8db4576cf579 100644 --- a/docs/source/en/api/pipelines/sana_sprint.md +++ b/docs/source/en/api/pipelines/sana_sprint.md @@ -18,7 +18,7 @@ LoRA -[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA and MIT HAN Lab, by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han +[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han The abstract from the paper is: @@ -84,6 +84,10 @@ image = pipeline(prompt).images[0] image.save("sana.png") ``` +## Setting `max_timesteps` + +Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper. + ## SanaSprintPipeline [[autodoc]] SanaSprintPipeline From 1de087e16f19203ee0c6d5ac1aafeb580babd760 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Mar 2025 23:54:04 +0100 Subject: [PATCH 13/21] add to torctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d1805ff605d8..d39b5a52d2fe 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,6 +496,8 @@ title: PixArt-Σ - local: api/pipelines/sana title: Sana + - local: api/pipelines/sana_sprint + title: Sana Sprint - local: api/pipelines/self_attention_guidance title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion From 3734af8eac1eceff5842130660d82ed784a52880 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 20 Mar 2025 12:56:21 -1000 Subject: [PATCH 14/21] Apply suggestions from code review Co-authored-by: Aryan --- .../models/transformers/sana_transformer.py | 6 ------ .../pipelines/sana/pipeline_sana_sprint.py | 21 ------------------- src/diffusers/schedulers/scheduling_scm.py | 8 +------ 3 files changed, 1 insertion(+), 34 deletions(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 61b56be96651..f7c73231725d 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -97,12 +97,6 @@ def forward( class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): - """ - For Sana. - - Reference: - """ - def __init__(self, embedding_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index c231876bbf0e..88d66c55cfc3 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -394,24 +394,6 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - def check_inputs( self, prompt, @@ -853,9 +835,6 @@ def __call__( guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) guidance = guidance * self.transformer.config.guidance_embeds_scale - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # YiYi TODO: refactor this timesteps = timesteps[:-1] diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 06212c9e7739..727ad1a30646 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -59,10 +59,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin): The number of diffusion steps to train the model. prediction_type (`str`, defaults to `trigflow`): Prediction type of the scheduler function. Currently only supports "trigflow". - max_timesteps (`float`, defaults to 1.57080): - The maximum timestep value used in the diffusion process. - intermediate_timesteps (`float`, *optional*, defaults to 1.3): - The intermediate timestep value used when num_inference_steps=2. sigma_data (`float`, defaults to 0.5): The standard deviation of the noise added during multi-step inference. """ @@ -220,9 +216,7 @@ def step( sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`, *optional*, defaults to `True`): - itself. Useful for methods such as [`CycleDiffusion`]. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a From 8c07fccb6d87d1c14af60be20bfc1d2e04c9ac52 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Mar 2025 00:05:45 +0100 Subject: [PATCH 15/21] up --- .../pipelines/sana/pipeline_sana_sprint.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 88d66c55cfc3..229282610fb8 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -394,6 +394,24 @@ def encode_prompt( return prompt_embeds, prompt_attention_mask + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + def check_inputs( self, prompt, @@ -835,10 +853,11 @@ def __call__( guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) guidance = guidance * self.transformer.config.guidance_embeds_scale - # YiYi TODO: refactor this - timesteps = timesteps[:-1] + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop + timesteps = timesteps[:-1] num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From eae8ed71a2995327350e4f6883373218541d2803 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Mar 2025 00:10:54 +0100 Subject: [PATCH 16/21] update docstring example --- src/diffusers/pipelines/sana/pipeline_sana_sprint.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 229282610fb8..4d7cc79fc4dc 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -111,14 +111,12 @@ Examples: ```py >>> import torch - >>> from diffusers import SanaPipeline + >>> from diffusers import SanaSprintPipeline - >>> pipe = SanaPipeline.from_pretrained( - ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + >>> pipe = SanaSprintPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 ... ) >>> pipe.to("cuda") - >>> pipe.text_encoder.to(torch.bfloat16) - >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] >>> image[0].save("output.png") From c4d049c054d5ec2ebd18cfc81fa7ac350e7a85a6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Mar 2025 03:06:56 +0100 Subject: [PATCH 17/21] add tests --- .../pipelines/sana/pipeline_sana_sprint.py | 12 +- src/diffusers/schedulers/scheduling_scm.py | 3 +- tests/pipelines/sana/test_sana_sprint.py | 302 ++++++++++++++++++ 3 files changed, 312 insertions(+), 5 deletions(-) create mode 100644 tests/pipelines/sana/test_sana_sprint.py diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 4d7cc79fc4dc..5a0bc81619f7 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -869,7 +869,11 @@ def __call__( latents_model_input = latents / self.scheduler.config.sigma_data scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) - latent_model_input = latents_model_input * torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) # predict noise model_output @@ -884,9 +888,9 @@ def __call__( )[0] noise_pred = ( - (1 - 2 * scm_timestep) * latent_model_input - + (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred - ) / torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2) + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) noise_pred = noise_pred.float() * self.scheduler.config.sigma_data # compute previous image: x_t -> x_t-1 diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py index 727ad1a30646..23f47f42a302 100644 --- a/src/diffusers/schedulers/scheduling_scm.py +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..schedulers.scheduling_utils import SchedulerMixin from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -246,7 +247,7 @@ def step( # Noise is not used for one-step sampling. if len(self.timesteps) > 1: noise = ( - torch.randn(model_output.shape, device=model_output.device, generator=generator) + randn_tensor(model_output.shape, device=model_output.device, generator=generator) * self.config.sigma_data ) prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py new file mode 100644 index 000000000000..d006c2b986ca --- /dev/null +++ b/tests/pipelines/sana/test_sana_sprint.py @@ -0,0 +1,302 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer + +from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SanaSprintPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"} + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"} + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = SanaTransformer2DModel( + patch_size=1, + in_channels=4, + out_channels=4, + num_layers=1, + num_attention_heads=2, + attention_head_dim=4, + num_cross_attention_heads=2, + cross_attention_head_dim=4, + cross_attention_dim=8, + caption_channels=8, + sample_size=32, + qk_norm="rms_norm_across_heads", + guidance_embeds=True, + ) + + torch.manual_seed(0) + vae = AutoencoderDC( + in_channels=3, + latent_channels=4, + attention_head_dim=2, + encoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + decoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + encoder_block_out_channels=(8, 8), + decoder_block_out_channels=(8, 8), + encoder_qkv_multiscales=((), (5,)), + decoder_qkv_multiscales=((), (5,)), + encoder_layers_per_block=(1, 1), + decoder_layers_per_block=[1, 1], + downsample_block_type="conv", + upsample_block_type="interpolate", + decoder_norm_types="rms_norm", + decoder_act_fns="silu", + scaling_factor=0.41407, + ) + + torch.manual_seed(0) + scheduler = SCMScheduler() + + torch.manual_seed(0) + text_encoder_config = Gemma2Config( + head_dim=16, + hidden_size=8, + initializer_range=0.02, + intermediate_size=64, + max_position_embeddings=8192, + model_type="gemma2", + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + vocab_size=8, + attn_implementation="eager", + ) + text_encoder = Gemma2Model(text_encoder_config) + tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + "complex_human_instruction": None, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.randn(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass + + def test_float16_inference(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_float16_inference(expected_max_diff=0.08) From c3e107f0a9c4e9094240ffc7bd3219149c8e7fcc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Mar 2025 03:54:33 +0100 Subject: [PATCH 18/21] up --- scripts/convert_sana_to_diffusers.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index fca3842141ec..27b4c1b9d117 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -171,15 +171,6 @@ def main(args): f"blocks.{depth}.attn.proj.bias" ) - # Add Q/K normalization for self-attention (attn1) - needed for Sana Sprint - if args.model_type == "SanaSprint_1600M_P1_D20": - converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( - f"blocks.{depth}.attn.q_norm.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( - f"blocks.{depth}.attn.k_norm.weight" - ) - # Feed-forward. converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop( f"blocks.{depth}.mlp.inverted_conv.conv.weight" @@ -218,15 +209,6 @@ def main(args): f"blocks.{depth}.cross_attn.k_norm.weight" ) - # Add Q/K normalization for cross-attention (attn2) - needed for Sana Sprint - if args.model_type == "SanaSprint_1600M_P1_D20": - converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( - f"blocks.{depth}.cross_attn.q_norm.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( - f"blocks.{depth}.cross_attn.k_norm.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.cross_attn.proj.weight" ) @@ -261,6 +243,13 @@ def main(args): } # Add qk_norm parameter for Sana Sprint + if args.model_type in [ + "SanaMS1.5_1600M_P1_D20", + "SanaMS1.5_4800M_P1_D60", + "SanaSprint_600M_P1_D28", + "SanaSprint_1600M_P1_D20", + ]: + transformer_kwargs["qk_norm"] = "rms_norm_across_heads" if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: transformer_kwargs["guidance_embeds"] = True From 94d87d518678c972723dd01e229354a1d4f1f279 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 20 Mar 2025 17:04:17 -1000 Subject: [PATCH 19/21] Apply suggestions from code review --- scripts/convert_sana_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 27b4c1b9d117..3d7568388cc0 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -126,9 +126,9 @@ def main(args): flow_shift = 3.0 # model config - if args.model_type == "SanaMS_1600M_P1_D20" or args.model_type == "SanaSprint_1600M_P1_D20": + if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]: layer_num = 20 - elif args.model_type == "SanaMS_600M_P1_D28": + elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]: layer_num = 28 elif args.model_type == "SanaMS_4800M_P1_D60": layer_num = 60 From a220997e11a99a81b5201dcc69d91e7bd8bfb1ea Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Fri, 21 Mar 2025 23:57:29 +0800 Subject: [PATCH 20/21] [SANA-Sprint] remove used multi-scale bin (#11131) * change sample prompt; * only 1024px is supported; --- .../pipelines/sana/pipeline_sana_sprint.py | 59 +------------------ 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 5a0bc81619f7..7c17c2708216 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -40,11 +40,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from ..pixart_alpha.pipeline_pixart_alpha import ( - ASPECT_RATIO_512_BIN, - ASPECT_RATIO_1024_BIN, -) -from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput @@ -64,49 +60,6 @@ import ftfy -ASPECT_RATIO_4096_BIN = { - "0.25": [2048.0, 8192.0], - "0.26": [2048.0, 7936.0], - "0.27": [2048.0, 7680.0], - "0.28": [2048.0, 7424.0], - "0.32": [2304.0, 7168.0], - "0.33": [2304.0, 6912.0], - "0.35": [2304.0, 6656.0], - "0.4": [2560.0, 6400.0], - "0.42": [2560.0, 6144.0], - "0.48": [2816.0, 5888.0], - "0.5": [2816.0, 5632.0], - "0.52": [2816.0, 5376.0], - "0.57": [3072.0, 5376.0], - "0.6": [3072.0, 5120.0], - "0.68": [3328.0, 4864.0], - "0.72": [3328.0, 4608.0], - "0.78": [3584.0, 4608.0], - "0.82": [3584.0, 4352.0], - "0.88": [3840.0, 4352.0], - "0.94": [3840.0, 4096.0], - "1.0": [4096.0, 4096.0], - "1.07": [4096.0, 3840.0], - "1.13": [4352.0, 3840.0], - "1.21": [4352.0, 3584.0], - "1.29": [4608.0, 3584.0], - "1.38": [4608.0, 3328.0], - "1.46": [4864.0, 3328.0], - "1.67": [5120.0, 3072.0], - "1.75": [5376.0, 3072.0], - "2.0": [5632.0, 2816.0], - "2.09": [5888.0, 2816.0], - "2.4": [6144.0, 2560.0], - "2.5": [6400.0, 2560.0], - "2.89": [6656.0, 2304.0], - "3.0": [6912.0, 2304.0], - "3.11": [7168.0, 2304.0], - "3.62": [7424.0, 2048.0], - "3.75": [7680.0, 2048.0], - "3.88": [7936.0, 2048.0], - "4.0": [8192.0, 2048.0], -} - EXAMPLE_DOC_STRING = """ Examples: ```py @@ -118,7 +71,7 @@ ... ) >>> pipe.to("cuda") - >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image = pipe(prompt='a tiny astronaut hatching from an egg on the moon')[0] >>> image[0].save("output.png") ``` """ @@ -762,14 +715,8 @@ def __call__( # 1. Check inputs. Raise error if not correct if use_resolution_binning: - if self.transformer.config.sample_size == 128: - aspect_ratio_bin = ASPECT_RATIO_4096_BIN - elif self.transformer.config.sample_size == 64: - aspect_ratio_bin = ASPECT_RATIO_2048_BIN - elif self.transformer.config.sample_size == 32: + if self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_1024_BIN - elif self.transformer.config.sample_size == 16: - aspect_ratio_bin = ASPECT_RATIO_512_BIN else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width From 7a0460455387f55bc48c6079c7a63bb356b45001 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 21 Mar 2025 17:01:04 +0100 Subject: [PATCH 21/21] style --- src/diffusers/pipelines/sana/pipeline_sana_sprint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 7c17c2708216..9b3acbb1cb22 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -71,7 +71,7 @@ ... ) >>> pipe.to("cuda") - >>> image = pipe(prompt='a tiny astronaut hatching from an egg on the moon')[0] + >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0] >>> image[0].save("output.png") ``` """