Skip to content

Commit 977afbe

Browse files
zRzRzRzRzRzRzRa-r-r-o-w
authored andcommitted
Cogvideox-5B Model adapter change (#9203)
* draft of embedding --------- Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 51317a8 commit 977afbe

File tree

8 files changed

+536
-32
lines changed

8 files changed

+536
-32
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2929

3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

32+
There are two models available that can be used with the CogVideoX pipeline:
33+
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
34+
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
35+
3236
## Inference
3337

3438
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds.
6872

6973
### Memory optimization
7074

71-
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
75+
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
7276

7377
- `pipe.enable_model_cpu_offload()`:
7478
- Without enabling cpu offloading, memory usage is `33 GB`

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
8686
"key_layernorm_list": reassign_query_key_layernorm_inplace,
8787
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
8888
"embed_tokens": remove_keys_inplace,
89+
"freqs_sin": remove_keys_inplace,
90+
"freqs_cos": remove_keys_inplace,
91+
"position_embedding": remove_keys_inplace,
8992
}
9093

9194
VAE_KEYS_RENAME_DICT = {
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
123126
state_dict[new_key] = state_dict.pop(old_key)
124127

125128

126-
def convert_transformer(ckpt_path: str):
129+
def convert_transformer(
130+
ckpt_path: str,
131+
num_layers: int,
132+
num_attention_heads: int,
133+
use_rotary_positional_embeddings: bool,
134+
dtype: torch.dtype,
135+
):
127136
PREFIX_KEY = "model.diffusion_model."
128137

129138
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
130-
transformer = CogVideoXTransformer3DModel()
139+
transformer = CogVideoXTransformer3DModel(
140+
num_layers=num_layers,
141+
num_attention_heads=num_attention_heads,
142+
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
143+
).to(dtype=dtype)
131144

132145
for key in list(original_state_dict.keys()):
133146
new_key = key[len(PREFIX_KEY) :]
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
145158
return transformer
146159

147160

148-
def convert_vae(ckpt_path: str):
161+
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
149162
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
150-
vae = AutoencoderKLCogVideoX()
163+
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
151164

152165
for key in list(original_state_dict.keys()):
153166
new_key = key[:]
@@ -172,13 +185,26 @@ def get_args():
172185
)
173186
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
174187
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
175-
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
188+
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
189+
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
176190
parser.add_argument(
177191
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
178192
)
179193
parser.add_argument(
180194
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
181195
)
196+
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
197+
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
198+
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
199+
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
200+
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
201+
parser.add_argument(
202+
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
203+
)
204+
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
205+
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
206+
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
207+
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
182208
return parser.parse_args()
183209

184210

@@ -188,18 +214,33 @@ def get_args():
188214
transformer = None
189215
vae = None
190216

217+
if args.fp16 and args.bf16:
218+
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
219+
220+
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
221+
191222
if args.transformer_ckpt_path is not None:
192-
transformer = convert_transformer(args.transformer_ckpt_path)
223+
transformer = convert_transformer(
224+
args.transformer_ckpt_path,
225+
args.num_layers,
226+
args.num_attention_heads,
227+
args.use_rotary_positional_embeddings,
228+
dtype,
229+
)
193230
if args.vae_ckpt_path is not None:
194-
vae = convert_vae(args.vae_ckpt_path)
231+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
195232

196233
text_encoder_id = "google/t5-v1_1-xxl"
197234
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
198235
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
199236

237+
# Apparently, the conversion does not work any more without this :shrug:
238+
for param in text_encoder.parameters():
239+
param.data = param.data.contiguous()
240+
200241
scheduler = CogVideoXDDIMScheduler.from_config(
201242
{
202-
"snr_shift_scale": 3.0,
243+
"snr_shift_scale": args.snr_shift_scale,
203244
"beta_end": 0.012,
204245
"beta_schedule": "scaled_linear",
205246
"beta_start": 0.00085,
@@ -208,7 +249,7 @@ def get_args():
208249
"prediction_type": "v_prediction",
209250
"rescale_betas_zero_snr": True,
210251
"set_alpha_to_one": True,
211-
"timestep_spacing": "linspace",
252+
"timestep_spacing": "trailing",
212253
}
213254
)
214255

@@ -218,5 +259,10 @@ def get_args():
218259

219260
if args.fp16:
220261
pipe = pipe.to(dtype=torch.float16)
262+
if args.bf16:
263+
pipe = pipe.to(dtype=torch.bfloat16)
221264

265+
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
266+
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
267+
# is either fp16/bf16 here).
222268
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)

src/diffusers/models/attention_processor.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,148 @@ def __call__(
17831783
return hidden_states
17841784

17851785

1786+
class CogVideoXAttnProcessor2_0:
1787+
r"""
1788+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
1789+
query and key vectors, but does not include spatial normalization.
1790+
"""
1791+
1792+
def __init__(self):
1793+
if not hasattr(F, "scaled_dot_product_attention"):
1794+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1795+
1796+
def __call__(
1797+
self,
1798+
attn: Attention,
1799+
hidden_states: torch.Tensor,
1800+
encoder_hidden_states: torch.Tensor,
1801+
attention_mask: Optional[torch.Tensor] = None,
1802+
image_rotary_emb: Optional[torch.Tensor] = None,
1803+
) -> torch.Tensor:
1804+
text_seq_length = encoder_hidden_states.size(1)
1805+
1806+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1807+
1808+
batch_size, sequence_length, _ = (
1809+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1810+
)
1811+
1812+
if attention_mask is not None:
1813+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1814+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1815+
1816+
query = attn.to_q(hidden_states)
1817+
key = attn.to_k(hidden_states)
1818+
value = attn.to_v(hidden_states)
1819+
1820+
inner_dim = key.shape[-1]
1821+
head_dim = inner_dim // attn.heads
1822+
1823+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1824+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1825+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1826+
1827+
if attn.norm_q is not None:
1828+
query = attn.norm_q(query)
1829+
if attn.norm_k is not None:
1830+
key = attn.norm_k(key)
1831+
1832+
# Apply RoPE if needed
1833+
if image_rotary_emb is not None:
1834+
from .embeddings import apply_rotary_emb
1835+
1836+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
1837+
if not attn.is_cross_attention:
1838+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
1839+
1840+
hidden_states = F.scaled_dot_product_attention(
1841+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1842+
)
1843+
1844+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1845+
1846+
# linear proj
1847+
hidden_states = attn.to_out[0](hidden_states)
1848+
# dropout
1849+
hidden_states = attn.to_out[1](hidden_states)
1850+
1851+
encoder_hidden_states, hidden_states = hidden_states.split(
1852+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
1853+
)
1854+
return hidden_states, encoder_hidden_states
1855+
1856+
1857+
class FusedCogVideoXAttnProcessor2_0:
1858+
r"""
1859+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
1860+
query and key vectors, but does not include spatial normalization.
1861+
"""
1862+
1863+
def __init__(self):
1864+
if not hasattr(F, "scaled_dot_product_attention"):
1865+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1866+
1867+
def __call__(
1868+
self,
1869+
attn: Attention,
1870+
hidden_states: torch.Tensor,
1871+
encoder_hidden_states: torch.Tensor,
1872+
attention_mask: Optional[torch.Tensor] = None,
1873+
image_rotary_emb: Optional[torch.Tensor] = None,
1874+
) -> torch.Tensor:
1875+
text_seq_length = encoder_hidden_states.size(1)
1876+
1877+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1878+
1879+
batch_size, sequence_length, _ = (
1880+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1881+
)
1882+
1883+
if attention_mask is not None:
1884+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1885+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1886+
1887+
qkv = attn.to_qkv(hidden_states)
1888+
split_size = qkv.shape[-1] // 3
1889+
query, key, value = torch.split(qkv, split_size, dim=-1)
1890+
1891+
inner_dim = key.shape[-1]
1892+
head_dim = inner_dim // attn.heads
1893+
1894+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1895+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1896+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1897+
1898+
if attn.norm_q is not None:
1899+
query = attn.norm_q(query)
1900+
if attn.norm_k is not None:
1901+
key = attn.norm_k(key)
1902+
1903+
# Apply RoPE if needed
1904+
if image_rotary_emb is not None:
1905+
from .embeddings import apply_rotary_emb
1906+
1907+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
1908+
if not attn.is_cross_attention:
1909+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
1910+
1911+
hidden_states = F.scaled_dot_product_attention(
1912+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1913+
)
1914+
1915+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1916+
1917+
# linear proj
1918+
hidden_states = attn.to_out[0](hidden_states)
1919+
# dropout
1920+
hidden_states = attn.to_out[1](hidden_states)
1921+
1922+
encoder_hidden_states, hidden_states = hidden_states.split(
1923+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
1924+
)
1925+
return hidden_states, encoder_hidden_states
1926+
1927+
17861928
class XFormersAttnAddedKVProcessor:
17871929
r"""
17881930
Processor for implementing memory efficient attention using xFormers.

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
902902
Tuple of block output channels.
903903
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
904904
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
905-
scaling_factor (`float`, *optional*, defaults to 0.18215):
905+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
906906
The component-wise standard deviation of the trained latent space computed using the first batch of the
907907
training set. This is used to scale the latent space to have unit variance when training the diffusion
908908
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the

src/diffusers/models/embeddings.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,90 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
374374
return embeds
375375

376376

377+
def get_3d_rotary_pos_embed(
378+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
379+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
380+
"""
381+
RoPE for video tokens with 3D structure.
382+
383+
Args:
384+
embed_dim: (`int`):
385+
The embedding dimension size, corresponding to hidden_size_head.
386+
crops_coords (`Tuple[int]`):
387+
The top-left and bottom-right coordinates of the crop.
388+
grid_size (`Tuple[int]`):
389+
The grid size of the spatial positional embedding (height, width).
390+
temporal_size (`int`):
391+
The size of the temporal dimension.
392+
theta (`float`):
393+
Scaling factor for frequency computation.
394+
use_real (`bool`):
395+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396+
397+
Returns:
398+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399+
"""
400+
start, stop = crops_coords
401+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
402+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
403+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
404+
405+
# Compute dimensions for each axis
406+
dim_t = embed_dim // 4
407+
dim_h = embed_dim // 8 * 3
408+
dim_w = embed_dim // 8 * 3
409+
410+
# Temporal frequencies
411+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
412+
grid_t = torch.from_numpy(grid_t).float()
413+
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
414+
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
415+
416+
# Spatial frequencies for height and width
417+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
418+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
419+
grid_h = torch.from_numpy(grid_h).float()
420+
grid_w = torch.from_numpy(grid_w).float()
421+
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
422+
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
423+
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
424+
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
425+
426+
# Broadcast and concatenate tensors along specified dimension
427+
def broadcast(tensors, dim=-1):
428+
num_tensors = len(tensors)
429+
shape_lens = {len(t.shape) for t in tensors}
430+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
431+
shape_len = list(shape_lens)[0]
432+
dim = (dim + shape_len) if dim < 0 else dim
433+
dims = list(zip(*(list(t.shape) for t in tensors)))
434+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435+
assert all(
436+
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
437+
), "invalid dimensions for broadcastable concatenation"
438+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
439+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
440+
expanded_dims.insert(dim, (dim, dims[dim]))
441+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
442+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
443+
return torch.cat(tensors, dim=dim)
444+
445+
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
446+
447+
t, h, w, d = freqs.shape
448+
freqs = freqs.view(t * h * w, d)
449+
450+
# Generate sine and cosine components
451+
sin = freqs.sin()
452+
cos = freqs.cos()
453+
454+
if use_real:
455+
return cos, sin
456+
else:
457+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
458+
return freqs_cis
459+
460+
377461
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
378462
"""
379463
RoPE for image tokens with 2d structure.

0 commit comments

Comments
 (0)