diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d1805ff605d8..d39b5a52d2fe 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -496,6 +496,8 @@ title: PixArt-Σ - local: api/pipelines/sana title: Sana + - local: api/pipelines/sana_sprint + title: Sana Sprint - local: api/pipelines/self_attention_guidance title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md new file mode 100644 index 000000000000..8db4576cf579 --- /dev/null +++ b/docs/source/en/api/pipelines/sana_sprint.md @@ -0,0 +1,100 @@ + + +# SanaSprintPipeline + +
+ LoRA +
+ +[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han + +The abstract from the paper is: + +*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/). + +Available models: + +| Model | Recommended dtype | +|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:| +| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` | +| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` | + +Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information. + +Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. + + +## Quantization + +Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. + +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes. + +```py +import torch +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline +from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel + +quant_config = BitsAndBytesConfig(load_in_8bit=True) +text_encoder_8bit = AutoModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="text_encoder", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +transformer_8bit = SanaTransformer2DModel.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +pipeline = SanaSprintPipeline.from_pretrained( + "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", + text_encoder=text_encoder_8bit, + transformer=transformer_8bit, + torch_dtype=torch.bfloat16, + device_map="balanced", +) + +prompt = "a tiny astronaut hatching from an egg on the moon" +image = pipeline(prompt).images[0] +image.save("sana.png") +``` + +## Setting `max_timesteps` + +Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper. + +## SanaSprintPipeline + +[[autodoc]] SanaSprintPipeline + - all + - __call__ + + +## SanaPipelineOutput + +[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 99a9ff322251..3d7568388cc0 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -16,7 +16,9 @@ DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, SanaPipeline, + SanaSprintPipeline, SanaTransformer2DModel, + SCMScheduler, ) from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -25,6 +27,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth", "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", @@ -72,15 +75,42 @@ def main(args): converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - # AdaLN-single LN - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + # Handle different time embedding structure based on model type + + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: + # For Sana Sprint, the time embedding structure is different + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + + # Guidance embedder for Sana Sprint + converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop( + "cfg_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias") + converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop( + "cfg_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias") + else: + # Original Sana time embedding structure + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop( + "t_embedder.mlp.0.bias" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop( + "t_embedder.mlp.2.bias" + ) # Shared norm. converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") @@ -96,14 +126,22 @@ def main(args): flow_shift = 3.0 # model config - if args.model_type == "SanaMS_1600M_P1_D20": + if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]: layer_num = 20 - elif args.model_type == "SanaMS_600M_P1_D28": + elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]: layer_num = 28 + elif args.model_type == "SanaMS_4800M_P1_D60": + layer_num = 60 else: raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} + qk_norm = ( + "rms_norm_across_heads" + if args.model_type + in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"] + else None + ) for depth in range(layer_num): # Transformer blocks. @@ -117,6 +155,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + if qk_norm is not None: + # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.attn.k_norm.weight" + ) # Projection. converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.attn.proj.weight" @@ -154,6 +200,14 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias + if qk_norm is not None: + # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5 + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.q_norm.weight" + ) + converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop( + f"blocks.{depth}.cross_attn.k_norm.weight" + ) converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( f"blocks.{depth}.cross_attn.proj.weight" @@ -169,24 +223,37 @@ def main(args): # Transformer with CTX(): - transformer = SanaTransformer2DModel( - in_channels=32, - out_channels=32, - num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"], - attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"], - num_layers=model_kwargs[args.model_type]["num_layers"], - num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"], - cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"], - cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"], - caption_channels=2304, - mlp_ratio=2.5, - attention_bias=False, - sample_size=args.image_size // 32, - patch_size=1, - norm_elementwise_affine=False, - norm_eps=1e-6, - interpolation_scale=interpolation_scale[args.image_size], - ) + transformer_kwargs = { + "in_channels": 32, + "out_channels": 32, + "num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"], + "attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"], + "num_layers": model_kwargs[args.model_type]["num_layers"], + "num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"], + "cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"], + "cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"], + "caption_channels": 2304, + "mlp_ratio": 2.5, + "attention_bias": False, + "sample_size": args.image_size // 32, + "patch_size": 1, + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "interpolation_scale": interpolation_scale[args.image_size], + } + + # Add qk_norm parameter for Sana Sprint + if args.model_type in [ + "SanaMS1.5_1600M_P1_D20", + "SanaMS1.5_4800M_P1_D60", + "SanaSprint_600M_P1_D28", + "SanaSprint_1600M_P1_D20", + ]: + transformer_kwargs["qk_norm"] = "rms_norm_across_heads" + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: + transformer_kwargs["guidance_embeds"] = True + + transformer = SanaTransformer2DModel(**transformer_kwargs) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_state_dict) @@ -196,6 +263,8 @@ def main(args): try: state_dict.pop("y_embedder.y_embedding") state_dict.pop("pos_embed") + state_dict.pop("logvar_linear.weight") + state_dict.pop("logvar_linear.bias") except KeyError: print("y_embedder.y_embedding or pos_embed not found in the state_dict") @@ -210,47 +279,75 @@ def main(args): print( colored( f"Only saving transformer model of {args.model_type}. " - f"Set --save_full_pipeline to save the whole SanaPipeline", + f"Set --save_full_pipeline to save the whole Pipeline", "green", attrs=["bold"], ) ) transformer.save_pretrained( - os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant + os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB" ) else: - print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"])) + print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32) + ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32) # Text Encoder - text_encoder_model_path = "google/gemma-2-2b-it" + text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) tokenizer.padding_side = "right" text_encoder = AutoModelForCausalLM.from_pretrained( text_encoder_model_path, torch_dtype=torch.bfloat16 ).get_decoder() - # Scheduler - if args.scheduler_type == "flow-dpm_solver": - scheduler = DPMSolverMultistepScheduler( - flow_shift=flow_shift, - use_flow_sigmas=True, - prediction_type="flow_prediction", + # Choose the appropriate pipeline and scheduler based on model type + if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]: + # Force SCM Scheduler for Sana Sprint regardless of scheduler_type + if args.scheduler_type != "scm": + print( + colored( + f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model", + "yellow", + attrs=["bold"], + ) + ) + + # SCM Scheduler for Sana Sprint + scheduler_config = { + "num_train_timesteps": 1000, + "prediction_type": "trigflow", + "sigma_data": 0.5, + } + scheduler = SCMScheduler(**scheduler_config) + pipe = SanaSprintPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=ae, + scheduler=scheduler, ) - elif args.scheduler_type == "flow-euler": - scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) else: - raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") - - pipe = SanaPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=ae, - scheduler=scheduler, - ) - pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + # Original Sana scheduler + if args.scheduler_type == "flow-dpm_solver": + scheduler = DPMSolverMultistepScheduler( + flow_shift=flow_shift, + use_flow_sigmas=True, + prediction_type="flow_prediction", + ) + elif args.scheduler_type == "flow-euler": + scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) + else: + raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") + + pipe = SanaPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=ae, + scheduler=scheduler, + ) + + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB") DTYPE_MAPPING = { @@ -259,12 +356,6 @@ def main(args): "bf16": torch.bfloat16, } -VARIANT_MAPPING = { - "fp32": None, - "fp16": "fp16", - "bf16": "bf16", -} - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -281,10 +372,23 @@ def main(args): help="Image size of pretrained model, 512, 1024, 2048 or 4096.", ) parser.add_argument( - "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] + "--model_type", + default="SanaMS_1600M_P1_D20", + type=str, + choices=[ + "SanaMS_1600M_P1_D20", + "SanaMS_600M_P1_D28", + "SanaMS_4800M_P1_D60", + "SanaSprint_1600M_P1_D20", + "SanaSprint_600M_P1_D28", + ], ) parser.add_argument( - "--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"] + "--scheduler_type", + default="flow-dpm_solver", + type=str, + choices=["flow-dpm_solver", "flow-euler", "scm"], + help="Scheduler type to use. Use 'scm' for Sana Sprint models.", ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.") @@ -309,10 +413,41 @@ def main(args): "cross_attention_dim": 1152, "num_layers": 28, }, + "SanaMS1.5_1600M_P1_D20": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 20, + }, + "SanaMS1.5__4800M_P1_D60": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 60, + }, + "SanaSprint_600M_P1_D28": { + "num_attention_heads": 36, + "attention_head_dim": 32, + "num_cross_attention_heads": 16, + "cross_attention_head_dim": 72, + "cross_attention_dim": 1152, + "num_layers": 28, + }, + "SanaSprint_1600M_P1_D20": { + "num_attention_heads": 70, + "attention_head_dim": 32, + "num_cross_attention_heads": 20, + "cross_attention_head_dim": 112, + "cross_attention_dim": 2240, + "num_layers": 20, + }, } device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = DTYPE_MAPPING[args.dtype] - variant = VARIANT_MAPPING[args.dtype] main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 913816ec9a93..f6bdf1b7506a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -271,6 +271,7 @@ "RePaintScheduler", "SASolverScheduler", "SchedulerMixin", + "SCMScheduler", "ScoreSdeVeScheduler", "TCDScheduler", "UnCLIPScheduler", @@ -421,6 +422,7 @@ "ReduxImageEncoder", "SanaPAGPipeline", "SanaPipeline", + "SanaSprintPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -834,6 +836,7 @@ RePaintScheduler, SASolverScheduler, SchedulerMixin, + SCMScheduler, ScoreSdeVeScheduler, TCDScheduler, UnCLIPScheduler, @@ -965,6 +968,7 @@ ReduxImageEncoder, SanaPAGPipeline, SanaPipeline, + SanaSprintPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21d17d6acdab..34276a544160 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -6020,6 +6020,11 @@ def __call__( key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + query = query.transpose(1, 2).unflatten(1, (attn.heads, -1)) key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3) value = value.transpose(1, 2).unflatten(1, (attn.heads, -1)) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index b8cc96d3532c..f7c73231725d 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config @@ -23,10 +24,9 @@ from ..attention_processor import ( Attention, AttentionProcessor, - AttnProcessor2_0, SanaLinearAttnProcessor2_0, ) -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm @@ -96,6 +96,95 @@ def forward( return hidden_states +class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + guidance_proj = self.guidance_condition_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype)) + conditioning = timesteps_emb + guidance_emb + + return self.linear(self.silu(conditioning)), conditioning + + +class SanaAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class SanaTransformerBlock(nn.Module): r""" Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629). @@ -115,6 +204,7 @@ def __init__( norm_eps: float = 1e-6, attention_out_bias: bool = True, mlp_ratio: float = 2.5, + qk_norm: Optional[str] = None, ) -> None: super().__init__() @@ -124,6 +214,8 @@ def __init__( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, + kv_heads=num_attention_heads if qk_norm is not None else None, + qk_norm=qk_norm, dropout=dropout, bias=attention_bias, cross_attention_dim=None, @@ -135,13 +227,15 @@ def __init__( self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn2 = Attention( query_dim=dim, + qk_norm=qk_norm, + kv_heads=num_cross_attention_heads if qk_norm is not None else None, cross_attention_dim=cross_attention_dim, heads=num_cross_attention_heads, dim_head=cross_attention_head_dim, dropout=dropout, bias=True, out_bias=attention_out_bias, - processor=AttnProcessor2_0(), + processor=SanaAttnProcessor2_0(), ) # 3. Feed-forward @@ -258,6 +352,9 @@ def __init__( norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, interpolation_scale: Optional[int] = None, + guidance_embeds: bool = False, + guidance_embeds_scale: float = 0.1, + qk_norm: Optional[str] = None, ) -> None: super().__init__() @@ -276,7 +373,10 @@ def __init__( ) # 2. Additional condition embeddings - self.time_embed = AdaLayerNormSingle(inner_dim) + if guidance_embeds: + self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim) + else: + self.time_embed = AdaLayerNormSingle(inner_dim) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) @@ -296,6 +396,7 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, mlp_ratio=mlp_ratio, + qk_norm=qk_norm, ) for _ in range(num_layers) ] @@ -372,7 +473,8 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, + timestep: torch.Tensor, + guidance: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -423,9 +525,14 @@ def forward( hidden_states = self.patch_embed(hidden_states) - timestep, embedded_timestep = self.time_embed( - timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) + if guidance is not None: + timestep, embedded_timestep = self.time_embed( + timestep, guidance=guidance, hidden_dtype=hidden_states.dtype + ) + else: + timestep, embedded_timestep = self.time_embed( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 541d1a743bcb..50eb4672839f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -280,7 +280,7 @@ _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] - _import_structure["sana"] = ["SanaPipeline"] + _import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["stable_audio"] = [ @@ -651,7 +651,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .sana import SanaPipeline + from .sana import SanaPipeline, SanaSprintPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 53b6ba762466..1393b37e2d3a 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_sana"] = ["SanaPipeline"] + _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_sana import SanaPipeline + from .pipeline_sana_sprint import SanaSprintPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 460e7e2a237a..76934d055c56 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -248,6 +248,64 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + def encode_prompt( self, prompt: Union[str, List[str]], @@ -296,6 +354,13 @@ def encode_prompt( if device is None: device = self._execution_device + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): @@ -320,43 +385,18 @@ def encode_prompt( select_index = [0] + list(range(-max_length + 1, 0)) if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - - # prepare complex human instruction - if not complex_human_instruction: - max_length_all = max_length - else: - chi_prompt = "\n".join(complex_human_instruction) - prompt = [chi_prompt + p for p in prompt] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = num_chi_prompt_tokens + max_length - 2 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length_all, - truncation=True, - add_special_tokens=True, - return_tensors="pt", + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0][:, select_index] + prompt_embeds = prompt_embeds[:, select_index] prompt_attention_mask = prompt_attention_mask[:, select_index] - if self.transformer is not None: - dtype = self.transformer.dtype - elif self.text_encoder is not None: - dtype = self.text_encoder.dtype - else: - dtype = None - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -366,25 +406,15 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, ) - negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py new file mode 100644 index 000000000000..9b3acbb1cb22 --- /dev/null +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -0,0 +1,889 @@ +# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaSprintPipeline + + >>> pipe = SanaSprintPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 2, + timesteps: List[int] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: List[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> Union[SanaPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`List[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + guidance = guidance * self.transformer.config.guidance_embeds_scale + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + timesteps = timesteps[:-1] + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype) + latents_model_input = latents / self.scheduler.config.sigma_data + + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + noise_pred = ( + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step( + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except torch.cuda.OutOfMemoryError as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index bb9088538653..05cd21cd0034 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -68,6 +68,7 @@ _import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] + _import_structure["scheduling_scm"] = ["SCMScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] @@ -168,13 +169,13 @@ from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler + from .scheduling_scm import SCMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler - try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py new file mode 100644 index 000000000000..23f47f42a302 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_scm.py @@ -0,0 +1,265 @@ +# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..schedulers.scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM +class SCMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +class SCMScheduler(SchedulerMixin, ConfigMixin): + """ + `SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass + documentation for the generic methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + prediction_type (`str`, defaults to `trigflow`): + Prediction type of the scheduler function. Currently only supports "trigflow". + sigma_data (`float`, defaults to 0.5): + The standard deviation of the noise added during multi-step inference. + """ + + # _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + prediction_type: str = "trigflow", + sigma_data: float = 0.5, + ): + """ + Initialize the SCM scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + prediction_type (`str`, defaults to `trigflow`): + Prediction type of the scheduler function. Currently only supports "trigflow". + sigma_data (`float`, defaults to 0.5): + The standard deviation of the noise added during multi-step inference. + """ + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self._step_index = None + self._begin_index = None + + @property + def step_index(self): + return self._step_index + + @property + def begin_index(self): + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int, + timesteps: torch.Tensor = None, + device: Union[str, torch.device] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + timesteps (`torch.Tensor`, *optional*): + Custom timesteps to use for the denoising process. + max_timesteps (`float`, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + """ + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + self.num_inference_steps = num_inference_steps + + if timesteps is not None: + if isinstance(timesteps, list): + self.timesteps = torch.tensor(timesteps, device=device).float() + elif isinstance(timesteps, torch.Tensor): + self.timesteps = timesteps.to(device).float() + else: + raise ValueError(f"Unsupported timesteps type: {type(timesteps)}") + elif intermediate_timesteps is not None: + self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float() + else: + # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here + self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float() + print(f"Set timesteps: {self.timesteps}") + + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def step( + self, + model_output: torch.FloatTensor, + timestep: float, + sample: torch.FloatTensor, + generator: torch.Generator = None, + return_dict: bool = True, + ) -> Union[SCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 2. compute alphas, betas + t = self.timesteps[self.step_index + 1] + s = self.timesteps[self.step_index] + + # 4. Different Parameterization: + parameterization = self.config.prediction_type + + if parameterization == "trigflow": + pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output + else: + raise ValueError(f"Unsupported parameterization: {parameterization}") + + # 5. Sample z ~ N(0, I), For MultiStep Inference + # Noise is not used for one-step sampling. + if len(self.timesteps) > 1: + noise = ( + randn_tensor(model_output.shape, device=model_output.device, generator=generator) + * self.config.sigma_data + ) + prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise + else: + prev_sample = pred_x0 + + self._step_index += 1 + + if not return_dict: + return (prev_sample, pred_x0) + + return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 31d2e1e2d78d..d62179951fbf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1834,6 +1834,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SCMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 841ffbdafa52..30aef8955092 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1502,6 +1502,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SanaSprintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py new file mode 100644 index 000000000000..d006c2b986ca --- /dev/null +++ b/tests/pipelines/sana/test_sana_sprint.py @@ -0,0 +1,302 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer + +from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SanaSprintPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"} + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"} + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = SanaTransformer2DModel( + patch_size=1, + in_channels=4, + out_channels=4, + num_layers=1, + num_attention_heads=2, + attention_head_dim=4, + num_cross_attention_heads=2, + cross_attention_head_dim=4, + cross_attention_dim=8, + caption_channels=8, + sample_size=32, + qk_norm="rms_norm_across_heads", + guidance_embeds=True, + ) + + torch.manual_seed(0) + vae = AutoencoderDC( + in_channels=3, + latent_channels=4, + attention_head_dim=2, + encoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + decoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + encoder_block_out_channels=(8, 8), + decoder_block_out_channels=(8, 8), + encoder_qkv_multiscales=((), (5,)), + decoder_qkv_multiscales=((), (5,)), + encoder_layers_per_block=(1, 1), + decoder_layers_per_block=[1, 1], + downsample_block_type="conv", + upsample_block_type="interpolate", + decoder_norm_types="rms_norm", + decoder_act_fns="silu", + scaling_factor=0.41407, + ) + + torch.manual_seed(0) + scheduler = SCMScheduler() + + torch.manual_seed(0) + text_encoder_config = Gemma2Config( + head_dim=16, + hidden_size=8, + initializer_range=0.02, + intermediate_size=64, + max_position_embeddings=8192, + model_type="gemma2", + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + vocab_size=8, + attn_implementation="eager", + ) + text_encoder = Gemma2Model(text_encoder_config) + tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + "complex_human_instruction": None, + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.randn(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass + + def test_float16_inference(self): + # Requires higher tolerance as model seems very sensitive to dtype + super().test_float16_inference(expected_max_diff=0.08)