From a7941c581da72815212c925d8466399f742a1714 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 14 Aug 2024 17:31:15 +0800 Subject: [PATCH 01/19] draft of embedding --- src/diffusers/models/embeddings.py | 281 +++++++++++------- .../transformers/cogvideox_transformer_3d.py | 18 +- 2 files changed, 187 insertions(+), 112 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1258964385da..fad2e89a9566 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn.functional as F +from einops import rearrange from torch import nn from ..utils import deprecate @@ -25,12 +26,12 @@ def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -77,13 +78,75 @@ def get_timestep_embedding( emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb - -def get_3d_sincos_pos_embed( +def get_3d_rotary_pos_embed( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], temporal_size: int, - spatial_interpolation_scale: float = 1.0, - temporal_interpolation_scale: float = 1.0, + hidden_size_head: int, + theta=10000, + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, +): + height, width = spatial_size + + # Compute dimensions for each axis + dim_t = hidden_size_head // 4 + dim_h = hidden_size_head // 8 * 3 + dim_w = hidden_size_head // 8 * 3 + + # Temporal frequencies + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) + grid_t = torch.arange(temporal_size, dtype=torch.float32) / temporal_interpolation_scale + freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) + freqs_t = freqs_t.repeat_interleave(2, dim=-1) + + # Spatial frequencies for height and width + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) + grid_h = torch.arange(height, dtype=torch.float32) / spatial_interpolation_scale + grid_w = torch.arange(width, dtype=torch.float32) / spatial_interpolation_scale + freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) + freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) + freqs_h = freqs_h.repeat_interleave(2, dim=-1) + freqs_w = freqs_w.repeat_interleave(2, dim=-1) + + # Broadcast and concatenate tensors along specified dimension + def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatenation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + + # Flatten and reorder dimensions manually without rearrange + t, h, w, d = freqs.shape + freqs = freqs.view(t * h * w, d) + + # Generate sine and cosine components + freqs_sin = freqs.sin() + freqs_cos = freqs.cos() + + return freqs_sin, freqs_cos + +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Args: @@ -126,7 +189,7 @@ def get_3d_sincos_pos_embed( def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or @@ -168,7 +231,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) + omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product @@ -184,18 +247,18 @@ class PatchEmbed(nn.Module): """2D Image to Patch Embedding with support for SD3 cropping.""" def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=1, - pos_embed_type="sincos", - pos_embed_max_size=None, # For SD3 cropping + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping ): super().__init__() @@ -221,7 +284,7 @@ def __init__( if pos_embed_max_size: grid_size = pos_embed_max_size else: - grid_size = int(num_patches**0.5) + grid_size = int(num_patches ** 0.5) if pos_embed_type is None: self.pos_embed = None @@ -253,7 +316,7 @@ def cropped_pos_embed(self, height, width): top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) - spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed[:, top: top + height, left: left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed @@ -336,12 +399,12 @@ def forward(self, x, freqs_cis): class CogVideoXPatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size @@ -439,13 +502,13 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -493,10 +556,10 @@ def get_1d_rotary_pos_embed( def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -542,14 +605,14 @@ def apply_rotary_emb( class TimestepEmbedding(nn.Module): def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - sample_proj_bias=True, + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, ): super().__init__() @@ -611,7 +674,7 @@ class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) @@ -690,11 +753,11 @@ class ImagePositionalEmbeddings(nn.Module): """ def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, ): super().__init__() @@ -768,11 +831,11 @@ def forward(self, labels: torch.LongTensor, force_drop_ids=None): class TextImageProjection(nn.Module): def __init__( - self, - text_embed_dim: int = 1024, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 10, + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, ): super().__init__() @@ -795,10 +858,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): class ImageProjection(nn.Module): def __init__( - self, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 32, + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, ): super().__init__() @@ -911,7 +974,7 @@ class HunyuanDiTAttentionPool(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) @@ -948,12 +1011,12 @@ def forward(self, x): class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): def __init__( - self, - embedding_dim, - pooled_projection_dim=1024, - seq_len=256, - cross_attention_dim=2048, - use_style_cond_and_image_meta_size=True, + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, ): super().__init__() @@ -1125,7 +1188,7 @@ class AttentionPooling(nn.Module): def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) @@ -1233,14 +1296,14 @@ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freq self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( - self, - boxes, - masks, - positive_embeddings=None, - phrases_masks=None, - image_masks=None, - phrases_embeddings=None, - image_embeddings=None, + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, ): masks = masks.unsqueeze(-1) @@ -1351,11 +1414,11 @@ def forward(self, caption): class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( - self, - embed_dims: int = 768, - dim_head: int = 64, - heads: int = 16, - ffn_ratio: float = 4, + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, ) -> None: super().__init__() from .attention import FeedForward @@ -1399,18 +1462,18 @@ class IPAdapterPlusImageProjection(nn.Module): """ def __init__( - self, - embed_dims: int = 768, - output_dims: int = 1024, - hidden_dims: int = 1280, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_queries: int = 8, - ffn_ratio: float = 4, + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, ) -> None: super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims ** 0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) @@ -1459,18 +1522,18 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module): """ def __init__( - self, - embed_dims: int = 768, - output_dims: int = 768, - hidden_dims: int = 1280, - id_embeddings_dim: int = 512, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_tokens: int = 4, - num_queries: int = 8, - ffn_ratio: float = 4, - ffproj_ratio: int = 2, + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, ) -> None: super().__init__() from .attention import FeedForward diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 1030b0df04ff..65ef3b88a532 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,7 +22,8 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, \ + get_3d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -208,7 +209,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - num_attention_heads: int = 30, + num_attention_heads: int = 48, # CogVideoX-2B is 30 attention_head_dim: int = 64, in_channels: int = 16, out_channels: Optional[int] = 16, @@ -216,7 +217,7 @@ def __init__( freq_shift: int = 0, time_embed_dim: int = 512, text_embed_dim: int = 4096, - num_layers: int = 30, + num_layers: int = 42, # CogVideoX-2B is 30 dropout: float = 0.0, attention_bias: bool = True, sample_width: int = 90, @@ -252,6 +253,17 @@ def __init__( spatial_interpolation_scale, temporal_interpolation_scale, ) + + # spatial_pos_embedding = get_3d_rotary_pos_embed( + # inner_dim, + # (post_patch_width, post_patch_height), + # post_time_compression_frames, + # attention_head_dim, + # 10000, + # spatial_interpolation_scale, + # temporal_interpolation_scale, + # ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) From 13c9f044f457f839b399c047f7b6a065dd81eb96 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 16 Aug 2024 21:44:43 +0800 Subject: [PATCH 02/19] For 5B --- scripts/convert_cogvideox_to_diffusers.py | 14 +- src/diffusers/models/attention_processor.py | 89 ++++++++++ src/diffusers/models/embeddings.py | 145 +++++++++------- .../transformers/cogvideox_transformer_3d.py | 40 +++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 162 +++++++++++------- 5 files changed, 303 insertions(+), 147 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index c03013a7fff9..5f63e3166079 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -141,6 +141,18 @@ def convert_transformer(ckpt_path: str): continue handler_fn_inplace(key, original_state_dict) + # Remove keys that are not used in final modelc + keys_to_remove = [ + "mixins.pos_embed.freqs_sin", # Not use + "mixins.pos_embed.freqs_cos", # Not use + "transformer_blocks.position_embeddings.weight" # Not use + ] + + for key in keys_to_remove: + if key in original_state_dict: + print(f"Removing key: {key}") + del original_state_dict[key] + transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -172,7 +184,7 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ab1606b345..10f440fe1c8d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1769,6 +1769,95 @@ def __call__( return hidden_states +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding + on query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fad2e89a9566..d9f71906445a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -78,69 +78,6 @@ def get_timestep_embedding( emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb -def get_3d_rotary_pos_embed( - embed_dim: int, - spatial_size: Union[int, Tuple[int, int]], - temporal_size: int, - hidden_size_head: int, - theta=10000, - spatial_interpolation_scale=1.0, - temporal_interpolation_scale=1.0, -): - height, width = spatial_size - - # Compute dimensions for each axis - dim_t = hidden_size_head // 4 - dim_h = hidden_size_head // 8 * 3 - dim_w = hidden_size_head // 8 * 3 - - # Temporal frequencies - freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) - grid_t = torch.arange(temporal_size, dtype=torch.float32) / temporal_interpolation_scale - freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) - freqs_t = freqs_t.repeat_interleave(2, dim=-1) - - # Spatial frequencies for height and width - freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) - freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) - grid_h = torch.arange(height, dtype=torch.float32) / spatial_interpolation_scale - grid_w = torch.arange(width, dtype=torch.float32) / spatial_interpolation_scale - freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) - freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) - freqs_h = freqs_h.repeat_interleave(2, dim=-1) - freqs_w = freqs_w.repeat_interleave(2, dim=-1) - - # Broadcast and concatenate tensors along specified dimension - def broadcat(tensors, dim=-1): - num_tensors = len(tensors) - shape_lens = set(list(map(lambda t: len(t.shape), tensors))) - assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" - shape_len = list(shape_lens)[0] - dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*map(lambda t: list(t.shape), tensors))) - expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] - assert all( - [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] - ), "invalid dimensions for broadcastable concatenation" - max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) - expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) - expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) - tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) - return torch.cat(tensors, dim=dim) - - freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - - # Flatten and reorder dimensions manually without rearrange - t, h, w, d = freqs.shape - freqs = freqs.view(t * h * w, d) - - # Generate sine and cosine components - freqs_sin = freqs.sin() - freqs_cos = freqs.cos() - - return freqs_sin, freqs_cos - def get_3d_sincos_pos_embed( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], @@ -437,6 +374,88 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return embeds +def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, temporal_size=13, theta=10000, use_real=True): + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) + grid_t = torch.from_numpy(grid_t).float() + freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) + freqs_t = freqs_t.repeat_interleave(2, dim=-1) + + # Spatial frequencies for height and width + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) + freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) + freqs_h = freqs_h.repeat_interleave(2, dim=-1) + freqs_w = freqs_w.repeat_interleave(2, dim=-1) + + # Broadcast and concatenate tensors along specified dimension + def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatenation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + + # Flatten and reorder dimensions manually without rearrange + t, h, w, d = freqs.shape + freqs = freqs.view(t * h * w, d) + + # Generate sine and cosine components + sin = freqs.sin() + cos = freqs.cos() + + if use_real: + return cos, sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 65ef3b88a532..d0a7865d4da6 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -18,11 +18,12 @@ import torch from torch import nn +from ..attention_processor import CogVideoXAttnProcessor2_0 from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, \ +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps,\ get_3d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -89,15 +90,15 @@ def __init__( # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - self.attn1 = Attention( query_dim=dim, - dim_head=attention_head_dim, + dim_head=dim // num_attention_heads, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -115,8 +116,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb @@ -127,12 +129,19 @@ def forward( # CogVideoX uses concatenated text + video embeddings with self-attention instead of using # them in cross-attention individually + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + cos_emb, sin_emb = image_rotary_emb + pad_tensor = torch.zeros((text_length, cos_emb.shape[1]), device=cos_emb.device, dtype=cos_emb.dtype) + cos_emb_padded = torch.cat([pad_tensor, cos_emb], dim=0) + sin_emb_padded = torch.cat([pad_tensor, sin_emb], dim=0) + + image_rotary_emb = (cos_emb_padded, sin_emb_padded) attn_output = self.attn1( hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, encoder_hidden_states=None, ) - hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] @@ -147,6 +156,7 @@ def forward( hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] + return hidden_states, encoder_hidden_states @@ -246,27 +256,17 @@ def __init__( self.embedding_dropout = nn.Dropout(dropout) # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - - # spatial_pos_embedding = get_3d_rotary_pos_embed( + # spatial_pos_embedding = get_3d_sincos_pos_embed( # inner_dim, # (post_patch_width, post_patch_height), # post_time_compression_frames, - # attention_head_dim, - # 10000, # spatial_interpolation_scale, # temporal_interpolation_scale, # ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + # spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + # pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) self.register_buffer("pos_embedding", pos_embedding, persistent=False) # 3. Time embeddings @@ -313,6 +313,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb=None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -362,6 +363,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, temb=emb, ) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index f43edab987fe..0da6be0cc66e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -28,11 +28,29 @@ from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor - +from ...models.embeddings import get_3d_rotary_pos_embed logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -57,12 +75,12 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles @@ -163,12 +181,12 @@ class CogVideoXPipeline(DiffusionPipeline): ] def __init__( - self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3DModel, - scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], ): super().__init__() @@ -185,12 +203,12 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -210,7 +228,7 @@ def _get_t5_prompt_embeds( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1: -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" @@ -227,16 +245,16 @@ def _get_t5_prompt_embeds( return prompt_embeds def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -308,7 +326,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, @@ -359,20 +377,20 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( - self, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds=None, - negative_prompt_embeds=None, + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -424,29 +442,29 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - num_frames: int = 49, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - guidance_scale: float = 6, - use_dynamic_cfg: bool = False, - num_videos_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "pil", - return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 226, + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -600,6 +618,20 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop + + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, + base_size_height) + image_rotary_emb = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + use_real=True + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -620,6 +652,7 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, + image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() @@ -627,7 +660,8 @@ def __call__( # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + (1 - math.cos( + math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) From 553aaedf2478a60e6f723363626cfe3cebcfbb0b Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 16 Aug 2024 23:18:22 +0800 Subject: [PATCH 03/19] Update cogvideox_transformer_3d.py --- .../models/transformers/cogvideox_transformer_3d.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index d0a7865d4da6..689b6adadc4e 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -130,11 +130,13 @@ def forward( # CogVideoX uses concatenated text + video embeddings with self-attention instead of using # them in cross-attention individually + # Padding the sin/cos embeddings for the text with zeros, sin is zero and cos is one norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) cos_emb, sin_emb = image_rotary_emb - pad_tensor = torch.zeros((text_length, cos_emb.shape[1]), device=cos_emb.device, dtype=cos_emb.dtype) - cos_emb_padded = torch.cat([pad_tensor, cos_emb], dim=0) - sin_emb_padded = torch.cat([pad_tensor, sin_emb], dim=0) + pad_tensor_sin = torch.zeros((text_length, sin_emb.shape[1]), device=sin_emb.device, dtype=sin_emb.dtype) + pad_tensor_cos = torch.ones((text_length, cos_emb.shape[1]), device=cos_emb.device, dtype=cos_emb.dtype) + cos_emb_padded = torch.cat([pad_tensor_cos, cos_emb], dim=0) + sin_emb_padded = torch.cat([pad_tensor_sin, sin_emb], dim=0) image_rotary_emb = (cos_emb_padded, sin_emb_padded) attn_output = self.attn1( From 431793ac6a7a8802c53d77f806d601a7240188b4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 19 Aug 2024 05:53:44 +0200 Subject: [PATCH 04/19] revert tab spacing changes --- scripts/convert_cogvideox_to_diffusers.py | 14 +- .../transformers/cogvideox_transformer_3d.py | 50 ++---- .../pipelines/cogvideo/pipeline_cogvideox.py | 162 +++++++----------- 3 files changed, 82 insertions(+), 144 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 5f63e3166079..c03013a7fff9 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -141,18 +141,6 @@ def convert_transformer(ckpt_path: str): continue handler_fn_inplace(key, original_state_dict) - # Remove keys that are not used in final modelc - keys_to_remove = [ - "mixins.pos_embed.freqs_sin", # Not use - "mixins.pos_embed.freqs_cos", # Not use - "transformer_blocks.position_embeddings.weight" # Not use - ] - - for key in keys_to_remove: - if key in original_state_dict: - print(f"Removing key: {key}") - del original_state_dict[key] - transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -184,7 +172,7 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 689b6adadc4e..1030b0df04ff 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -18,13 +18,11 @@ import torch from torch import nn -from ..attention_processor import CogVideoXAttnProcessor2_0 from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps,\ - get_3d_rotary_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -90,15 +88,15 @@ def __init__( # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + self.attn1 = Attention( query_dim=dim, - dim_head=dim // num_attention_heads, + dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -116,9 +114,8 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb=None, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb @@ -129,21 +126,12 @@ def forward( # CogVideoX uses concatenated text + video embeddings with self-attention instead of using # them in cross-attention individually - - # Padding the sin/cos embeddings for the text with zeros, sin is zero and cos is one norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - cos_emb, sin_emb = image_rotary_emb - pad_tensor_sin = torch.zeros((text_length, sin_emb.shape[1]), device=sin_emb.device, dtype=sin_emb.dtype) - pad_tensor_cos = torch.ones((text_length, cos_emb.shape[1]), device=cos_emb.device, dtype=cos_emb.dtype) - cos_emb_padded = torch.cat([pad_tensor_cos, cos_emb], dim=0) - sin_emb_padded = torch.cat([pad_tensor_sin, sin_emb], dim=0) - - image_rotary_emb = (cos_emb_padded, sin_emb_padded) attn_output = self.attn1( hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, encoder_hidden_states=None, ) + hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] @@ -158,7 +146,6 @@ def forward( hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] - return hidden_states, encoder_hidden_states @@ -221,7 +208,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - num_attention_heads: int = 48, # CogVideoX-2B is 30 + num_attention_heads: int = 30, attention_head_dim: int = 64, in_channels: int = 16, out_channels: Optional[int] = 16, @@ -229,7 +216,7 @@ def __init__( freq_shift: int = 0, time_embed_dim: int = 512, text_embed_dim: int = 4096, - num_layers: int = 42, # CogVideoX-2B is 30 + num_layers: int = 30, dropout: float = 0.0, attention_bias: bool = True, sample_width: int = 90, @@ -258,17 +245,16 @@ def __init__( self.embedding_dropout = nn.Dropout(dropout) # 2. 3D positional embeddings - # spatial_pos_embedding = get_3d_sincos_pos_embed( - # inner_dim, - # (post_patch_width, post_patch_height), - # post_time_compression_frames, - # spatial_interpolation_scale, - # temporal_interpolation_scale, - # ) - - # spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + spatial_pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + spatial_interpolation_scale, + temporal_interpolation_scale, + ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - # pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) self.register_buffer("pos_embedding", pos_embedding, persistent=False) # 3. Time embeddings @@ -315,7 +301,6 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb=None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -365,7 +350,6 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - image_rotary_emb=image_rotary_emb, temb=emb, ) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 0da6be0cc66e..f43edab987fe 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -28,27 +28,9 @@ from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ...models.embeddings import get_3d_rotary_pos_embed -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -75,12 +57,12 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles @@ -181,12 +163,12 @@ class CogVideoXPipeline(DiffusionPipeline): ] def __init__( - self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - vae: AutoencoderKLCogVideoX, - transformer: CogVideoXTransformer3DModel, - scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], ): super().__init__() @@ -203,12 +185,12 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -228,7 +210,7 @@ def _get_t5_prompt_embeds( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1: -1]) + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" @@ -245,16 +227,16 @@ def _get_t5_prompt_embeds( return prompt_embeds def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -326,7 +308,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, @@ -377,20 +359,20 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( - self, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds=None, - negative_prompt_embeds=None, + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -442,29 +424,29 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - num_frames: int = 49, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - guidance_scale: float = 6, - use_dynamic_cfg: bool = False, - num_videos_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "pil", - return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 226, + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -618,20 +600,6 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - - grid_height = height // 8 // self.transformer.config.patch_size - grid_width = width // 8 // self.transformer.config.patch_size - base_size_width = 720 // 8 // self.transformer.config.patch_size - base_size_height = 480 // 8 // self.transformer.config.patch_size - - grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, - base_size_height) - image_rotary_emb = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - use_real=True - ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -652,7 +620,6 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, - image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() @@ -660,8 +627,7 @@ def __call__( # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos( - math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) From 4cdc271c8d881b50e66766d99b923e6660495249 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 19 Aug 2024 06:50:08 +0200 Subject: [PATCH 05/19] unrevert pipeline changes except tab spacing --- .../pipelines/cogvideo/pipeline_cogvideox.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index f43edab987fe..ff3eeef0b0f4 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,6 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring @@ -55,6 +56,25 @@ """ +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -599,6 +619,21 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7. Create rotary embeds + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, + base_size_height) + image_rotary_emb = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + use_real=True + ) + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -620,6 +655,7 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, + image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() From dbc0b2e8812b6ee852092b4d8bfc12ae3ec9bcd8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 19 Aug 2024 15:34:55 +0200 Subject: [PATCH 06/19] update conversion script --- scripts/convert_cogvideox_to_diffusers.py | 31 +++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index c03013a7fff9..b639c12de304 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -86,6 +86,9 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "key_layernorm_list": reassign_query_key_layernorm_inplace, "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, "embed_tokens": remove_keys_inplace, + "freqs_sin": remove_keys_inplace, + "freqs_cos": remove_keys_inplace, + "position_embedding": remove_keys_inplace, } VAE_KEYS_RENAME_DICT = { @@ -123,11 +126,11 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: state_dict[new_key] = state_dict.pop(old_key) -def convert_transformer(ckpt_path: str): +def convert_transformer(ckpt_path: str, num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - transformer = CogVideoXTransformer3DModel() + transformer = CogVideoXTransformer3DModel(num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings) for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -172,13 +175,20 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) + # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 + parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") + # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 + parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True + parser.add_argument("--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not") return parser.parse_args() @@ -188,8 +198,10 @@ def get_args(): transformer = None vae = None + if args.fp16 and args.bf16: + raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path) + transformer = convert_transformer(args.transformer_ckpt_path, args.num_layers, args.num_attention_heads, args.use_rotary_positional_embeddings) if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path) @@ -197,6 +209,10 @@ def get_args(): tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + # Apparently, the conversion does not work any more without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + scheduler = CogVideoXDDIMScheduler.from_config( { "snr_shift_scale": 3.0, @@ -208,7 +224,7 @@ def get_args(): "prediction_type": "v_prediction", "rescale_betas_zero_snr": True, "set_alpha_to_one": True, - "timestep_spacing": "linspace", + "timestep_spacing": "trailing", } ) @@ -218,5 +234,10 @@ def get_args(): if args.fp16: pipe = pipe.to(dtype=torch.float16) + if args.bf16: + pipe = pipe.to(dtype=torch.bfloat16) + # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird + # for users to specify variant when the default is not fp32 and they want to run with the correct default (which + # is either fp16/bf16 here). pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) From 991d0580869e0c91de79ba1762788423060d55f0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 19 Aug 2024 15:35:18 +0200 Subject: [PATCH 07/19] refactor and cleanup; make style --- src/diffusers/models/attention_processor.py | 6 +- src/diffusers/models/embeddings.py | 237 +++++++++--------- .../transformers/cogvideox_transformer_3d.py | 20 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 45 ++-- 4 files changed, 169 insertions(+), 139 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 10f440fe1c8d..0e53dee95acb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1769,10 +1769,11 @@ def __call__( return hidden_states + class CogVideoXAttnProcessor2_0: r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding - on query and key vectors, but does not include spatial normalization. + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. """ def __init__(self): @@ -1859,6 +1860,7 @@ def __call__( return hidden_states + class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d9f71906445a..c2392b029d5f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -17,7 +17,6 @@ import numpy as np import torch import torch.nn.functional as F -from einops import rearrange from torch import nn from ..utils import deprecate @@ -26,12 +25,12 @@ def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -78,12 +77,13 @@ def get_timestep_embedding( emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb + def get_3d_sincos_pos_embed( - embed_dim: int, - spatial_size: Union[int, Tuple[int, int]], - temporal_size: int, - spatial_interpolation_scale: float = 1.0, - temporal_interpolation_scale: float = 1.0, + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Args: @@ -126,7 +126,7 @@ def get_3d_sincos_pos_embed( def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or @@ -168,7 +168,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 - omega = 1.0 / 10000 ** omega # (D/2,) + omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product @@ -184,18 +184,18 @@ class PatchEmbed(nn.Module): """2D Image to Patch Embedding with support for SD3 cropping.""" def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=1, - pos_embed_type="sincos", - pos_embed_max_size=None, # For SD3 cropping + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping ): super().__init__() @@ -221,7 +221,7 @@ def __init__( if pos_embed_max_size: grid_size = pos_embed_max_size else: - grid_size = int(num_patches ** 0.5) + grid_size = int(num_patches**0.5) if pos_embed_type is None: self.pos_embed = None @@ -253,7 +253,7 @@ def cropped_pos_embed(self, height, width): top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) - spatial_pos_embed = spatial_pos_embed[:, top: top + height, left: left + width, :] + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed @@ -336,12 +336,12 @@ def forward(self, x, freqs_cis): class CogVideoXPatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size @@ -424,20 +424,20 @@ def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, temporal_size=13 # Broadcast and concatenate tensors along specified dimension def broadcat(tensors, dim=-1): num_tensors = len(tensors) - shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + shape_lens = {len(t.shape) for t in tensors} assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim - dims = list(zip(*map(lambda t: list(t.shape), tensors))) + dims = list(zip(*(list(t.shape) for t in tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( - [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + [*(len(set(t[1])) <= 2 for t in expandable_dims)] ), "invalid dimensions for broadcastable concatenation" - max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) - expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] expanded_dims.insert(dim, (dim, dims[dim])) - expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) - tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] return torch.cat(tensors, dim=dim) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) @@ -456,6 +456,7 @@ def broadcat(tensors, dim=-1): freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. @@ -521,13 +522,13 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -575,10 +576,10 @@ def get_1d_rotary_pos_embed( def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -624,14 +625,14 @@ def apply_rotary_emb( class TimestepEmbedding(nn.Module): def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - sample_proj_bias=True, + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, ): super().__init__() @@ -693,7 +694,7 @@ class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) @@ -772,11 +773,11 @@ class ImagePositionalEmbeddings(nn.Module): """ def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, ): super().__init__() @@ -850,11 +851,11 @@ def forward(self, labels: torch.LongTensor, force_drop_ids=None): class TextImageProjection(nn.Module): def __init__( - self, - text_embed_dim: int = 1024, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 10, + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, ): super().__init__() @@ -877,10 +878,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): class ImageProjection(nn.Module): def __init__( - self, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 32, + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, ): super().__init__() @@ -993,7 +994,7 @@ class HunyuanDiTAttentionPool(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) @@ -1030,12 +1031,12 @@ def forward(self, x): class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): def __init__( - self, - embedding_dim, - pooled_projection_dim=1024, - seq_len=256, - cross_attention_dim=2048, - use_style_cond_and_image_meta_size=True, + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, ): super().__init__() @@ -1207,7 +1208,7 @@ class AttentionPooling(nn.Module): def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) @@ -1315,14 +1316,14 @@ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freq self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( - self, - boxes, - masks, - positive_embeddings=None, - phrases_masks=None, - image_masks=None, - phrases_embeddings=None, - image_embeddings=None, + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, ): masks = masks.unsqueeze(-1) @@ -1433,11 +1434,11 @@ def forward(self, caption): class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( - self, - embed_dims: int = 768, - dim_head: int = 64, - heads: int = 16, - ffn_ratio: float = 4, + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, ) -> None: super().__init__() from .attention import FeedForward @@ -1481,18 +1482,18 @@ class IPAdapterPlusImageProjection(nn.Module): """ def __init__( - self, - embed_dims: int = 768, - output_dims: int = 1024, - hidden_dims: int = 1280, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_queries: int = 8, - ffn_ratio: float = 4, + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, ) -> None: super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims ** 0.5) + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) @@ -1541,18 +1542,18 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module): """ def __init__( - self, - embed_dims: int = 768, - output_dims: int = 768, - hidden_dims: int = 1280, - id_embeddings_dim: int = 512, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_tokens: int = 4, - num_queries: int = 8, - ffn_ratio: float = 4, - ffproj_ratio: int = 2, + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, ) -> None: super().__init__() from .attention import FeedForward diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 1030b0df04ff..7b0c01c6e60b 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn @@ -22,6 +22,7 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward +from ..attention_processor import HunyuanAttnProcessor2_0, CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -97,6 +98,7 @@ def __init__( eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), ) # 2. Feed Forward @@ -116,7 +118,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: + breakpoint() norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) @@ -130,6 +134,7 @@ def forward( attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, + image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] @@ -231,6 +236,7 @@ def __init__( norm_eps: float = 1e-5, spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -301,6 +307,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -319,11 +326,12 @@ def forward( hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding - seq_length = height * width * num_frames // (self.config.patch_size**2) + if not self.config.use_rotary_positional_embeddings: + seq_length = height * width * num_frames // (self.config.patch_size**2) - pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) + pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] @@ -344,6 +352,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, emb, + image_rotary_emb, **ckpt_kwargs, ) else: @@ -351,6 +360,7 @@ def custom_forward(*inputs): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, + image_rotary_emb=image_rotary_emb, ) hidden_states = self.norm_final(hidden_states) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index ff3eeef0b0f4..1bb9351d4107 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,6 +429,31 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + def _prepare_rotary_embeddings( + self, height: int, width: int, max_sequence_length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + use_real=True, + ) + + pad_tensor_cos = freqs_cos.new_zeros((max_sequence_length, freqs_cos.size(1))) + pad_tensor_sin = freqs_sin.new_zeros((max_sequence_length, freqs_sin.size(1))) + freqs_cos = torch.cat([pad_tensor_cos, freqs_cos]) + freqs_sin = torch.cat([pad_tensor_sin, freqs_sin]) + + return freqs_cos, freqs_sin + @property def guidance_scale(self): return self._guidance_scale @@ -619,22 +644,14 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Create rotary embeds - grid_height = height // 8 // self.transformer.config.patch_size - grid_width = width // 8 // self.transformer.config.patch_size - base_size_width = 720 // 8 // self.transformer.config.patch_size - base_size_height = 480 // 8 // self.transformer.config.patch_size - - grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, - base_size_height) - image_rotary_emb = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - use_real=True + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_embeddings(height, width, max_sequence_length) + if self.transformer.config.use_rotary_positional_embeddings + else None ) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: From 01037830d751f2e91951704886f157142e4721d9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 20 Aug 2024 00:26:09 +0200 Subject: [PATCH 08/19] make style; fix autoencoder scaling factor --- scripts/convert_cogvideox_to_diffusers.py | 29 ++++++++++++++----- src/diffusers/models/attention_processor.py | 1 - .../autoencoders/autoencoder_kl_cogvideox.py | 2 +- src/diffusers/models/embeddings.py | 26 ++++++++++++++--- .../transformers/cogvideox_transformer_3d.py | 3 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 19 ++++++++---- 6 files changed, 60 insertions(+), 20 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index b639c12de304..9062e894cb58 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -126,11 +126,17 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: state_dict[new_key] = state_dict.pop(old_key) -def convert_transformer(ckpt_path: str, num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool): +def convert_transformer( + ckpt_path: str, num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool +): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - transformer = CogVideoXTransformer3DModel(num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings) + transformer = CogVideoXTransformer3DModel( + num_layers=num_layers, + num_attention_heads=num_attention_heads, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + ) for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -148,9 +154,9 @@ def convert_transformer(ckpt_path: str, num_layers: int, num_attention_heads: in return transformer -def convert_vae(ckpt_path: str): +def convert_vae(ckpt_path: str, scaling_factor: float): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX() + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor) for key in list(original_state_dict.keys()): new_key = key[:] @@ -188,7 +194,11 @@ def get_args(): # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True - parser.add_argument("--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not") + parser.add_argument( + "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + ) + # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 + parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") return parser.parse_args() @@ -201,9 +211,14 @@ def get_args(): if args.fp16 and args.bf16: raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path, args.num_layers, args.num_attention_heads, args.use_rotary_positional_embeddings) + transformer = convert_transformer( + args.transformer_ckpt_path, + args.num_layers, + args.num_attention_heads, + args.use_rotary_positional_embeddings, + ) if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path) + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor) text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0e53dee95acb..4b6a9efb9410 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1786,7 +1786,6 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 3bf6e68d2628..17fa2bbf40f6 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. sample_size (`int`, *optional*, defaults to `32`): Sample input size. - scaling_factor (`float`, *optional*, defaults to 0.18215): + scaling_factor (`float`, *optional*, defaults to `1.15258426`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c2392b029d5f..7cd9e8c08856 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -374,7 +374,9 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return embeds -def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, temporal_size=13, theta=10000, use_real=True): +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ RoPE for video tokens with 3D structure. @@ -395,6 +397,23 @@ def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, temporal_size=13 Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ + # start, stop = crops_coords + # grid_t = torch.arange(temporal_size, dtype=torch.float32) + # grid_h = torch.arange(stop[0] - start[0], dtype=torch.float32) + # grid_w = torch.arange(stop[1] - start[1], dtype=torch.float32) + + # dim_t = embed_dim // 4 + # dim_h = embed_dim // 8 * 3 + # dim_w = embed_dim // 8 * 3 + + # freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2, dtype=torch.float32)[: dim_t // 2] / dim_t)) + # freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: dim_h // 2] / dim_h)) + # freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: dim_w // 2] / dim_w)) + + # freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t).repeat_interleave(2, dim=-1) + # freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h).repeat_interleave(2, dim=-1) + # freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w).repeat_interleave(2, dim=-1) + start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) @@ -422,7 +441,7 @@ def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, temporal_size=13 freqs_w = freqs_w.repeat_interleave(2, dim=-1) # Broadcast and concatenate tensors along specified dimension - def broadcat(tensors, dim=-1): + def broadcast(tensors, dim=-1): num_tensors = len(tensors) shape_lens = {len(t.shape) for t in tensors} assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" @@ -440,9 +459,8 @@ def broadcat(tensors, dim=-1): tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] return torch.cat(tensors, dim=dim) - freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - # Flatten and reorder dimensions manually without rearrange t, h, w, d = freqs.shape freqs = freqs.view(t * h * w, d) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 7b0c01c6e60b..e756c3ad3f44 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,7 +22,7 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..attention_processor import HunyuanAttnProcessor2_0, CogVideoXAttnProcessor2_0 +from ..attention_processor import CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -120,7 +120,6 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: - breakpoint() norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 1bb9351d4107..b7a23ebb5786 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,8 +429,14 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def _prepare_rotary_embeddings( - self, height: int, width: int, max_sequence_length: int + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) @@ -444,13 +450,14 @@ def _prepare_rotary_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=num_frames, use_real=True, ) pad_tensor_cos = freqs_cos.new_zeros((max_sequence_length, freqs_cos.size(1))) pad_tensor_sin = freqs_sin.new_zeros((max_sequence_length, freqs_sin.size(1))) - freqs_cos = torch.cat([pad_tensor_cos, freqs_cos]) - freqs_sin = torch.cat([pad_tensor_sin, freqs_sin]) + freqs_cos = torch.cat([pad_tensor_cos, freqs_cos]).to(device=device, dtype=dtype) + freqs_sin = torch.cat([pad_tensor_sin, freqs_sin]).to(device=device, dtype=dtype) return freqs_cos, freqs_sin @@ -646,7 +653,9 @@ def __call__( # 7. Create rotary embeds if required image_rotary_emb = ( - self._prepare_rotary_embeddings(height, width, max_sequence_length) + self._prepare_rotary_positional_embeddings( + height, width, latents.size(1), max_sequence_length, device, prompt_embeds.dtype + ) if self.transformer.config.use_rotary_positional_embeddings else None ) From 2f178ee60a2af5a7aff53f92bee9c3d2247e7e2d Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 22 Aug 2024 02:48:46 +0200 Subject: [PATCH 09/19] fix bugs --- scripts/convert_cogvideox_to_diffusers.py | 18 +++++++++---- src/diffusers/models/attention_processor.py | 13 +++++---- src/diffusers/models/embeddings.py | 6 ++++- .../transformers/cogvideox_transformer_3d.py | 27 ++++++++++++------- .../pipelines/cogvideo/pipeline_cogvideox.py | 12 +++------ 5 files changed, 47 insertions(+), 29 deletions(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 9062e894cb58..5854c4460e98 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -127,7 +127,11 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( - ckpt_path: str, num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + dtype: torch.dtype, ): PREFIX_KEY = "model.diffusion_model." @@ -136,7 +140,7 @@ def convert_transformer( num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, - ) + ).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -154,9 +158,9 @@ def convert_transformer( return transformer -def convert_vae(ckpt_path: str, scaling_factor: float): +def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor) + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -210,15 +214,19 @@ def get_args(): if args.fp16 and args.bf16: raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") + + dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 + if args.transformer_ckpt_path is not None: transformer = convert_transformer( args.transformer_ckpt_path, args.num_layers, args.num_attention_heads, args.use_rotary_positional_embeddings, + dtype, ) if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor) + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4b6a9efb9410..6b6ba61952c3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1787,6 +1787,7 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + text_seq_length: int = 226, ) -> torch.Tensor: from .embeddings import apply_rotary_emb @@ -1809,13 +1810,12 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -1833,16 +1833,19 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) + query[:, :, text_seq_length:] = apply_rotary_emb( + query[:, :, text_seq_length:], image_rotary_emb, upcast=False + ) if not attn.is_cross_attention: - key = apply_rotary_emb(key, image_rotary_emb) + key[:, :, text_seq_length:] = apply_rotary_emb( + key[:, :, text_seq_length:], image_rotary_emb, upcast=False + ) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7cd9e8c08856..e37c0921fb52 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -598,6 +598,7 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, + upcast: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -630,7 +631,10 @@ def apply_rotary_emb( else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + if upcast: + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + else: + out = x * cos + x_rotated * sin return out else: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index e756c3ad3f44..c464595ab7fb 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -125,7 +125,7 @@ def forward( ) # attention - text_length = norm_encoder_hidden_states.size(1) + text_seq_length = norm_encoder_hidden_states.size(1) # CogVideoX uses concatenated text + video embeddings with self-attention instead of using # them in cross-attention individually @@ -134,10 +134,11 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb, + text_seq_length=text_seq_length, ) - hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] + hidden_states = hidden_states + gate_msa * attn_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_seq_length] # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( @@ -148,8 +149,8 @@ def forward( norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] return hidden_states, encoder_hidden_states @@ -325,15 +326,16 @@ def forward( hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding + text_seq_length = encoder_hidden_states.shape[1] if not self.config.use_rotary_positional_embeddings: seq_length = height * width * num_frames // (self.config.patch_size**2) - pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] + pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) - encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] - hidden_states = hidden_states[:, self.config.max_text_seq_length :] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): @@ -362,7 +364,14 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, ) - hidden_states = self.norm_final(hidden_states) + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] # 5. Final block hidden_states = self.norm_out(hidden_states, temb=emb) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index b7a23ebb5786..0ae0cbcd40ec 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -434,7 +434,6 @@ def _prepare_rotary_positional_embeddings( height: int, width: int, num_frames: int, - max_sequence_length: int, device: torch.device, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -454,11 +453,8 @@ def _prepare_rotary_positional_embeddings( use_real=True, ) - pad_tensor_cos = freqs_cos.new_zeros((max_sequence_length, freqs_cos.size(1))) - pad_tensor_sin = freqs_sin.new_zeros((max_sequence_length, freqs_sin.size(1))) - freqs_cos = torch.cat([pad_tensor_cos, freqs_cos]).to(device=device, dtype=dtype) - freqs_sin = torch.cat([pad_tensor_sin, freqs_sin]).to(device=device, dtype=dtype) - + freqs_cos = freqs_cos.to(device=device, dtype=dtype) + freqs_sin = freqs_sin.to(device=device, dtype=dtype) return freqs_cos, freqs_sin @property @@ -653,9 +649,7 @@ def __call__( # 7. Create rotary embeds if required image_rotary_emb = ( - self._prepare_rotary_positional_embeddings( - height, width, latents.size(1), max_sequence_length, device, prompt_embeds.dtype - ) + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, prompt_embeds.dtype) if self.transformer.config.use_rotary_positional_embeddings else None ) From 4a22392fd155fef8da5ff6b7c9f4d6f42a5030f0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 22 Aug 2024 07:11:36 +0200 Subject: [PATCH 10/19] rebase --- src/diffusers/models/attention_processor.py | 75 --------------------- 1 file changed, 75 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 33bf9dde2e32..f2d33c30ca82 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1695,81 +1695,6 @@ def __call__( return hidden_states -# YiYi to-do: refactor rope related functions/classes -def apply_rope(xq, xk, freqs_cis): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - -class FluxSingleAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - query, key = apply_rope(query, key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states - - class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" From e1a51a727f36a1c25c8a188980a6a349f8791c76 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 22 Aug 2024 08:03:19 +0200 Subject: [PATCH 11/19] add qkv fusion support --- src/diffusers/models/attention_processor.py | 109 ++++++++++++---- .../transformers/cogvideox_transformer_3d.py | 117 ++++++++++++++++-- .../pipelines/cogvideo/pipeline_cogvideox.py | 13 ++ 3 files changed, 202 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f2d33c30ca82..914d5e20fd7c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1797,20 +1797,13 @@ def __call__( self, attn: Attention, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - text_seq_length: int = 226, ) -> torch.Tensor: - from .embeddings import apply_rotary_emb + text_seq_length = encoder_hidden_states.size(1) - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -1820,17 +1813,9 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1846,6 +1831,8 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + query[:, :, text_seq_length:] = apply_rotary_emb( query[:, :, text_seq_length:], image_rotary_emb, upcast=False ) @@ -1865,15 +1852,85 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states - if attn.residual_connection: - hidden_states = hidden_states + residual - hidden_states = hidden_states / attn.rescale_output_factor +class FusedCogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ - return hidden_states + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb( + query[:, :, text_seq_length:], image_rotary_emb, upcast=False + ) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb( + key[:, :, text_seq_length:], image_rotary_emb, upcast=False + ) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states class XFormersAttnAddedKVProcessor: diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c464595ab7fb..83a6e2d7ad99 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,7 +22,7 @@ from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward -from ..attention_processor import CogVideoXAttnProcessor2_0 +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -120,25 +120,22 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) # attention - text_seq_length = norm_encoder_hidden_states.size(1) - - # CogVideoX uses concatenated text + video embeddings with self-attention instead of using - # them in cross-attention individually - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - attn_output = self.attn1( + attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, - encoder_hidden_states=None, + encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - text_seq_length=text_seq_length, ) - hidden_states = hidden_states + gate_msa * attn_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_seq_length] + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( @@ -151,6 +148,7 @@ def forward( hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + return hidden_states, encoder_hidden_states @@ -301,6 +299,103 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the + corresponding cross attention processor. This is strongly recommended when setting trainable attention + processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 0ae0cbcd40ec..927960c2235f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -429,6 +429,19 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + def _prepare_rotary_positional_embeddings( self, height: int, From 43c4edb1a58bb1ff96d521534627e229fd484223 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 22 Aug 2024 08:05:16 +0200 Subject: [PATCH 12/19] make fix-copies --- .../models/transformers/cogvideox_transformer_3d.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 83a6e2d7ad99..c8d4b1896346 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -332,9 +332,11 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the - corresponding cross attention processor. This is strongly recommended when setting trainable attention - processors. + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ count = len(self.attn_processors.keys()) @@ -357,7 +359,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -392,6 +394,7 @@ def unfuse_qkv_projections(self): This API is 🧪 experimental. + """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) From e5c686118caf6aaca1d2d004173bab27f3cd74e8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 22 Aug 2024 08:13:30 +0200 Subject: [PATCH 13/19] remove commented changes --- src/diffusers/models/embeddings.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5e4a21085295..2991a3cc7e1b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -397,23 +397,6 @@ def get_3d_rotary_pos_embed( Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ - # start, stop = crops_coords - # grid_t = torch.arange(temporal_size, dtype=torch.float32) - # grid_h = torch.arange(stop[0] - start[0], dtype=torch.float32) - # grid_w = torch.arange(stop[1] - start[1], dtype=torch.float32) - - # dim_t = embed_dim // 4 - # dim_h = embed_dim // 8 * 3 - # dim_w = embed_dim // 8 * 3 - - # freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2, dtype=torch.float32)[: dim_t // 2] / dim_t)) - # freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: dim_h // 2] / dim_h)) - # freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: dim_w // 2] / dim_w)) - - # freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t).repeat_interleave(2, dim=-1) - # freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h).repeat_interleave(2, dim=-1) - # freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w).repeat_interleave(2, dim=-1) - start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) From 6ee1e28c175004bca85b833820e643a71878675f Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Fri, 23 Aug 2024 00:56:23 +0800 Subject: [PATCH 14/19] Update convert_cogvideox_to_diffusers.py --- scripts/convert_cogvideox_to_diffusers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 5854c4460e98..6448da7f1131 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -203,6 +203,8 @@ def get_args(): ) # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 + parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") return parser.parse_args() @@ -238,7 +240,7 @@ def get_args(): scheduler = CogVideoXDDIMScheduler.from_config( { - "snr_shift_scale": 3.0, + "snr_shift_scale": args.snr_shift_scale, "beta_end": 0.012, "beta_schedule": "scaled_linear", "beta_start": 0.00085, From 2badcc53ac0cf8fdf1c10890323ab842eea344a3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 23 Aug 2024 02:20:21 +0200 Subject: [PATCH 15/19] revert upcast changes --- src/diffusers/models/attention_processor.py | 8 ++------ src/diffusers/models/embeddings.py | 6 +----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 914d5e20fd7c..4dcfa4270cf9 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1908,13 +1908,9 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - query[:, :, text_seq_length:] = apply_rotary_emb( - query[:, :, text_seq_length:], image_rotary_emb, upcast=False - ) + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb( - key[:, :, text_seq_length:], image_rotary_emb, upcast=False - ) + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2991a3cc7e1b..d1366654c448 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -584,7 +584,6 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, - upcast: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -617,10 +616,7 @@ def apply_rotary_emb( else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - if upcast: - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - else: - out = x * cos + x_rotated * sin + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: From ae43411bbb1cb4f457d81e218ec9f0ea684b0225 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 23 Aug 2024 02:45:02 +0200 Subject: [PATCH 16/19] fix rope call --- src/diffusers/models/attention_processor.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4dcfa4270cf9..75b4f164eb25 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1833,13 +1833,9 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - query[:, :, text_seq_length:] = apply_rotary_emb( - query[:, :, text_seq_length:], image_rotary_emb, upcast=False - ) + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb( - key[:, :, text_seq_length:], image_rotary_emb, upcast=False - ) + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From a8f5ce06410d2c60167866c00d7c95ad57719c8d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 23 Aug 2024 03:18:16 +0200 Subject: [PATCH 17/19] add qkv fusion test --- tests/pipelines/cogvideox/test_cogvideox.py | 45 ++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py index 17d0d8f21d5c..c69dcfda93c5 100644 --- a/tests/pipelines/cogvideox/test_cogvideox.py +++ b/tests/pipelines/cogvideox/test_cogvideox.py @@ -30,7 +30,12 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) enable_full_determinism() @@ -279,6 +284,44 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): def test_xformers_attention_forwardGenerator_pass(self): pass + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_fused = frames[0, -2:, -1, -3:, -3:] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_disabled = frames[0, -2:, -1, -3:, -3:] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + @slow @require_torch_gpu From 5e9924788e7939181eb6a4f09b2c495be3a90c3f Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 23 Aug 2024 03:18:48 +0200 Subject: [PATCH 18/19] update --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 927960c2235f..0e0e907ff4a7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -448,7 +448,6 @@ def _prepare_rotary_positional_embeddings( width: int, num_frames: int, device: torch.device, - dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) @@ -466,8 +465,9 @@ def _prepare_rotary_positional_embeddings( use_real=True, ) - freqs_cos = freqs_cos.to(device=device, dtype=dtype) - freqs_sin = freqs_sin.to(device=device, dtype=dtype) + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin @property @@ -662,7 +662,7 @@ def __call__( # 7. Create rotary embeds if required image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, prompt_embeds.dtype) + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None ) From e7cd7a9c0b94a1a4502468dfe29a690bd04601c0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 23 Aug 2024 03:29:27 +0200 Subject: [PATCH 19/19] update docs --- docs/source/en/api/pipelines/cogvideox.md | 6 +++++- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 549666e60ebc..c7340eff40c4 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). +There are two models available that can be used with the CogVideoX pipeline: +- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) +- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) + ## Inference Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. @@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds. ### Memory optimization -CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script. +CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script. - `pipe.enable_model_cpu_offload()`: - Without enabling cpu offloading, memory usage is `33 GB` diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 0e0e907ff4a7..e100c1f11e20 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -41,6 +41,7 @@ >>> from diffusers import CogVideoXPipeline >>> from diffusers.utils import export_to_video + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") >>> prompt = ( ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " @@ -467,7 +468,6 @@ def _prepare_rotary_positional_embeddings( freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin @property