Skip to content

Commit df073ba

Browse files
authored
[research_projects] add flux training script with quantization (#9754)
* add flux training script with quantization * remove exclamation
1 parent 94643fa commit df073ba

File tree

5 files changed

+1496
-0
lines changed

5 files changed

+1496
-0
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
## LoRA fine-tuning Flux.1 Dev with quantization
2+
3+
> [!NOTE]
4+
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
5+
6+
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
7+
8+
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
9+
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
10+
* Since we already precomputed the text embeddings, we don't load the text encoders.
11+
* We load the VAE and use it to precompute the image latents and we then delete it.
12+
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
13+
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
14+
* Train!
15+
16+
To run training in a memory-optimized manner, we additionally use:
17+
18+
* 8Bit Adam
19+
* Gradient checkpointing
20+
21+
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
22+
23+
## Training
24+
25+
Ensure you have installed the required libraries:
26+
27+
```bash
28+
pip install -U transformers accelerate bitsandbytes peft datasets
29+
pip install git+https://github.com/huggingface/diffusers -U
30+
```
31+
32+
Now, compute the text embeddings:
33+
34+
```bash
35+
python compute_embeddings.py
36+
```
37+
38+
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
39+
40+
```bash
41+
huggingface-cli
42+
```
43+
44+
Then launch:
45+
46+
```bash
47+
accelerate launch --config_file=accelerate.yaml \
48+
train_dreambooth_lora_flux_miniature.py \
49+
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
50+
--data_df_path="embeddings.parquet" \
51+
--output_dir="yarn_art_lora_flux_nf4" \
52+
--mixed_precision="fp16" \
53+
--use_8bit_adam \
54+
--weighting_scheme="none" \
55+
--resolution=1024 \
56+
--train_batch_size=1 \
57+
--repeats=1 \
58+
--learning_rate=1e-4 \
59+
--guidance_scale=1 \
60+
--report_to="wandb" \
61+
--gradient_accumulation_steps=4 \
62+
--gradient_checkpointing \
63+
--lr_scheduler="constant" \
64+
--lr_warmup_steps=0 \
65+
--cache_latents \
66+
--rank=4 \
67+
--max_train_steps=700 \
68+
--seed="0"
69+
```
70+
71+
We can direcly pass a quantized checkpoint path, too:
72+
73+
```diff
74+
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
75+
```
76+
77+
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
78+
79+
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
80+
81+
```bash
82+
pip install -Uq deepspeed
83+
```
84+
85+
And then launch:
86+
87+
```bash
88+
accelerate launch --config_file=ds2.yaml \
89+
train_dreambooth_lora_flux_miniature.py \
90+
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
91+
--data_df_path="embeddings.parquet" \
92+
--output_dir="yarn_art_lora_flux_nf4" \
93+
--mixed_precision="no" \
94+
--use_8bit_adam \
95+
--weighting_scheme="none" \
96+
--resolution=1024 \
97+
--train_batch_size=1 \
98+
--repeats=1 \
99+
--learning_rate=1e-4 \
100+
--guidance_scale=1 \
101+
--report_to="wandb" \
102+
--gradient_accumulation_steps=4 \
103+
--gradient_checkpointing \
104+
--lr_scheduler="constant" \
105+
--lr_warmup_steps=0 \
106+
--cache_latents \
107+
--rank=4 \
108+
--max_train_steps=700 \
109+
--seed="0"
110+
```
111+
112+
## Inference
113+
114+
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
115+
116+
1. First, load the original model and merge the LoRA params into it:
117+
118+
```py
119+
from diffusers import FluxPipeline
120+
import torch
121+
122+
ckpt_id = "black-forest-labs/FLUX.1-dev"
123+
pipeline = FluxPipeline.from_pretrained(
124+
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
125+
)
126+
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
127+
pipeline.fuse_lora()
128+
pipeline.unload_lora_weights()
129+
130+
pipeline.transformer.save_pretrained("fused_transformer")
131+
```
132+
133+
2. Quantize the model and run inference
134+
135+
```py
136+
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
137+
import torch
138+
139+
ckpt_id = "black-forest-labs/FLUX.1-dev"
140+
bnb_4bit_compute_dtype = torch.float16
141+
nf4_config = BitsAndBytesConfig(
142+
load_in_4bit=True,
143+
bnb_4bit_quant_type="nf4",
144+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
145+
)
146+
transformer = FluxTransformer2DModel.from_pretrained(
147+
"fused_transformer",
148+
quantization_config=nf4_config,
149+
torch_dtype=bnb_4bit_compute_dtype,
150+
)
151+
pipeline = AutoPipelineForText2Image.from_pretrained(
152+
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
153+
)
154+
pipeline.enable_model_cpu_offload()
155+
156+
image = pipeline(
157+
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
158+
).images[0]
159+
image.save("yarn_merged.png")
160+
```
161+
162+
| Dequantize, merge, quantize | Merging directly into quantized model |
163+
|-------|-------|
164+
| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) |
165+
166+
As we can notice the first column result follows the style more closely.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: NO
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: true
6+
gpu_ids: all
7+
machine_rank: 0
8+
main_training_function: main
9+
mixed_precision: bf16
10+
num_machines: 1
11+
num_processes: 1
12+
rdzv_backend: static
13+
same_network: true
14+
tpu_env: []
15+
tpu_use_cluster: false
16+
tpu_use_sudo: false
17+
use_cpu: false
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
19+
import pandas as pd
20+
import torch
21+
from datasets import load_dataset
22+
from huggingface_hub.utils import insecure_hashlib
23+
from tqdm.auto import tqdm
24+
from transformers import T5EncoderModel
25+
26+
from diffusers import FluxPipeline
27+
28+
29+
MAX_SEQ_LENGTH = 77
30+
OUTPUT_PATH = "embeddings.parquet"
31+
32+
33+
def generate_image_hash(image):
34+
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
35+
36+
37+
def load_flux_dev_pipeline():
38+
id = "black-forest-labs/FLUX.1-dev"
39+
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
40+
pipeline = FluxPipeline.from_pretrained(
41+
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
42+
)
43+
return pipeline
44+
45+
46+
@torch.no_grad()
47+
def compute_embeddings(pipeline, prompts, max_sequence_length):
48+
all_prompt_embeds = []
49+
all_pooled_prompt_embeds = []
50+
all_text_ids = []
51+
for prompt in tqdm(prompts, desc="Encoding prompts."):
52+
(
53+
prompt_embeds,
54+
pooled_prompt_embeds,
55+
text_ids,
56+
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
57+
all_prompt_embeds.append(prompt_embeds)
58+
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
59+
all_text_ids.append(text_ids)
60+
61+
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
62+
print(f"Max memory allocated: {max_memory:.3f} GB")
63+
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
64+
65+
66+
def run(args):
67+
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
68+
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
69+
all_prompts = list(image_prompts.values())
70+
print(f"{len(all_prompts)=}")
71+
72+
pipeline = load_flux_dev_pipeline()
73+
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
74+
pipeline, all_prompts, args.max_sequence_length
75+
)
76+
77+
data = []
78+
for i, (image_hash, _) in enumerate(image_prompts.items()):
79+
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
80+
print(f"{len(data)=}")
81+
82+
# Create a DataFrame
83+
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
84+
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
85+
print(f"{len(df)=}")
86+
87+
# Convert embedding lists to arrays (for proper storage in parquet)
88+
for col in embedding_cols:
89+
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
90+
91+
# Save the dataframe to a parquet file
92+
df.to_parquet(args.output_path)
93+
print(f"Data successfully serialized to {args.output_path}")
94+
95+
96+
if __name__ == "__main__":
97+
parser = argparse.ArgumentParser()
98+
parser.add_argument(
99+
"--max_sequence_length",
100+
type=int,
101+
default=MAX_SEQ_LENGTH,
102+
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
103+
)
104+
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
105+
args = parser.parse_args()
106+
107+
run(args)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_accumulation_steps: 1
5+
gradient_clipping: 1.0
6+
offload_optimizer_device: cpu
7+
offload_param_device: cpu
8+
zero3_init_flag: false
9+
zero_stage: 2
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
enable_cpu_affinity: false
13+
machine_rank: 0
14+
main_training_function: main
15+
mixed_precision: 'no'
16+
num_machines: 1
17+
num_processes: 1
18+
rdzv_backend: static
19+
same_network: true
20+
tpu_env: []
21+
tpu_use_cluster: false
22+
tpu_use_sudo: false
23+
use_cpu: false

0 commit comments

Comments
 (0)