Skip to content

[core] Allegro T2V #9736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
199c240
update
a-r-r-o-w Oct 21, 2024
901d10e
refactor transformer part 1
a-r-r-o-w Oct 21, 2024
ec05bbd
refactor part 2
a-r-r-o-w Oct 21, 2024
892b70d
refactor part 3
a-r-r-o-w Oct 22, 2024
fd18f9a
make style
a-r-r-o-w Oct 22, 2024
4f1653c
refactor part 4; modeling tests
a-r-r-o-w Oct 22, 2024
412cd7c
make style
a-r-r-o-w Oct 22, 2024
bcba858
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 22, 2024
8f9ffa8
refactor part 5
a-r-r-o-w Oct 22, 2024
c76dc5a
refactor part 6
a-r-r-o-w Oct 22, 2024
015cc78
gradient checkpointing
a-r-r-o-w Oct 22, 2024
6b53b85
pipeline tests (broken atm)
a-r-r-o-w Oct 22, 2024
f64f2d0
update
a-r-r-o-w Oct 22, 2024
2ef6a9e
add coauthor
a-r-r-o-w Oct 22, 2024
e53dac2
refactor part 7
a-r-r-o-w Oct 22, 2024
f702af0
add docs
a-r-r-o-w Oct 22, 2024
4f59d56
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 22, 2024
3d41281
make style
a-r-r-o-w Oct 22, 2024
37e8a95
add coauthor
a-r-r-o-w Oct 22, 2024
2c4645c
make fix-copies
a-r-r-o-w Oct 22, 2024
e26604c
undo unrelated change
a-r-r-o-w Oct 22, 2024
bb321e7
revert changes to embeddings, normalization, transformer
a-r-r-o-w Oct 23, 2024
174621f
refactor part 8
a-r-r-o-w Oct 23, 2024
2a82064
make style
a-r-r-o-w Oct 23, 2024
762ccd5
refactor part 9
a-r-r-o-w Oct 23, 2024
cf5dec1
make style
a-r-r-o-w Oct 23, 2024
31544d4
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 23, 2024
d9eabf8
fix
a-r-r-o-w Oct 23, 2024
cf010fc
apply suggestions from review
a-r-r-o-w Oct 23, 2024
d44a5c8
Apply suggestions from code review
a-r-r-o-w Oct 23, 2024
ceb7678
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 23, 2024
b036386
update example
a-r-r-o-w Oct 24, 2024
0fe8c51
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 24, 2024
2065adc
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 26, 2024
9214f4a
remove attention mask for self-attention
a-r-r-o-w Oct 28, 2024
723e5b5
Merge branch 'main' into allegro-impl
a-r-r-o-w Oct 29, 2024
3354ee1
update
a-r-r-o-w Oct 29, 2024
28e5758
copied from
a-r-r-o-w Oct 29, 2024
1ec17d5
update
a-r-r-o-w Oct 29, 2024
4d6d4e4
update
a-r-r-o-w Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@
else:
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
Expand Down Expand Up @@ -237,6 +239,7 @@
else:
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
Expand Down Expand Up @@ -556,9 +559,11 @@
from .utils.dummy_pt_objects import * # noqa F403
else:
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
Expand Down Expand Up @@ -697,6 +702,7 @@
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AmusedImg2ImgPipeline,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
Expand All @@ -54,6 +55,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
Expand Down Expand Up @@ -81,6 +83,7 @@
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
Expand All @@ -97,6 +100,7 @@
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
Expand Down
94 changes: 94 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,100 @@ def __call__(
return hidden_states, encoder_hidden_states


class AllegroAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Allegro model. It applies a s normalization layer and rotary embedding on query and key vector.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# Apply RoPE if needed
if image_rotary_emb is not None and not attn.is_cross_attention:
from .embeddings import apply_rotary_emb_allegro

query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
Expand Down
Loading
Loading