Skip to content

fix: CogVideox train dataset _preprocess_data crop video #9574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cogvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none'].

> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
Expand Down
74 changes: 64 additions & 10 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np


if is_wandb_available():
Expand Down Expand Up @@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
Expand Down Expand Up @@ -413,6 +423,7 @@ def __init__(
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
Expand All @@ -429,6 +440,7 @@ def __init__(
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
Expand Down Expand Up @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):

return instance_prompts, instance_videos

def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)

h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)

delta_h = h - image_size[0]
delta_w = w - image_size[1]

if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr

def _preprocess_data(self):
try:
import decord
Expand All @@ -542,14 +586,14 @@ def _preprocess_data(self):

decord.bridge.set_bridge("torch")

videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []

for filename in self.instance_video_paths:
progress_dataset_bar.update(1)
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_num_frames = len(video_reader)

Expand All @@ -576,10 +620,12 @@ def _preprocess_data(self):
assert (selected_num_frames - 1) % 4 == 0

# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
tensor = frames.float() / 255.0
Copy link
Member

@a-r-r-o-w a-r-r-o-w Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we are making the tensors to range [0, 1], instead of [-1, 1]. In the original codebase, we convert to [-1, 1] as well here if I understand correctly, yes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we are making the tensors to range [0, 1], instead of [-1, 1]. In the original codebase, we convert to [-1, 1] as well here if I understand correctly, yes?

You're right, it should be in the [-1, 1] range. In fact, this is for matrix calculations during fine-tuning, and the [-1, 1] range is easier for computation. I forgot that this step is handled in the latten2img process, so the image is in the [0, 1] range, while the latent space is in the [-1, 1] range.

I've already verified that the cause of the training result showing a blank screen is that I input a 960x720 image into the dataset, and it was compressed to a 460x720 image for training directly.

Copy link
Contributor Author

@glide-the glide-the Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/THUDM/CogVideo/blob/111756a6a68a8df375ef9c31f9f325818699dfaa/sat/data_video.py#L437
The number 127.5 may experience precision loss during division operations.

encode : images / 255.0 * 2.0 - 1.0
decode: (images / 2 + 0.5).clamp(0, 1)
image

encode : (frames - 127.5) / 127.5
decode: (images / 2 + 0.5).clamp(0, 1)
image

frames = tensor.permute(0, 3, 1, 2)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]

progress_dataset_bar.close()
return videos


Expand Down Expand Up @@ -1171,6 +1217,7 @@ def load_model_hook(models, input_dir):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
Expand All @@ -1179,13 +1226,20 @@ def load_model_hook(models, input_dir):
id_token=args.id_token,
)

def encode_video(video):

def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist

train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [encode_video(video,progress_encode_bar) for video in train_dataset.instance_videos]
progress_encode_bar.close()

def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
Expand Down
681 changes: 681 additions & 0 deletions examples/cogvideo/video_fix_rgb_float_and_crop.ipynb

Large diffs are not rendered by default.

Loading