Skip to content

[Flux ControlNet] ControlNet initialization from transformer seems to be broken #9540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sayakpaul opened this issue Sep 27, 2024 · 5 comments
Assignees

Comments

@sayakpaul
Copy link
Member

Originally caught in #9324.

Reproduction:

from diffusers import FluxTransformer2DModel, FluxControlNetModel

transformer = FluxTransformer2DModel.from_pretrained(
    "hf-internal-testing/tiny-flux-pipe", subfolder="transformer"
)
controlnet = FluxControlNetModel.from_transformer(
    transformer=transformer, num_layers=1, num_single_layers=1, attention_head_dim=16, num_attention_heads=1
)

Leads to:

RuntimeError: Error(s) in loading state_dict for CombinedTimestepTextProjEmbeddings:
        size mismatch for timestep_embedder.linear_1.weight: copying a param with shape torch.Size([32, 256]) from checkpoint, the shape in current model is torch.Size([16, 256]).
        size mismatch for timestep_embedder.linear_1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for timestep_embedder.linear_2.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([16, 16]).
        size mismatch for timestep_embedder.linear_2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for text_embedder.linear_1.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([16, 32]).
        size mismatch for text_embedder.linear_1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).
        size mismatch for text_embedder.linear_2.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([16, 16]).
        size mismatch for text_embedder.linear_2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([16]).

Cc: @PromeAIpro

I think it makes sense to make this more robust and have dedicated testing for it.

@yiyixuxu possible to look into it?

@PromeAIpro
Copy link
Contributor

I explicitly pass it in, and works

flux_controlnet = FluxControlNetModel.from_transformer(
            flux_transformer,
+            attention_head_dim=flux_transformer.config["attention_head_dim"],
+            num_attention_heads=flux_transformer.config["num_attention_heads"],
            num_layers=args.num_double_layers,
            num_single_layers=args.num_single_layers,
        )

@PromeAIpro
Copy link
Contributor

Why do we need to update the parameter here? Shouldn't it be passed in by the transformer?
image

@sayakpaul
Copy link
Member Author

Yeah I would assume so as well.

So, if we match the parameters of the base transformer exactly, then it works:

controlnet = FluxControlNetModel.from_transformer(
    transformer=transformer, num_layers=1, num_single_layers=1, attention_head_dim=16, num_attention_heads=2
)

In this case, num_layers, num_single_layers, attention_head_dim, and num_attention_heads have been set to the values used by the transformer. But exposing this arguments can lead the user to believe that these are configurable.

For now, I am going to close this issue but we can revisit this later.

@PromeAIpro
Copy link
Contributor

PromeAIpro commented Sep 27, 2024

I think the default is the transformer configuration, unless the user explicitly specifies it, then we need to changeattention_head_dim and num_attention_heads

@PromeAIpro
Copy link
Contributor

PromeAIpro commented Sep 27, 2024

something like this

    def from_transformer(
        cls,
        transformer,
        num_layers: int = 4,
        num_single_layers: int = 10,
        attention_head_dim = None,
        num_attention_heads = None,
        load_weights_from_transformer=True,
    ):
        config = transformer.config
        config["num_layers"] = num_layers
        config["num_single_layers"] = num_single_layers
        config["attention_head_dim"] = attention_head_dim if attention_head_dim is not None else config["attention_head_dim"]
        config["num_attention_heads"] = num_attention_heads if num_attention_heads is not None else config["num_attention_heads"]

        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
            controlnet.single_transformer_blocks.load_state_dict(
                transformer.single_transformer_blocks.state_dict(), strict=False
            )

            controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)

        return controlnet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants