-
Notifications
You must be signed in to change notification settings - Fork 6k
[Flux ControlNet] ControlNet initialization from transformer seems to be broken #9540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I explicitly pass it in, and works flux_controlnet = FluxControlNetModel.from_transformer(
flux_transformer,
+ attention_head_dim=flux_transformer.config["attention_head_dim"],
+ num_attention_heads=flux_transformer.config["num_attention_heads"],
num_layers=args.num_double_layers,
num_single_layers=args.num_single_layers,
) |
Yeah I would assume so as well. So, if we match the parameters of the base transformer exactly, then it works: controlnet = FluxControlNetModel.from_transformer(
transformer=transformer, num_layers=1, num_single_layers=1, attention_head_dim=16, num_attention_heads=2
) In this case, For now, I am going to close this issue but we can revisit this later. |
I think the default is the transformer configuration, unless the user explicitly specifies it, then we need to change |
something like this def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim = None,
num_attention_heads = None,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim if attention_head_dim is not None else config["attention_head_dim"]
config["num_attention_heads"] = num_attention_heads if num_attention_heads is not None else config["num_attention_heads"]
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet |
Originally caught in #9324.
Reproduction:
Leads to:
Cc: @PromeAIpro
I think it makes sense to make this more robust and have dedicated testing for it.
@yiyixuxu possible to look into it?
The text was updated successfully, but these errors were encountered: