From 8d1a17caa5cef5d2fcf6a7c129d4c85335386a40 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 19 Jul 2022 15:43:38 -0700 Subject: [PATCH 01/63] re-add RL model code --- src/diffusers/models/__init__.py | 1 + src/diffusers/models/resnet.py | 134 +++++ src/diffusers/models/unet_rl.py | 228 ++++++++ tests/test_modeling_utils.py | 915 +++++++++++++++++++++++++++++++ 4 files changed, 1278 insertions(+) create mode 100644 src/diffusers/models/unet_rl.py create mode 100755 tests/test_modeling_utils.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..47f7fa71682b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,6 +18,7 @@ if is_torch_available(): from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_rl import TemporalUNet from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 49ff7d6bfa45..9b52681c3b99 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -5,6 +5,70 @@ import torch.nn.functional as F +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param + use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. + If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param + use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. + If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. @@ -374,6 +438,76 @@ def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + RearrangeDim(), + # Rearrange("batch channels horizon -> batch channels 1 horizon"), + nn.GroupNorm(n_groups, out_channels), + RearrangeDim(), + # Rearrange("batch channels 1 horizon -> batch channels horizon"), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class RearrangeDim(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +# unet_rl.py +class ResidualTemporalBlock(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): + super().__init__() + + self.blocks = nn.ModuleList( + [ + Conv1dBlock(inp_channels, out_channels, kernel_size), + Conv1dBlock(out_channels, out_channels, kernel_size), + ] + ) + + self.time_mlp = nn.Sequential( + nn.Mish(), + nn.Linear(embed_dim, out_channels), + RearrangeDim(), + # Rearrange("batch t -> batch t 1"), + ) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x + out_channels x horizon ] + """ + out = self.blocks[0](x) + self.time_mlp(t) + out = self.blocks[1](out) + return out + self.residual_conv(x) + + def upsample_2d(x, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py new file mode 100644 index 000000000000..786a80a38a6e --- /dev/null +++ b/src/diffusers/models/unet_rl.py @@ -0,0 +1,228 @@ +# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py + +import torch +import torch.nn as nn + +from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return get_timestep_embedding(x, self.dim) + + +class RearrangeDim(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + RearrangeDim(), + # Rearrange("batch channels horizon -> batch channels 1 horizon"), + nn.GroupNorm(n_groups, out_channels), + RearrangeDim(), + # Rearrange("batch channels 1 horizon -> batch channels horizon"), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): + def __init__( + self, + training_horizon=128, + transition_dim=14, + cond_dim=3, + predict_epsilon=False, + clip_denoised=True, + dim=32, + dim_mults=(1, 4, 8), + ): + super().__init__() + + self.transition_dim = transition_dim + self.cond_dim = cond_dim + self.predict_epsilon = predict_epsilon + self.clip_denoised = clip_denoised + + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + time_dim = dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append( + nn.ModuleList( + [ + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), + Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), + ] + ) + ) + + if not is_last: + training_horizon = training_horizon // 2 + + mid_dim = dims[-1] + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (num_resolutions - 1) + + self.ups.append( + nn.ModuleList( + [ + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), + Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), + ] + ) + ) + + if not is_last: + training_horizon = training_horizon * 2 + + self.final_conv = nn.Sequential( + Conv1dBlock(dim, dim, kernel_size=5), + nn.Conv1d(dim, transition_dim, 1), + ) + + def forward(self, sample, timesteps): + """ + x : [ batch x horizon x transition ] + """ + x = sample + + x = x.permute(0, 2, 1) + + t = self.time_mlp(timesteps) + h = [] + + for resnet, resnet2, downsample in self.downs: + x = resnet(x, t) + x = resnet2(x, t) + h.append(x) + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_block2(x, t) + + for resnet, resnet2, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, t) + x = resnet2(x, t) + x = upsample(x) + + x = self.final_conv(x) + + x = x.permute(0, 2, 1) + return x + + +class TemporalValue(nn.Module): + def __init__( + self, + horizon, + transition_dim, + cond_dim, + dim=32, + time_dim=None, + out_dim=1, + dim_mults=(1, 2, 4, 8), + ): + super().__init__() + + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + time_dim = time_dim or dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) + + self.blocks = nn.ModuleList([]) + + print(in_out) + for dim_in, dim_out in in_out: + self.blocks.append( + nn.ModuleList( + [ + ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), + Downsample1d(dim_out), + ] + ) + ) + + horizon = horizon // 2 + + fc_dim = dims[-1] * max(horizon, 1) + + self.final_block = nn.Sequential( + nn.Linear(fc_dim + time_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, out_dim), + ) + + def forward(self, x, cond, time, *args): + """ + x : [ batch x horizon x transition ] + """ + x = x.permute(0, 2, 1) + + t = self.time_mlp(time) + + for resnet, resnet2, downsample in self.blocks: + x = resnet(x, t) + x = resnet2(x, t) + x = downsample(x) + + x = x.view(len(x), -1) + out = self.final_block(torch.cat([x, t], dim=-1)) + return out diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py new file mode 100755 index 000000000000..bbeb76e44503 --- /dev/null +++ b/tests/test_modeling_utils.py @@ -0,0 +1,915 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import math +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it +from diffusers import ( + AutoencoderKL, + DDIMPipeline, + DDIMScheduler, + DDPMPipeline, + DDPMScheduler, + LatentDiffusionPipeline, + LatentDiffusionUncondPipeline, + PNDMPipeline, + PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, + TemporalUNet, + UNetUnconditionalModel, + VQModel, +) +from diffusers.configuration_utils import ConfigMixin +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.training_utils import EMAModel + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class ConfigTester(unittest.TestCase): + def test_load_not_from_mixin(self): + with self.assertRaises(ValueError): + ConfigMixin.from_config("dummy_path") + + def test_save_load(self): + class SampleObject(ConfigMixin): + config_name = "config.json" + + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + self.register_to_config(a=a, b=b, c=c, d=d, e=e) + + obj = SampleObject() + config = obj.config + + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + new_obj = SampleObject.from_config(tmpdirname) + new_config = new_obj.config + + # unfreeze configs + config = dict(config) + new_config = dict(new_config) + + assert config.pop("c") == (2, 5) # instantiated as tuple + assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + assert config == new_config + + +class ModelTesterMixin: + def test_from_pretrained_save_pretrained(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**inputs_dict) + if isinstance(image, dict): + image = image["sample"] + + new_image = new_model(**inputs_dict) + + if isinstance(new_image, dict): + new_image = new_image["sample"] + + max_diff = (image - new_image).abs().sum().item() + self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**inputs_dict) + if isinstance(first, dict): + first = first["sample"] + + second = model(**inputs_dict) + if isinstance(second, dict): + second = second["sample"] + + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_signature(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["sample", "timestep"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + + if isinstance(output_1, dict): + output_1 = output_1["sample"] + + output_2 = new_model(**inputs_dict) + + if isinstance(output_2, dict): + output_2 = output_2["sample"] + + self.assertEqual(output_1.shape, output_2.shape) + + def test_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + + def test_ema_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + ema_model = EMAModel(model, device=torch_device) + + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + ema_model.step(model) + + +class UnetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetUnconditionalModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_channels": (32, 64), + "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), + "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), + "num_head_channels": None, + "out_channels": 3, + "in_channels": 3, + "num_res_blocks": 2, + "image_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + +# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints +# def test_output_pretrained(self): +# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") +# model.eval() +# +# torch.manual_seed(0) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(0) +# +# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) +# time_step = torch.tensor([10]) +# +# with torch.no_grad(): +# output = model(noise, time_step)["sample"] +# +# output_slice = output[0, -1, -3:, -3:].flatten() +# fmt: off +# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) +# fmt: on +# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetUnconditionalModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "image_size": 32, + "in_channels": 4, + "out_channels": 4, + "num_res_blocks": 2, + "block_channels": (32, 64), + "num_head_channels": 32, + "conv_resample": True, + "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), + "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetUnconditionalModel.from_pretrained( + "fusing/unet-ldm-dummy-update", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input)["sample"] + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +# TODO(Patrick) - Re-add this test after having cleaned up LDM +# def test_output_pretrained_spatial_transformer(self): +# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") +# model.eval() +# +# torch.manual_seed(0) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(0) +# +# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) +# context = torch.ones((1, 16, 64), dtype=torch.float32) +# time_step = torch.tensor([10] * noise.shape[0]) +# +# with torch.no_grad(): +# output = model(noise, time_step, context=context) +# +# output_slice = output[0, -1, -3:, -3:].flatten() +# fmt: off +# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890]) +# fmt: on +# +# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +# + + +class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetUnconditionalModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_channels": [32, 64, 64, 64], + "in_channels": 3, + "num_res_blocks": 1, + "out_channels": 3, + "time_embedding_type": "fourier", + "resnet_eps": 1e-6, + "mid_block_scale_factor": math.sqrt(2.0), + "resnet_num_groups": None, + "down_blocks": [ + "UNetResSkipDownBlock2D", + "UNetResAttnSkipDownBlock2D", + "UNetResSkipDownBlock2D", + "UNetResSkipDownBlock2D", + ], + "up_blocks": [ + "UNetResSkipUpBlock2D", + "UNetResSkipUpBlock2D", + "UNetResAttnSkipUpBlock2D", + "UNetResSkipUpBlock2D", + ], + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetUnconditionalModel.from_pretrained( + "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained_ve_mid(self): + model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256") + model.to(torch_device) + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + batch_size = 4 + num_channels = 3 + sizes = (256, 256) + + noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -3:, -3:, -1].flatten().cpu() + # fmt: off + expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + def test_output_pretrained_ve_large(self): + model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") + model.to(torch_device) + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -3:, -3:, -1].flatten().cpu() + # fmt: off + expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class VQModelTests(ModelTesterMixin, unittest.TestCase): + model_class = VQModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "out_ch": 3, + "num_res_blocks": 1, + "in_channels": 3, + "attn_resolutions": [], + "resolution": 32, + "z_channels": 3, + "n_embed": 256, + "embed_dim": 3, + "sane_index_shape": False, + "ch_mult": (1,), + "dropout": 0.0, + "double_z": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = TemporalUNet + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timesteps": time_step} + + @property + def input_shape(self): + return (4, 16, 14) + + @property + def output_shape(self): + return (4, 16, 14) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "training_horizon": 128, + "dim": 32, + "dim_mults": [1, 4, 8], + "predict_epsilon": False, + "clip_denoised": True, + "transition_dim": 14, + "cond_dim": 3, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = TemporalUNet.from_pretrained( + "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.transition_dim + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)) + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + +class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "ch_mult": (1,), + "embed_dim": 4, + "in_channels": 3, + "attn_resolutions": [], + "num_res_blocks": 1, + "out_ch": 3, + "resolution": 32, + "z_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image, sample_posterior=True) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + # 1. Load models + model = UNetUnconditionalModel( + block_channels=(32, 64), + num_res_blocks=2, + image_size=32, + in_channels=3, + out_channels=3, + down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), + up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), + ) + schedular = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, schedular) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator)["sample"] + generator = generator.manual_seed(0) + new_image = new_ddpm(generator=generator)["sample"] + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_from_pretrained_hub(self): + model_path = "google/ddpm-cifar10-32" + + ddpm = DDPMPipeline.from_pretrained(model_path) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm.scheduler.num_timesteps = 10 + ddpm_from_hub.scheduler.num_timesteps = 10 + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator)["sample"] + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator)["sample"] + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_ddpm_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNetUnconditionalModel.from_pretrained(model_id) + scheduler = DDPMScheduler.from_config(model_id) + scheduler = scheduler.set_format("pt") + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator)["sample"] + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ddim_lsun(self): + model_id = "google/ddpm-ema-bedroom-256" + + unet = UNetUnconditionalModel.from_pretrained(model_id) + scheduler = DDIMScheduler.from_config(model_id) + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator)["sample"] + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor( + [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ddim_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNetUnconditionalModel.from_pretrained(model_id) + scheduler = DDIMScheduler(tensor_format="pt") + + ddim = DDIMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddim(generator=generator, eta=0.0)["sample"] + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_pndm_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNetUnconditionalModel.from_pretrained(model_id) + scheduler = PNDMScheduler(tensor_format="pt") + + pndm = PNDMPipeline(unet=unet, scheduler=scheduler) + generator = torch.manual_seed(0) + image = pndm(generator=generator)["sample"] + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 32, 32) + expected_slice = torch.tensor( + [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ldm_text2img(self): + ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=20) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_ldm_text2img_fast(self): + ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=1) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_score_sde_ve_pipeline(self): + model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024") + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_ve(num_inference_steps=2) + + if model.device.type == "cpu": + # patrick's cpu + expected_image_sum = 3384805888.0 + expected_image_mean = 1076.00085 + + # m1 mbp + # expected_image_sum = 3384805376.0 + # expected_image_mean = 1076.000610351562 + else: + expected_image_sum = 3382849024.0 + expected_image_mean = 1075.3788 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + + @slow + def test_ldm_uncond(self): + ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") + + generator = torch.manual_seed(0) + image = ldm(generator=generator, num_inference_steps=5)["sample"] + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor( + [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] + ) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 From 84e94d7229ffefadef24219201752632a5e8d2bf Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 19 Jul 2022 15:57:11 -0700 Subject: [PATCH 02/63] match model forward api --- src/diffusers/models/unet_rl.py | 4 ++-- tests/test_modeling_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 786a80a38a6e..ebf6209e9382 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -130,7 +130,7 @@ def __init__( nn.Conv1d(dim, transition_dim, 1), ) - def forward(self, sample, timesteps): + def forward(self, sample, timestep): """ x : [ batch x horizon x transition ] """ @@ -138,7 +138,7 @@ def forward(self, sample, timesteps): x = x.permute(0, 2, 1) - t = self.time_mlp(timesteps) + t = self.time_mlp(timestep) h = [] for resnet, resnet2, downsample in self.downs: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index bbeb76e44503..06dd43c97aae 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -585,7 +585,7 @@ def dummy_input(self): noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) time_step = torch.tensor([10] * batch_size).to(torch_device) - return {"sample": noise, "timesteps": time_step} + return {"sample": noise, "timestep": time_step} @property def input_shape(self): From f67b036e862b34741282af0d9477c04326ea9cfb Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 18:33:17 -0700 Subject: [PATCH 03/63] add register_to_config, pass training tests --- src/diffusers/models/unet_rl.py | 3 ++- tests/test_modeling_utils.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index ebf6209e9382..ea47cc58934e 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,7 +5,7 @@ from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding @@ -57,6 +57,7 @@ def forward(self, x): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): + @register_to_config def __init__( self, training_horizon=128, diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 06dd43c97aae..f137d40fc38a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -595,6 +595,12 @@ def input_shape(self): def output_shape(self): return (4, 16, 14) + def test_ema_training(self): + pass + + def test_training(self): + pass + def prepare_init_args_and_inputs_for_common(self): init_dict = { "training_horizon": 128, From e42d1c05afc41a8938cf60796b16dab53c81aea7 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 13:49:10 -0400 Subject: [PATCH 04/63] fix tests, update forward outputs --- src/diffusers/__init__.py | 2 +- src/diffusers/models/unet_rl.py | 73 ++- tests/test_modeling_utils.py | 921 -------------------------------- tests/test_models_unet.py | 83 ++- 4 files changed, 136 insertions(+), 943 deletions(-) delete mode 100755 tests/test_modeling_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1cf64a4a2ebf..e6b920a31b4c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, TemporalUNet, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index ea47cc58934e..be668c9c02a3 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -1,4 +1,6 @@ # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py +from dataclasses import dataclass +from typing import Tuple, Union import torch import torch.nn as nn @@ -7,9 +9,21 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin +from ..utils import BaseOutput from .embeddings import get_timestep_embedding +@dataclass +class TemporalUNetOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -131,36 +145,55 @@ def __init__( nn.Conv1d(dim, transition_dim, 1), ) - def forward(self, sample, timestep): - """ - x : [ batch x horizon x transition ] + # def forward(self, sample, timestep): + # """ + # x : [ batch x horizon x transition ] #""" + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[TemporalUNetOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): TODO verify shape (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): TODO verify batch (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - x = sample - - x = x.permute(0, 2, 1) + # x = sample + sample = sample.permute(0, 2, 1) t = self.time_mlp(timestep) h = [] for resnet, resnet2, downsample in self.downs: - x = resnet(x, t) - x = resnet2(x, t) - h.append(x) - x = downsample(x) + sample = resnet(sample, t) + sample = resnet2(sample, t) + h.append(sample) + sample = downsample(sample) - x = self.mid_block1(x, t) - x = self.mid_block2(x, t) + sample = self.mid_block1(sample, t) + sample = self.mid_block2(sample, t) for resnet, resnet2, upsample in self.ups: - x = torch.cat((x, h.pop()), dim=1) - x = resnet(x, t) - x = resnet2(x, t) - x = upsample(x) + sample = torch.cat((sample, h.pop()), dim=1) + sample = resnet(sample, t) + sample = resnet2(sample, t) + sample = upsample(sample) - x = self.final_conv(x) + sample = self.final_conv(sample) - x = x.permute(0, 2, 1) - return x + sample = sample.permute(0, 2, 1) + + if not return_dict: + return (sample,) + + return TemporalUNetOutput(sample=sample) class TemporalValue(nn.Module): @@ -196,7 +229,7 @@ def __init__( [ ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), - Downsample1d(dim_out), + Downsample1D(dim_out), ] ) ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py deleted file mode 100755 index f137d40fc38a..000000000000 --- a/tests/test_modeling_utils.py +++ /dev/null @@ -1,921 +0,0 @@ -# coding=utf-8 -# Copyright 2022 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -import math -import tempfile -import unittest - -import numpy as np -import torch - -from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it -from diffusers import ( - AutoencoderKL, - DDIMPipeline, - DDIMScheduler, - DDPMPipeline, - DDPMScheduler, - LatentDiffusionPipeline, - LatentDiffusionUncondPipeline, - PNDMPipeline, - PNDMScheduler, - ScoreSdeVePipeline, - ScoreSdeVeScheduler, - TemporalUNet, - UNetUnconditionalModel, - VQModel, -) -from diffusers.configuration_utils import ConfigMixin -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.testing_utils import floats_tensor, slow, torch_device -from diffusers.training_utils import EMAModel - - -torch.backends.cuda.matmul.allow_tf32 = False - - -class ConfigTester(unittest.TestCase): - def test_load_not_from_mixin(self): - with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") - - def test_save_load(self): - class SampleObject(ConfigMixin): - config_name = "config.json" - - def __init__( - self, - a=2, - b=5, - c=(2, 5), - d="for diffusion", - e=[1, 3], - ): - self.register_to_config(a=a, b=b, c=c, d=d, e=e) - - obj = SampleObject() - config = obj.config - - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) - new_config = new_obj.config - - # unfreeze configs - config = dict(config) - new_config = dict(new_config) - - assert config.pop("c") == (2, 5) # instantiated as tuple - assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json - assert config == new_config - - -class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - if isinstance(image, dict): - image = image["sample"] - - new_image = new_model(**inputs_dict) - - if isinstance(new_image, dict): - new_image = new_image["sample"] - - max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") - - def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - with torch.no_grad(): - first = model(**inputs_dict) - if isinstance(first, dict): - first = first["sample"] - - second = model(**inputs_dict) - if isinstance(second, dict): - second = second["sample"] - - out_1 = first.cpu().numpy() - out_2 = second.cpu().numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["sample", "timestep"] - self.assertListEqual(arg_names[:2], expected_arg_names) - - def test_model_from_config(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all paramters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) - - with torch.no_grad(): - output_1 = model(**inputs_dict) - - if isinstance(output_1, dict): - output_1 = output_1["sample"] - - output_2 = new_model(**inputs_dict) - - if isinstance(output_2, dict): - output_2 = output_2["sample"] - - self.assertEqual(output_1.shape, output_2.shape) - - def test_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - - def test_ema_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - ema_model = EMAModel(model, device=torch_device) - - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - ema_model.step(model) - - -class UnetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_channels": (32, 64), - "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), - "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), - "num_head_channels": None, - "out_channels": 3, - "in_channels": 3, - "num_res_blocks": 2, - "image_size": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - -# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints -# def test_output_pretrained(self): -# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) -# time_step = torch.tensor([10]) -# -# with torch.no_grad(): -# output = model(noise, time_step)["sample"] -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) -# fmt: on -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 32, 32) - - @property - def output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "image_size": 32, - "in_channels": 4, - "out_channels": 4, - "num_res_blocks": 2, - "block_channels": (32, 64), - "num_head_channels": 32, - "conv_resample": True, - "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), - "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetUnconditionalModel.from_pretrained( - "fusing/unet-ldm-dummy-update", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input)["sample"] - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) - time_step = torch.tensor([10] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -# TODO(Patrick) - Re-add this test after having cleaned up LDM -# def test_output_pretrained_spatial_transformer(self): -# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) -# context = torch.ones((1, 16, 64), dtype=torch.float32) -# time_step = torch.tensor([10] * noise.shape[0]) -# -# with torch.no_grad(): -# output = model(noise, time_step, context=context) -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890]) -# fmt: on -# -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) -# - - -class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNetUnconditionalModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_channels": [32, 64, 64, 64], - "in_channels": 3, - "num_res_blocks": 1, - "out_channels": 3, - "time_embedding_type": "fourier", - "resnet_eps": 1e-6, - "mid_block_scale_factor": math.sqrt(2.0), - "resnet_num_groups": None, - "down_blocks": [ - "UNetResSkipDownBlock2D", - "UNetResAttnSkipDownBlock2D", - "UNetResSkipDownBlock2D", - "UNetResSkipDownBlock2D", - ], - "up_blocks": [ - "UNetResSkipUpBlock2D", - "UNetResSkipUpBlock2D", - "UNetResAttnSkipUpBlock2D", - "UNetResSkipUpBlock2D", - ], - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNetUnconditionalModel.from_pretrained( - "fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained_ve_mid(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256") - model.to(torch_device) - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - batch_size = 4 - num_channels = 3 - sizes = (256, 256) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -3:, -3:, -1].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - def test_output_pretrained_ve_large(self): - model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") - model.to(torch_device) - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -3:, -3:, -1].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class VQModelTests(ModelTesterMixin, unittest.TestCase): - model_class = VQModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - - return {"sample": image} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 64, - "out_ch": 3, - "num_res_blocks": 1, - "in_channels": 3, - "attn_resolutions": [], - "resolution": 32, - "z_channels": 3, - "n_embed": 256, - "embed_dim": 3, - "sane_index_shape": False, - "ch_mult": (1,), - "dropout": 0.0, - "double_z": False, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_forward_signature(self): - pass - - def test_training(self): - pass - - def test_from_pretrained_hub(self): - model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = VQModel.from_pretrained("fusing/vqgan-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - with torch.no_grad(): - output = model(image) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = TemporalUNet - - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 - - noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 16, 14) - - @property - def output_shape(self): - return (4, 16, 14) - - def test_ema_training(self): - pass - - def test_training(self): - pass - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "training_horizon": 128, - "dim": 32, - "dim_mults": [1, 4, 8], - "predict_epsilon": False, - "clip_denoised": True, - "transition_dim": 14, - "cond_dim": 3, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = TemporalUNet.from_pretrained( - "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = model.transition_dim - seq_len = 16 - noise = torch.randn((1, seq_len, num_features)) - time_step = torch.full((num_features,), 0) - - with torch.no_grad(): - output = model(noise, time_step) - - output_slice = output[0, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) - - -class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): - model_class = AutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - - return {"sample": image} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 64, - "ch_mult": (1,), - "embed_dim": 4, - "in_channels": 3, - "attn_resolutions": [], - "num_res_blocks": 1, - "out_ch": 3, - "resolution": 32, - "z_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_forward_signature(self): - pass - - def test_training(self): - pass - - def test_from_pretrained_hub(self): - model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - with torch.no_grad(): - output = model(image, sample_posterior=True) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class PipelineTesterMixin(unittest.TestCase): - def test_from_pretrained_save_pretrained(self): - # 1. Load models - model = UNetUnconditionalModel( - block_channels=(32, 64), - num_res_blocks=2, - image_size=32, - in_channels=3, - out_channels=3, - down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), - up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), - ) - schedular = DDPMScheduler(num_train_timesteps=10) - - ddpm = DDPMPipeline(model, schedular) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator)["sample"] - generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator)["sample"] - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_from_pretrained_hub(self): - model_path = "google/ddpm-cifar10-32" - - ddpm = DDPMPipeline.from_pretrained(model_path) - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) - - ddpm.scheduler.num_timesteps = 10 - ddpm_from_hub.scheduler.num_timesteps = 10 - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator)["sample"] - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator)["sample"] - - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_ddpm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNetUnconditionalModel.from_pretrained(model_id) - scheduler = DDPMScheduler.from_config(model_id) - scheduler = scheduler.set_format("pt") - - ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ddim_lsun(self): - model_id = "google/ddpm-ema-bedroom-256" - - unet = UNetUnconditionalModel.from_pretrained(model_id) - scheduler = DDIMScheduler.from_config(model_id) - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ddim_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNetUnconditionalModel.from_pretrained(model_id) - scheduler = DDIMScheduler(tensor_format="pt") - - ddim = DDIMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0)["sample"] - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_pndm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNetUnconditionalModel.from_pretrained(model_id) - scheduler = PNDMScheduler(tensor_format="pt") - - pndm = PNDMPipeline(unet=unet, scheduler=scheduler) - generator = torch.manual_seed(0) - image = pndm(generator=generator)["sample"] - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ldm_text2img(self): - ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_ldm_text2img_fast(self): - ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=1) - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 - - @slow - def test_score_sde_ve_pipeline(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024") - - sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) - - torch.manual_seed(0) - image = sde_ve(num_inference_steps=2) - - if model.device.type == "cpu": - # patrick's cpu - expected_image_sum = 3384805888.0 - expected_image_mean = 1076.00085 - - # m1 mbp - # expected_image_sum = 3384805376.0 - # expected_image_mean = 1076.000610351562 - else: - expected_image_sum = 3382849024.0 - expected_image_mean = 1075.3788 - - assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 - assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 - - @slow - def test_ldm_uncond(self): - ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") - - generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5)["sample"] - - image_slice = image[0, -1, -3:, -3:].cpu() - - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 94a186d1c06a..12f38ab4e557 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -18,7 +18,7 @@ import torch -from diffusers import UNet2DConditionModel, UNet2DModel +from diffusers import TemporalUNet, UNet2DConditionModel, UNet2DModel from diffusers.testing_utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -375,3 +375,84 @@ def test_output_pretrained_ve_large(self): def test_forward_with_norm_groups(self): # not required for this model pass + + +class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = TemporalUNet + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 16, 14) + + @property + def output_shape(self): + return (4, 16, 14) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "training_horizon": 128, + "dim": 32, + "dim_mults": [1, 4, 8], + "predict_epsilon": False, + "clip_denoised": True, + "transition_dim": 14, + "cond_dim": 3, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = TemporalUNet.from_pretrained( + "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.transition_dim + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)) + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step).sample + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass From 2dd514ea536468664c8ca07e9e3505da691507e4 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:16:57 -0400 Subject: [PATCH 05/63] remove unused code, some comments --- src/diffusers/models/unet_rl.py | 75 ++------------------------------- 1 file changed, 3 insertions(+), 72 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index be668c9c02a3..fd04fa7d9fb9 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -17,7 +17,7 @@ class TemporalUNetOutput(BaseOutput): """ Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + sample (`torch.FloatTensor` of shape `(batch, horizon, obs_dimension)`): Hidden states output. Output of last layer of model. """ @@ -59,10 +59,8 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), RearrangeDim(), - # Rearrange("batch channels horizon -> batch channels 1 horizon"), nn.GroupNorm(n_groups, out_channels), RearrangeDim(), - # Rearrange("batch channels 1 horizon -> batch channels horizon"), nn.Mish(), ) @@ -156,8 +154,8 @@ def forward( ) -> Union[TemporalUNetOutput, Tuple]: """r Args: - sample (`torch.FloatTensor`): TODO verify shape (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): TODO verify batch (batch) timesteps + sample (`torch.FloatTensor`): (batch, horizon, obs_dimension) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. @@ -165,7 +163,6 @@ def forward( [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - # x = sample sample = sample.permute(0, 2, 1) t = self.time_mlp(timestep) @@ -194,69 +191,3 @@ def forward( return (sample,) return TemporalUNetOutput(sample=sample) - - -class TemporalValue(nn.Module): - def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - time_dim=None, - out_dim=1, - dim_mults=(1, 2, 4, 8), - ): - super().__init__() - - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - time_dim = time_dim or dim - self.time_mlp = nn.Sequential( - SinusoidalPosEmb(dim), - nn.Linear(dim, dim * 4), - nn.Mish(), - nn.Linear(dim * 4, dim), - ) - - self.blocks = nn.ModuleList([]) - - print(in_out) - for dim_in, dim_out in in_out: - self.blocks.append( - nn.ModuleList( - [ - ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), - Downsample1D(dim_out), - ] - ) - ) - - horizon = horizon // 2 - - fc_dim = dims[-1] * max(horizon, 1) - - self.final_block = nn.Sequential( - nn.Linear(fc_dim + time_dim, fc_dim // 2), - nn.Mish(), - nn.Linear(fc_dim // 2, out_dim), - ) - - def forward(self, x, cond, time, *args): - """ - x : [ batch x horizon x transition ] - """ - x = x.permute(0, 2, 1) - - t = self.time_mlp(time) - - for resnet, resnet2, downsample in self.blocks: - x = resnet(x, t) - x = resnet2(x, t) - x = downsample(x) - - x = x.view(len(x), -1) - out = self.final_block(torch.cat([x, t], dim=-1)) - return out From b4c6188998773ca0563461521acbbb880427b50c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:17:54 -0400 Subject: [PATCH 06/63] add to docs --- docs/source/api/models.mdx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c92fdccb8333..98687b5e7038 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -34,6 +34,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## DecoderOutput [[autodoc]] models.vae.DecoderOutput +## TemporalUNet +[[autodoc]] TemporalUNet + ## VQEncoderOutput [[autodoc]] models.vae.VQEncoderOutput From c53bba903626691c143ebb8e3a2a65ac65f5c129 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 6 Oct 2022 15:56:15 -0700 Subject: [PATCH 07/63] remove extra embedding code --- src/diffusers/models/unet_rl.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index fd04fa7d9fb9..a8354cdb64ec 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -10,8 +10,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .embeddings import get_timestep_embedding - +from .embeddings import get_timestep_embedding, Timesteps @dataclass class TemporalUNetOutput(BaseOutput): @@ -24,14 +23,6 @@ class TemporalUNetOutput(BaseOutput): sample: torch.FloatTensor -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - return get_timestep_embedding(x, self.dim) - class RearrangeDim(nn.Module): def __init__(self): @@ -92,7 +83,7 @@ def __init__( time_dim = dim self.time_mlp = nn.Sequential( - SinusoidalPosEmb(dim), + Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1), nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim), From effcbdbe95182b7f414786e42db5c6e192e3c2f0 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 7 Oct 2022 16:47:56 -0700 Subject: [PATCH 08/63] unify time embedding --- src/diffusers/models/embeddings.py | 11 +++++-- src/diffusers/models/resnet.py | 27 +++++++++++----- src/diffusers/models/unet_rl.py | 49 ++++++++++++++++++------------ 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 06b814e2bbcd..7d2e1b677a9f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -62,14 +62,21 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() self.linear_1 = nn.Linear(channel, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + if act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) def forward(self, sample): sample = self.linear_1(sample) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 9b52681c3b99..7fd0f1db3d36 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -474,6 +474,16 @@ def forward(self, tensor): else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + # unet_rl.py class ResidualTemporalBlock(nn.Module): @@ -486,13 +496,14 @@ def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5 Conv1dBlock(out_channels, out_channels, kernel_size), ] ) + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) - self.time_mlp = nn.Sequential( - nn.Mish(), - nn.Linear(embed_dim, out_channels), - RearrangeDim(), - # Rearrange("batch t -> batch t 1"), - ) + # self.time_mlp = nn.Sequential( + # nn.Mish(), + # nn.Linear(embed_dim, out_channels), + # RearrangeDim(), + # ) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() @@ -503,7 +514,9 @@ def forward(self, x, t): x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] """ - out = self.blocks[0](x) + self.time_mlp(t) + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.blocks[0](x) + rearrange_dims(t) out = self.blocks[1](out) return out + self.residual_conv(x) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index a8354cdb64ec..e9fcc4c8535b 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -10,7 +10,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .embeddings import get_timestep_embedding, Timesteps +from .embeddings import get_timestep_embedding, Timesteps, TimestepEmbedding @dataclass class TemporalUNetOutput(BaseOutput): @@ -78,17 +78,12 @@ def __init__( self.predict_epsilon = predict_epsilon self.clip_denoised = clip_denoised + self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) + self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4*dim, act_fn="mish", out_dim=dim) + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - time_dim = dim - self.time_mlp = nn.Sequential( - Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1), - nn.Linear(dim, dim * 4), - nn.Mish(), - nn.Linear(dim * 4, dim), - ) - self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) @@ -99,8 +94,8 @@ def __init__( self.downs.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim, horizon=training_horizon), Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), ] ) @@ -110,8 +105,8 @@ def __init__( training_horizon = training_horizon // 2 mid_dim = dims[-1] - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim, horizon=training_horizon) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim, horizon=training_horizon) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) @@ -119,8 +114,8 @@ def __init__( self.ups.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=dim, horizon=training_horizon), Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), ] ) @@ -134,9 +129,6 @@ def __init__( nn.Conv1d(dim, transition_dim, 1), ) - # def forward(self, sample, timestep): - # """ - # x : [ batch x horizon x transition ] #""" def forward( self, sample: torch.FloatTensor, @@ -145,7 +137,7 @@ def forward( ) -> Union[TemporalUNetOutput, Tuple]: """r Args: - sample (`torch.FloatTensor`): (batch, horizon, obs_dimension) noisy inputs tensor + sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. @@ -156,24 +148,41 @@ def forward( """ sample = sample.permute(0, 2, 1) - t = self.time_mlp(timestep) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # t = self.time_mlp(timesteps) + t = self.time_proj(timesteps) + t = self.time_mlp(t) + # t = self.time_embedding(timesteps) + # t = self.time_emb_lin1(t) + # t = self.time_emb_act(t) + # t = self.time_emb_lin2(t) h = [] + # 2. down for resnet, resnet2, downsample in self.downs: sample = resnet(sample, t) sample = resnet2(sample, t) h.append(sample) sample = downsample(sample) + # 3. mid sample = self.mid_block1(sample, t) sample = self.mid_block2(sample, t) + # 4. up for resnet, resnet2, upsample in self.ups: sample = torch.cat((sample, h.pop()), dim=1) sample = resnet(sample, t) sample = resnet2(sample, t) sample = upsample(sample) + # 5. post-process sample = self.final_conv(sample) sample = sample.permute(0, 2, 1) From 78652313cff7f59f02997182f3f282c563d129bf Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 7 Oct 2022 17:16:27 -0700 Subject: [PATCH 09/63] remove conv1d output sequential --- src/diffusers/models/resnet.py | 40 ++++++++++++++++----------------- src/diffusers/models/unet_rl.py | 23 ++++++++++--------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7fd0f1db3d36..728c662ffdf6 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -438,27 +438,6 @@ def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - RearrangeDim(), - # Rearrange("batch channels horizon -> batch channels 1 horizon"), - nn.GroupNorm(n_groups, out_channels), - RearrangeDim(), - # Rearrange("batch channels 1 horizon -> batch channels horizon"), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - class RearrangeDim(nn.Module): def __init__(self): @@ -484,6 +463,25 @@ def rearrange_dims(tensor): else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + RearrangeDim(), + nn.GroupNorm(n_groups, out_channels), + RearrangeDim(), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + # unet_rl.py class ResidualTemporalBlock(nn.Module): diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index e9fcc4c8535b..cf6079af35df 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import get_timestep_embedding, Timesteps, TimestepEmbedding +from .resnet import rearrange_dims @dataclass class TemporalUNetOutput(BaseOutput): @@ -59,7 +60,7 @@ def forward(self, x): return self.block(x) -class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): +class TemporalUNet(ModelMixin, ConfigMixin): @register_to_config def __init__( self, @@ -124,10 +125,10 @@ def __init__( if not is_last: training_horizon = training_horizon * 2 - self.final_conv = nn.Sequential( - Conv1dBlock(dim, dim, kernel_size=5), - nn.Conv1d(dim, transition_dim, 1), - ) + self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(8, dim) + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(dim, transition_dim, 1) def forward( self, @@ -155,13 +156,8 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - # t = self.time_mlp(timesteps) t = self.time_proj(timesteps) t = self.time_mlp(t) - # t = self.time_embedding(timesteps) - # t = self.time_emb_lin1(t) - # t = self.time_emb_act(t) - # t = self.time_emb_lin2(t) h = [] # 2. down @@ -183,7 +179,12 @@ def forward( sample = upsample(sample) # 5. post-process - sample = self.final_conv(sample) + sample = self.final_conv1d_1(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_gn(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_act(sample) + sample = self.final_conv1d_2(sample) sample = sample.permute(0, 2, 1) From 35b0a43c6ff3a90bebefb7a85a0f3e28a09ff4ca Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 7 Oct 2022 17:43:13 -0700 Subject: [PATCH 10/63] remove sequential from conv1dblock --- src/diffusers/models/resnet.py | 26 +++++++++++--------------- src/diffusers/models/unet_rl.py | 17 ----------------- 2 files changed, 11 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 728c662ffdf6..ed54635fc398 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -471,17 +471,19 @@ class Conv1dBlock(nn.Module): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - RearrangeDim(), - nn.GroupNorm(n_groups, out_channels), - RearrangeDim(), - nn.Mish(), - ) - def forward(self, x): - return self.block(x) + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x # unet_rl.py class ResidualTemporalBlock(nn.Module): @@ -497,12 +499,6 @@ def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5 self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) - # self.time_mlp = nn.Sequential( - # nn.Mish(), - # nn.Linear(embed_dim, out_channels), - # RearrangeDim(), - # ) - self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index cf6079af35df..85e459fe6e5d 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -40,24 +40,7 @@ def forward(self, tensor): raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - RearrangeDim(), - nn.GroupNorm(n_groups, out_channels), - RearrangeDim(), - nn.Mish(), - ) - def forward(self, x): - return self.block(x) class TemporalUNet(ModelMixin, ConfigMixin): From 9b1379d40f91c5626aca91a27a62cfa7428282a5 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 7 Oct 2022 17:45:27 -0700 Subject: [PATCH 11/63] style and deleting duplicated code --- src/diffusers/models/resnet.py | 19 +++---------------- src/diffusers/models/unet_rl.py | 23 +++-------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ed54635fc398..831bb02eb566 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -439,20 +439,7 @@ def forward(self, x): -class RearrangeDim(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - +# unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: return tensor[:, :, None] @@ -463,6 +450,7 @@ def rearrange_dims(tensor): else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish @@ -471,12 +459,10 @@ class Conv1dBlock(nn.Module): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() - self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() - def forward(self, x): x = self.conv1d(x) x = rearrange_dims(x) @@ -485,6 +471,7 @@ def forward(self, x): x = self.mish(x) return x + # unet_rl.py class ResidualTemporalBlock(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 85e459fe6e5d..69f4c4cd37ed 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -10,9 +10,10 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .embeddings import get_timestep_embedding, Timesteps, TimestepEmbedding +from .embeddings import TimestepEmbedding, Timesteps from .resnet import rearrange_dims + @dataclass class TemporalUNetOutput(BaseOutput): """ @@ -25,24 +26,6 @@ class TemporalUNetOutput(BaseOutput): -class RearrangeDim(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - - - - class TemporalUNet(ModelMixin, ConfigMixin): @register_to_config def __init__( @@ -63,7 +46,7 @@ def __init__( self.clip_denoised = clip_denoised self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4*dim, act_fn="mish", out_dim=dim) + self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4 * dim, act_fn="mish", out_dim=dim) dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) From e97a61066d4093a22c741aa2ca8ed7fd5300d8cc Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sat, 8 Oct 2022 09:09:48 -0700 Subject: [PATCH 12/63] clean files --- src/diffusers/models/resnet.py | 11 +++++++---- src/diffusers/models/unet_rl.py | 31 ++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 831bb02eb566..c4649647dd41 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -438,7 +438,6 @@ def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) - # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: @@ -474,7 +473,7 @@ def forward(self, x): # unet_rl.py class ResidualTemporalBlock(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() self.blocks = nn.ModuleList( @@ -492,8 +491,12 @@ def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5 def forward(self, x, t): """ - x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x - out_channels x horizon ] + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] """ t = self.time_emb_act(t) t = self.time_emb(t) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 69f4c4cd37ed..420a1661d526 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -25,8 +25,20 @@ class TemporalUNetOutput(BaseOutput): sample: torch.FloatTensor - class TemporalUNet(ModelMixin, ConfigMixin): + """ + A UNet for multi-dimensional temporal data. This model takes the batch over the `training_horizon`. + + Parameters: + training_horizon: horizon of training samples used for diffusion process. + transition_dim: state-dimension of samples to predict over + cond_dim: held dimension in input (e.g. for actions) -- TODO remove from pretrained + predict_epsilon: TODO remove from pretrained + clip_denoised: TODO remove from pretrained + dim: embedding dimension of model + dim_mults: dimension multiples of the up/down blocks + """ + @register_to_config def __init__( self, @@ -45,6 +57,7 @@ def __init__( self.predict_epsilon = predict_epsilon self.clip_denoised = clip_denoised + # time self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4 * dim, act_fn="mish", out_dim=dim) @@ -55,14 +68,15 @@ def __init__( self.ups = nn.ModuleList([]) num_resolutions = len(in_out) + # down for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim, horizon=training_horizon), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim), Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), ] ) @@ -71,18 +85,20 @@ def __init__( if not is_last: training_horizon = training_horizon // 2 + # mid mid_dim = dims[-1] - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim, horizon=training_horizon) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim, horizon=training_horizon) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) + # up for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=dim, horizon=training_horizon), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=dim), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=dim), Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), ] ) @@ -91,6 +107,7 @@ def __init__( if not is_last: training_horizon = training_horizon * 2 + # out self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(8, dim) self.final_conv1d_act = nn.Mish() From 8642560db0223e1a42d270d584e94e2cadbdf4ed Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Oct 2022 11:10:37 -0700 Subject: [PATCH 13/63] remove unused variables --- src/diffusers/models/unet_rl.py | 11 +---------- tests/test_models_unet.py | 4 ---- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 420a1661d526..c926ef3a67c5 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -32,9 +32,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): Parameters: training_horizon: horizon of training samples used for diffusion process. transition_dim: state-dimension of samples to predict over - cond_dim: held dimension in input (e.g. for actions) -- TODO remove from pretrained - predict_epsilon: TODO remove from pretrained - clip_denoised: TODO remove from pretrained dim: embedding dimension of model dim_mults: dimension multiples of the up/down blocks """ @@ -44,18 +41,12 @@ def __init__( self, training_horizon=128, transition_dim=14, - cond_dim=3, - predict_epsilon=False, - clip_denoised=True, dim=32, dim_mults=(1, 4, 8), ): super().__init__() self.transition_dim = transition_dim - self.cond_dim = cond_dim - self.predict_epsilon = predict_epsilon - self.clip_denoised = clip_denoised # time self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) @@ -119,7 +110,7 @@ def forward( timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[TemporalUNetOutput, Tuple]: - """r + r""" Args: sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 12f38ab4e557..f9390bcbdc33 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -407,13 +407,9 @@ def test_training(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "training_horizon": 128, "dim": 32, "dim_mults": [1, 4, 8], - "predict_epsilon": False, - "clip_denoised": True, "transition_dim": 14, - "cond_dim": 3, } inputs_dict = self.dummy_input return init_dict, inputs_dict From f58c91529c431946c53b82aa0a67366f2f3ddc2c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Oct 2022 11:26:19 -0700 Subject: [PATCH 14/63] clean variables --- src/diffusers/models/unet_rl.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index c926ef3a67c5..5560bd90371e 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -30,7 +30,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): A UNet for multi-dimensional temporal data. This model takes the batch over the `training_horizon`. Parameters: - training_horizon: horizon of training samples used for diffusion process. transition_dim: state-dimension of samples to predict over dim: embedding dimension of model dim_mults: dimension multiples of the up/down blocks @@ -39,7 +38,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - training_horizon=128, transition_dim=14, dim=32, dim_mults=(1, 4, 8), @@ -55,15 +53,15 @@ def __init__( dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) num_resolutions = len(in_out) # down for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - self.downs.append( + self.down_blocks.append( nn.ModuleList( [ ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim), @@ -73,9 +71,6 @@ def __init__( ) ) - if not is_last: - training_horizon = training_horizon // 2 - # mid mid_dim = dims[-1] self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) @@ -85,7 +80,7 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) - self.ups.append( + self.up_blocks.append( nn.ModuleList( [ ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=dim), @@ -95,9 +90,6 @@ def __init__( ) ) - if not is_last: - training_horizon = training_horizon * 2 - # out self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(8, dim) @@ -135,7 +127,7 @@ def forward( h = [] # 2. down - for resnet, resnet2, downsample in self.downs: + for resnet, resnet2, downsample in self.down_blocks: sample = resnet(sample, t) sample = resnet2(sample, t) h.append(sample) @@ -146,7 +138,7 @@ def forward( sample = self.mid_block2(sample, t) # 4. up - for resnet, resnet2, upsample in self.ups: + for resnet, resnet2, upsample in self.up_blocks: sample = torch.cat((sample, h.pop()), dim=1) sample = resnet(sample, t) sample = resnet2(sample, t) From 3b08bea39ff8a754842aed6c6f729ea1b9710a9c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Oct 2022 15:14:55 -0700 Subject: [PATCH 15/63] add 1d resnet block structure for downsample --- src/diffusers/models/resnet.py | 29 ++++++++----- src/diffusers/models/unet_blocks.py | 66 ++++++++++++++++++++++++++++- src/diffusers/models/unet_rl.py | 33 +++++++++------ 3 files changed, 104 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b83c4367e7bf..363fed372cbc 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -7,10 +7,13 @@ class Upsample1D(nn.Module): """ - An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param - use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. - If 3D, then - upsampling occurs in the inner-two dimensions. + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -21,7 +24,6 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann self.use_conv_transpose = use_conv_transpose self.name = name - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) @@ -43,10 +45,13 @@ def forward(self, x): class Downsample1D(nn.Module): """ - A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param - use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. - If 3D, then - downsampling occurs in the inner-two dimensions. + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -76,7 +81,8 @@ class Upsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. + use_conv_transpose: + out_channels: """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -129,7 +135,8 @@ class Downsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. + out_channels: + padding: """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index a17b1d2a5333..a5d3fa499dc7 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -15,10 +15,19 @@ # limitations under the License. import torch +import torch.nn.functional as F from torch import nn from .attention import AttentionBlock, SpatialTransformer -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D +from .resnet import ( + Downsample1D, + Downsample2D, + FirDownsample2D, + FirUpsample2D, + ResidualTemporalBlock, + ResnetBlock2D, + Upsample2D, +) def get_down_block( @@ -460,6 +469,61 @@ def forward(self, hidden_states, temb=None): return hidden_states, output_states +class DownResnetBlock1D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnet1(hidden_states, temb) + hidden_states = self.resnet2(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + class CrossAttnDownBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 5560bd90371e..d4432d2f562f 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn -from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock, Upsample1D +from diffusers.models.resnet import ResidualTemporalBlock, Upsample1D +from diffusers.models.unet_blocks import DownResnetBlock1D from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin @@ -61,13 +62,18 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) + # self.down_blocks.append( + # nn.ModuleList( + # [ + # ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim), + # ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim), + # Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), + # ] + # ) + # ) self.down_blocks.append( - nn.ModuleList( - [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim), - Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), - ] + DownResnetBlock1D( + in_channels=dim_in, out_channels=dim_out, temb_channels=dim, add_downsample=(not is_last) ) ) @@ -127,11 +133,14 @@ def forward( h = [] # 2. down - for resnet, resnet2, downsample in self.down_blocks: - sample = resnet(sample, t) - sample = resnet2(sample, t) - h.append(sample) - sample = downsample(sample) + # for resnet, resnet2, downsample in self.down_blocks: + # sample = resnet(sample, t) + # sample = resnet2(sample, t) + # h.append(sample) + # sample = downsample(sample) + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=t) + h.append(res_samples[0]) # 3. mid sample = self.mid_block1(sample, t) From aae2a9a69f329e9f54496c2f2b472ae1597a2fc8 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Oct 2022 15:38:52 -0700 Subject: [PATCH 16/63] rename as unet1d --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/unet_blocks.py | 53 ++++++++++++++++++++++++++++- src/diffusers/models/unet_rl.py | 51 +++++++-------------------- tests/test_models_unet.py | 8 ++--- 5 files changed, 71 insertions(+), 45 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8267aaff73fe..fa97effaaf0a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, TemporalUNet, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 47f7fa71682b..dc0946cf4d54 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .unet_rl import TemporalUNet + from .unet_rl import UNet1DModel from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index a5d3fa499dc7..64dcac75c232 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -26,6 +26,7 @@ FirUpsample2D, ResidualTemporalBlock, ResnetBlock2D, + Upsample1D, Upsample2D, ) @@ -488,7 +489,6 @@ def __init__( self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels - self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.add_downsample = add_downsample @@ -523,6 +523,57 @@ def forward(self, hidden_states, temb=None): return hidden_states, output_states +class UpResnetBlock1D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states=None, temb=None): + if res_hidden_states is not None: + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnet1(hidden_states, temb) + hidden_states = self.resnet2(hidden_states, temb) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + class CrossAttnDownBlock2D(nn.Module): def __init__( diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index d4432d2f562f..b62a76ff27c1 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -6,7 +6,7 @@ import torch.nn as nn from diffusers.models.resnet import ResidualTemporalBlock, Upsample1D -from diffusers.models.unet_blocks import DownResnetBlock1D +from diffusers.models.unet_blocks import DownResnetBlock1D, UpResnetBlock1D from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin @@ -16,7 +16,7 @@ @dataclass -class TemporalUNetOutput(BaseOutput): +class UNet1DOutput(BaseOutput): """ Args: sample (`torch.FloatTensor` of shape `(batch, horizon, obs_dimension)`): @@ -26,7 +26,7 @@ class TemporalUNetOutput(BaseOutput): sample: torch.FloatTensor -class TemporalUNet(ModelMixin, ConfigMixin): +class UNet1DModel(ModelMixin, ConfigMixin): """ A UNet for multi-dimensional temporal data. This model takes the batch over the `training_horizon`. @@ -62,15 +62,6 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) - # self.down_blocks.append( - # nn.ModuleList( - # [ - # ResidualTemporalBlock(dim_in, dim_out, embed_dim=dim), - # ResidualTemporalBlock(dim_out, dim_out, embed_dim=dim), - # Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(), - # ] - # ) - # ) self.down_blocks.append( DownResnetBlock1D( in_channels=dim_in, out_channels=dim_out, temb_channels=dim, add_downsample=(not is_last) @@ -86,15 +77,7 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) - self.up_blocks.append( - nn.ModuleList( - [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=dim), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=dim), - Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(), - ] - ) - ) + self.up_blocks.append(UpResnetBlock1D(in_channels=dim_out*2, out_channels=dim_in, temb_channels=dim, add_upsample=(not is_last))) # out self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) @@ -128,30 +111,22 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - t = self.time_proj(timesteps) - t = self.time_mlp(t) - h = [] + temb = self.time_proj(timesteps) + temb = self.time_mlp(temb) + down_block_res_samples = [] # 2. down - # for resnet, resnet2, downsample in self.down_blocks: - # sample = resnet(sample, t) - # sample = resnet2(sample, t) - # h.append(sample) - # sample = downsample(sample) for downsample_block in self.down_blocks: - sample, res_samples = downsample_block(hidden_states=sample, temb=t) - h.append(res_samples[0]) + sample, res_samples = downsample_block(hidden_states=sample, temb=temb) + down_block_res_samples.append(res_samples[0]) # 3. mid - sample = self.mid_block1(sample, t) - sample = self.mid_block2(sample, t) + sample = self.mid_block1(sample, temb) + sample = self.mid_block2(sample, temb) # 4. up - for resnet, resnet2, upsample in self.up_blocks: - sample = torch.cat((sample, h.pop()), dim=1) - sample = resnet(sample, t) - sample = resnet2(sample, t) - sample = upsample(sample) + for up_block in self.up_blocks: + sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb) # 5. post-process sample = self.final_conv1d_1(sample) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index e78cca2b7537..4ff1ebc6d241 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import TemporalUNet, UNet2DConditionModel, UNet2DModel +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -451,7 +451,7 @@ def test_forward_with_norm_groups(self): class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = TemporalUNet + model_class = UNet1DModel @property def dummy_input(self): @@ -488,7 +488,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = TemporalUNet.from_pretrained( + model, loading_info = UNet1DModel.from_pretrained( "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True ) self.assertIsNotNone(model) @@ -500,7 +500,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") + model = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") model.eval() torch.manual_seed(0) From dd872aff4d5d1980c8f7c4d29202d4299b36c05d Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 10 Oct 2022 15:41:49 -0700 Subject: [PATCH 17/63] fix renaming --- src/diffusers/models/unet_blocks.py | 1 + src/diffusers/models/unet_rl.py | 12 ++++++++---- tests/test_models_unet.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 64dcac75c232..413500338d48 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -523,6 +523,7 @@ def forward(self, hidden_states, temb=None): return hidden_states, output_states + class UpResnetBlock1D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index b62a76ff27c1..aeaabc1b3719 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from diffusers.models.resnet import ResidualTemporalBlock, Upsample1D +from diffusers.models.resnet import ResidualTemporalBlock from diffusers.models.unet_blocks import DownResnetBlock1D, UpResnetBlock1D from ..configuration_utils import ConfigMixin, register_to_config @@ -77,7 +77,11 @@ def __init__( for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) - self.up_blocks.append(UpResnetBlock1D(in_channels=dim_out*2, out_channels=dim_in, temb_channels=dim, add_upsample=(not is_last))) + self.up_blocks.append( + UpResnetBlock1D( + in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, add_upsample=(not is_last) + ) + ) # out self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) @@ -90,7 +94,7 @@ def forward( sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, - ) -> Union[TemporalUNetOutput, Tuple]: + ) -> Union[UNet1DOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor @@ -141,4 +145,4 @@ def forward( if not return_dict: return (sample,) - return TemporalUNetOutput(sample=sample) + return UNet1DOutput(sample=sample) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 4ff1ebc6d241..d822153cd44c 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -450,7 +450,7 @@ def test_forward_with_norm_groups(self): pass -class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): +class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet1DModel @property From 9b67bb77375431742aaa3fca86cdae531816e223 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 11 Oct 2022 19:18:24 -0700 Subject: [PATCH 18/63] rename files --- docs/source/api/models.mdx | 7 +- src/diffusers/models/__init__.py | 2 +- .../models/{unet_rl.py => unet_1d.py} | 21 ++- src/diffusers/models/unet_1d_blocks.py | 126 ++++++++++++++++++ src/diffusers/models/unet_2d.py | 2 +- .../{unet_blocks.py => unet_2d_blocks.py} | 121 +---------------- src/diffusers/models/unet_2d_condition.py | 2 +- src/diffusers/models/vae.py | 2 +- 8 files changed, 154 insertions(+), 129 deletions(-) rename src/diffusers/models/{unet_rl.py => unet_1d.py} (85%) create mode 100644 src/diffusers/models/unet_1d_blocks.py rename src/diffusers/models/{unet_blocks.py => unet_2d_blocks.py} (93%) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 98687b5e7038..b944a1d13089 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -34,8 +34,11 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## DecoderOutput [[autodoc]] models.vae.DecoderOutput -## TemporalUNet -[[autodoc]] TemporalUNet +## UNet1DModel +[[autodoc]] UNet1DModel + +## UNet1DOutput +[[autodoc]] UNet1DOutput ## VQEncoderOutput [[autodoc]] models.vae.VQEncoderOutput diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc0946cf4d54..c5d53b2feb4b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,9 +16,9 @@ if is_torch_available(): + from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .unet_rl import UNet1DModel from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_1d.py similarity index 85% rename from src/diffusers/models/unet_rl.py rename to src/diffusers/models/unet_1d.py index aeaabc1b3719..c32f2de62091 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_1d.py @@ -1,4 +1,17 @@ -# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Tuple, Union @@ -6,7 +19,7 @@ import torch.nn as nn from diffusers.models.resnet import ResidualTemporalBlock -from diffusers.models.unet_blocks import DownResnetBlock1D, UpResnetBlock1D +from diffusers.models.unet_1d_blocks import DownResnetBlock1D, UpResnetBlock1D from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin @@ -120,8 +133,8 @@ def forward( down_block_res_samples = [] # 2. down - for downsample_block in self.down_blocks: - sample, res_samples = downsample_block(hidden_states=sample, temb=temb) + for down_block in self.down_blocks: + sample, res_samples = down_block(hidden_states=sample, temb=temb) down_block_res_samples.append(res_samples[0]) # 3. mid diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py new file mode 100644 index 000000000000..8df7e81b5fd2 --- /dev/null +++ b/src/diffusers/models/unet_1d_blocks.py @@ -0,0 +1,126 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +from torch import nn + +from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnet1(hidden_states, temb) + hidden_states = self.resnet2(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states=None, temb=None): + if res_hidden_states is not None: + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnet1(hidden_states, temb) + hidden_states = self.resnet2(hidden_states, temb) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2415bf4ee78d..d423cbc02fae 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -8,7 +8,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_2d_blocks.py similarity index 93% rename from src/diffusers/models/unet_blocks.py rename to src/diffusers/models/unet_2d_blocks.py index 413500338d48..4aec1ede0275 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -10,25 +10,14 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +# limitations under the License. import numpy as np - -# limitations under the License. import torch -import torch.nn.functional as F from torch import nn from .attention import AttentionBlock, SpatialTransformer -from .resnet import ( - Downsample1D, - Downsample2D, - FirDownsample2D, - FirUpsample2D, - ResidualTemporalBlock, - ResnetBlock2D, - Upsample1D, - Upsample2D, -) +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D def get_down_block( @@ -470,112 +459,6 @@ def forward(self, hidden_states, temb=None): return hidden_states, output_states -class DownResnetBlock1D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.add_downsample = add_downsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) - - def forward(self, hidden_states, temb=None): - output_states = () - - hidden_states = self.resnet1(hidden_states, temb) - hidden_states = self.resnet2(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsample is not None: - hidden_states = self.downsample(hidden_states) - - return hidden_states, output_states - - -class UpResnetBlock1D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.time_embedding_norm = time_embedding_norm - self.add_upsample = add_upsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = None - if add_upsample: - self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - - def forward(self, hidden_states, res_hidden_states=None, temb=None): - if res_hidden_states is not None: - hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) - - hidden_states = self.resnet1(hidden_states, temb) - hidden_states = self.resnet2(hidden_states, temb) - - if self.upsample is not None: - hidden_states = self.upsample(hidden_states) - - return hidden_states - - class CrossAttnDownBlock2D(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4e4eaddf5dfe..7f7d6d541f9e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -9,7 +9,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import ( +from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 7ce2f98eee27..515c9f17a4c9 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -8,7 +8,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass From db012ebe9f065d31a3e3033502cdbcf43e64a39b Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 11 Oct 2022 20:15:08 -0700 Subject: [PATCH 19/63] add get_block(...) api --- src/diffusers/models/unet_1d.py | 60 ++++++++++++++++---------- src/diffusers/models/unet_1d_blocks.py | 47 +++++++++++++++++++- 2 files changed, 82 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c32f2de62091..c5d6c8afbe3d 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from dataclasses import dataclass from typing import Tuple, Union @@ -19,7 +18,7 @@ import torch.nn as nn from diffusers.models.resnet import ResidualTemporalBlock -from diffusers.models.unet_1d_blocks import DownResnetBlock1D, UpResnetBlock1D +from diffusers.models.unet_1d_blocks import get_down_block, get_up_block from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin @@ -55,52 +54,67 @@ def __init__( transition_dim=14, dim=32, dim_mults=(1, 4, 8), + in_channels: int = 14, + out_channels: int = 14, + down_block_types: Tuple[str] = ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], + up_block_types: Tuple[str] = ["UpResnetBlock1D", "UpResnetBlock1D"], + block_out_channels: Tuple[int] = [32, 128, 256], ): super().__init__() - self.transition_dim = transition_dim + self.transition_dim = in_channels # time self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4 * dim, act_fn="mish", out_dim=dim) - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) - num_resolutions = len(in_out) + mid_dim = block_out_channels[-1] # down - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.down_blocks.append( - DownResnetBlock1D( - in_channels=dim_in, out_channels=dim_out, temb_channels=dim, add_downsample=(not is_last) - ) + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block_type = down_block_types[i] + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=dim, + add_downsample=not is_final_block, ) + self.down_blocks.append(down_block) # mid - mid_dim = dims[-1] self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) # up - for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (num_resolutions - 1) - - self.up_blocks.append( - UpResnetBlock1D( - in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, add_upsample=(not is_last) - ) + reversed_block_out_channels = list(reversed(block_out_channels)) + for i, up_block_type in enumerate(up_block_types): + input_channel = reversed_block_out_channels[i] + output_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + in_channels=input_channel * 2, + out_channels=output_channel, + temb_channels=dim, + add_upsample=not is_final_block, ) + self.up_blocks.append(up_block) # out self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(8, dim) self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(dim, transition_dim, 1) + self.final_conv1d_2 = nn.Conv1d(dim, out_channels, 1) def forward( self, diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 8df7e81b5fd2..cfdc1f762d47 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -23,7 +23,6 @@ class DownResnetBlock1D(nn.Module): def __init__( self, - *, in_channels, out_channels=None, conv_shortcut=False, @@ -77,7 +76,6 @@ def forward(self, hidden_states, temb=None): class UpResnetBlock1D(nn.Module): def __init__( self, - *, in_channels, out_channels=None, temb_channels=32, @@ -124,3 +122,48 @@ def forward(self, hidden_states, res_hidden_states=None, temb=None): hidden_states = self.upsample(hidden_states) return hidden_states + + +class DownBlock1D(nn.Module): + pass + + +class AttnDownBlock1D(nn.Module): + pass + + +class DownBlock1DNoSkip(nn.Module): + pass + + +class UpBlock1D(nn.Module): + pass + + +class AttnUpBlock1D(nn.Module): + pass + + +class UpBlock1DNoSkip(nn.Module): + pass + + +def get_down_block(down_block_type, in_channels, out_channels, temb_channels, add_downsample): + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, in_channels, out_channels, temb_channels, add_upsample): + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample + ) + + raise ValueError(f"{up_block_type} does not exist.") From 4db6e0b552d158e5258fc2e3d1eb01dee680ca69 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 09:59:23 -0700 Subject: [PATCH 20/63] unify args for model1d like model2d --- src/diffusers/models/unet_1d.py | 34 ++++++++++++++------------ src/diffusers/models/unet_1d_blocks.py | 10 ++++++++ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c5d6c8afbe3d..96de7a101d84 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -40,7 +40,10 @@ class UNet1DOutput(BaseOutput): class UNet1DModel(ModelMixin, ConfigMixin): """ - A UNet for multi-dimensional temporal data. This model takes the batch over the `training_horizon`. + UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) Parameters: transition_dim: state-dimension of samples to predict over @@ -51,22 +54,21 @@ class UNet1DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - transition_dim=14, - dim=32, - dim_mults=(1, 4, 8), in_channels: int = 14, out_channels: int = 14, - down_block_types: Tuple[str] = ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], - up_block_types: Tuple[str] = ["UpResnetBlock1D", "UpResnetBlock1D"], - block_out_channels: Tuple[int] = [32, 128, 256], + down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), + block_out_channels: Tuple[int] = (32, 128, 256), + act_fn: str = "mish", ): super().__init__() self.transition_dim = in_channels + time_embed_dim = block_out_channels[0] * 4 # time - self.time_proj = Timesteps(num_channels=dim, flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_mlp = TimestepEmbedding(channel=dim, time_embed_dim=4 * dim, act_fn="mish", out_dim=dim) + self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) + self.time_mlp = TimestepEmbedding(channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0]) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -84,14 +86,14 @@ def __init__( down_block_type, in_channels=input_channel, out_channels=output_channel, - temb_channels=dim, + temb_channels=block_out_channels[0], add_downsample=not is_final_block, ) self.down_blocks.append(down_block) # mid - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=dim) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=block_out_channels[0]) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=block_out_channels[0]) # up reversed_block_out_channels = list(reversed(block_out_channels)) @@ -105,16 +107,16 @@ def __init__( up_block_type, in_channels=input_channel * 2, out_channels=output_channel, - temb_channels=dim, + temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) # out - self.final_conv1d_1 = nn.Conv1d(dim, dim, 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(8, dim) + self.final_conv1d_1 = nn.Conv1d(block_out_channels[0], block_out_channels[0], 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(8, block_out_channels[0]) self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(dim, out_channels, 1) + self.final_conv1d_2 = nn.Conv1d(block_out_channels[0], out_channels, 1) def forward( self, diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index cfdc1f762d47..c65ce221f3dc 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -55,6 +55,8 @@ def __init__( self.nonlinearity = nn.Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None self.downsample = None if add_downsample: @@ -67,6 +69,9 @@ def forward(self, hidden_states, temb=None): hidden_states = self.resnet2(hidden_states, temb) output_states += (hidden_states,) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.downsample is not None: hidden_states = self.downsample(hidden_states) @@ -106,6 +111,8 @@ def __init__( self.nonlinearity = nn.Mish() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None self.upsample = None if add_upsample: @@ -118,6 +125,9 @@ def forward(self, hidden_states, res_hidden_states=None, temb=None): hidden_states = self.resnet1(hidden_states, temb) hidden_states = self.resnet2(hidden_states, temb) + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: hidden_states = self.upsample(hidden_states) From 634a526ff7e230194cfff7c74e25d745dcc2509f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 10:16:34 -0700 Subject: [PATCH 21/63] minor cleaning --- src/diffusers/models/resnet.py | 2 +- src/diffusers/models/unet_1d.py | 27 ++++++++++++++++---------- src/diffusers/models/unet_1d_blocks.py | 10 +++++----- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 363fed372cbc..4b1be8c13c51 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -501,7 +501,7 @@ def forward(self, x): # unet_rl.py -class ResidualTemporalBlock(nn.Module): +class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 96de7a101d84..c8b14ad8fc95 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from diffusers.models.resnet import ResidualTemporalBlock +from diffusers.models.resnet import ResidualTemporalBlock1D from diffusers.models.unet_1d_blocks import get_down_block, get_up_block from ..configuration_utils import ConfigMixin, register_to_config @@ -46,9 +46,13 @@ class UNet1DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - transition_dim: state-dimension of samples to predict over - dim: embedding dimension of model - dim_mults: dimension multiples of the up/down blocks + in_channels: + out_channels: + down_block_types: + up_block_types: + block_out_channels: + act_fn: + norm_num_groups: """ @register_to_config @@ -60,15 +64,17 @@ def __init__( up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", + norm_num_groups: int = 8, ): super().__init__() - self.transition_dim = in_channels time_embed_dim = block_out_channels[0] * 4 # time self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_mlp = TimestepEmbedding(channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0]) + self.time_mlp = TimestepEmbedding( + channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0] + ) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -92,8 +98,8 @@ def __init__( self.down_blocks.append(down_block) # mid - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=block_out_channels[0]) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=block_out_channels[0]) + self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) + self.mid_block2 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) # up reversed_block_out_channels = list(reversed(block_out_channels)) @@ -105,7 +111,7 @@ def __init__( up_block = get_up_block( up_block_type, - in_channels=input_channel * 2, + in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, @@ -113,8 +119,9 @@ def __init__( self.up_blocks.append(up_block) # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.final_conv1d_1 = nn.Conv1d(block_out_channels[0], block_out_channels[0], 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(8, block_out_channels[0]) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, block_out_channels[0]) self.final_conv1d_act = nn.Mish() self.final_conv1d_2 = nn.Conv1d(block_out_channels[0], out_channels, 1) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index c65ce221f3dc..8ed258ee4f2a 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch import nn -from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D class DownResnetBlock1D(nn.Module): @@ -46,8 +46,8 @@ def __init__( if groups_out is None: groups_out = groups - self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + self.resnet1 = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) @@ -102,8 +102,8 @@ def __init__( if groups_out is None: groups_out = groups - self.resnet1 = ResidualTemporalBlock(in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock(out_channels, out_channels, embed_dim=temb_channels) + self.resnet1 = ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels) + self.resnet2 = ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) From aebf547329546aa0ab72f2b96f208ffceef5ec58 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 10:17:52 -0700 Subject: [PATCH 22/63] fix docs --- docs/source/api/models.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index b944a1d13089..b8fa056dc5be 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -38,7 +38,7 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module [[autodoc]] UNet1DModel ## UNet1DOutput -[[autodoc]] UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput ## VQEncoderOutput [[autodoc]] models.vae.VQEncoderOutput From 305ecd891be91329fe2ee984c947702d8f7a3d18 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 10:33:21 -0700 Subject: [PATCH 23/63] improve 1d resnet blocks --- src/diffusers/models/unet_1d.py | 3 ++ src/diffusers/models/unet_1d_blocks.py | 42 +++++++++++++++++++------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c8b14ad8fc95..9fa1b145a7d5 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -65,6 +65,7 @@ def __init__( block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", norm_num_groups: int = 8, + layers_per_block: int = 1, ): super().__init__() @@ -90,6 +91,7 @@ def __init__( down_block_type = down_block_types[i] down_block = get_down_block( down_block_type, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], @@ -111,6 +113,7 @@ def __init__( up_block = get_up_block( up_block_type, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 8ed258ee4f2a..40e25fb43afb 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -25,6 +25,7 @@ def __init__( self, in_channels, out_channels=None, + num_layers=1, conv_shortcut=False, temb_channels=32, groups=32, @@ -46,8 +47,13 @@ def __init__( if groups_out is None: groups_out = groups - self.resnet1 = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels) + # there will always be at least one resenet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) @@ -65,8 +71,10 @@ def __init__( def forward(self, hidden_states, temb=None): output_states = () - hidden_states = self.resnet1(hidden_states, temb) - hidden_states = self.resnet2(hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) if self.nonlinearity is not None: @@ -83,6 +91,7 @@ def __init__( self, in_channels, out_channels=None, + num_layers=1, temb_channels=32, groups=32, groups_out=None, @@ -102,8 +111,13 @@ def __init__( if groups_out is None: groups_out = groups - self.resnet1 = ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels) - self.resnet2 = ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels) + # there will always be at least one resenet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) if non_linearity == "swish": self.nonlinearity = lambda x: F.silu(x) @@ -122,8 +136,9 @@ def forward(self, hidden_states, res_hidden_states=None, temb=None): if res_hidden_states is not None: hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) - hidden_states = self.resnet1(hidden_states, temb) - hidden_states = self.resnet2(hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) @@ -158,10 +173,11 @@ class UpBlock1DNoSkip(nn.Module): pass -def get_down_block(down_block_type, in_channels, out_channels, temb_channels, add_downsample): +def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( in_channels=in_channels, + num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, @@ -170,10 +186,14 @@ def get_down_block(down_block_type, in_channels, out_channels, temb_channels, ad raise ValueError(f"{down_block_type} does not exist.") -def get_up_block(up_block_type, in_channels, out_channels, temb_channels, add_upsample): +def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): if up_block_type == "UpResnetBlock1D": return UpResnetBlock1D( - in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, ) raise ValueError(f"{up_block_type} does not exist.") From 95d3a1c267fff39e7202c2b0fc1fd46ea404c3a1 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 10:59:47 -0700 Subject: [PATCH 24/63] fix tests, remove permuts --- src/diffusers/models/unet_1d.py | 2 -- tests/test_models_unet.py | 22 ++++++++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 9fa1b145a7d5..3f58a682d449 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -145,7 +145,6 @@ def forward( [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - sample = sample.permute(0, 2, 1) # 1. time timesteps = timestep @@ -179,7 +178,6 @@ def forward( sample = self.final_conv1d_act(sample) sample = self.final_conv1d_2(sample) - sample = sample.permute(0, 2, 1) if not return_dict: return (sample,) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index d822153cd44c..6683ff97099e 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -15,6 +15,7 @@ import gc import math +import pdb import tracemalloc import unittest @@ -459,18 +460,18 @@ def dummy_input(self): num_features = 14 seq_len = 16 - noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) time_step = torch.tensor([10] * batch_size).to(torch_device) return {"sample": noise, "timestep": time_step} @property def input_shape(self): - return (4, 16, 14) + return (4, 14, 16) @property def output_shape(self): - return (4, 16, 14) + return (4, 14, 16) def test_ema_training(self): pass @@ -480,9 +481,9 @@ def test_training(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "dim": 32, - "dim_mults": [1, 4, 8], - "transition_dim": 14, + "block_out_channels": (32, 128, 256), + "in_channels": 14, + "out_channels": 14, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -501,25 +502,22 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") - model.eval() - torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - num_features = model.transition_dim + num_features = model.in_channels seq_len = 16 - noise = torch.randn((1, seq_len, num_features)) + noise = torch.randn((1, seq_len, num_features)).permute(0, 2, 1) # match original, we can update values and remove time_step = torch.full((num_features,), 0) with torch.no_grad(): - output = model(noise, time_step).sample + output = model(noise, time_step).sample.permute(0, 2, 1) output_slice = output[0, -3:, -3:].flatten() # fmt: off expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) def test_forward_with_norm_groups(self): From 6cbb73b57e6b679d0de8ae47e80d50874ad5fb14 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 12 Oct 2022 11:00:38 -0700 Subject: [PATCH 25/63] fix style --- src/diffusers/models/unet_1d.py | 1 - tests/test_models_unet.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 3f58a682d449..2e20cacb64f1 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -178,7 +178,6 @@ def forward( sample = self.final_conv1d_act(sample) sample = self.final_conv1d_2(sample) - if not return_dict: return (sample,) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 6683ff97099e..e1dbdfaa4611 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -15,7 +15,6 @@ import gc import math -import pdb import tracemalloc import unittest @@ -508,7 +507,9 @@ def test_output_pretrained(self): num_features = model.in_channels seq_len = 16 - noise = torch.randn((1, seq_len, num_features)).permute(0, 2, 1) # match original, we can update values and remove + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove time_step = torch.full((num_features,), 0) with torch.no_grad(): From ffb73552a39280b00c16672f0d53e903de5c2b46 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 18 Oct 2022 13:47:33 -0700 Subject: [PATCH 26/63] add output activation --- src/diffusers/models/unet_1d.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 2e20cacb64f1..3ede756c9b3d 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -125,7 +125,10 @@ def __init__( num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.final_conv1d_1 = nn.Conv1d(block_out_channels[0], block_out_channels[0], 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, block_out_channels[0]) - self.final_conv1d_act = nn.Mish() + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() self.final_conv1d_2 = nn.Conv1d(block_out_channels[0], out_channels, 1) def forward( From a6314f67962ca95c40a5b8a41e51ed8dbeebf666 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 18 Oct 2022 14:38:51 -0700 Subject: [PATCH 27/63] rename flax blocks file --- .../models/{unet_blocks_flax.py => unet_2d_blocks_flax.py} | 0 src/diffusers/models/unet_2d_condition_flax.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/diffusers/models/{unet_blocks_flax.py => unet_2d_blocks_flax.py} (100%) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py similarity index 100% rename from src/diffusers/models/unet_blocks_flax.py rename to src/diffusers/models/unet_2d_blocks_flax.py diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 5411855f79d5..f0e721826bd7 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -23,7 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .unet_blocks_flax import ( +from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, From 48a74147233ca48aa6c77530ed244de4b5d59dfe Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Thu, 20 Oct 2022 23:02:00 -0400 Subject: [PATCH 28/63] Add Value Function and corresponding example script to Diffuser implementation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert --- .gitignore | 4 +- examples/community/pipeline.py | 99 ++++++ examples/community/value_guided_diffuser.py | 99 ++++++ examples/diffuser/run_diffuser.py | 122 +++++++ .../diffuser/run_diffuser_value_guided.py | 94 ++++++ examples/diffuser/train_diffuser.py | 312 ++++++++++++++++++ .../convert_models_diffuser_to_diffusers.py | 77 +++++ src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_1d.py | 57 ++-- src/diffusers/models/unet_1d_blocks.py | 78 ++++- src/diffusers/models/unet_rl.py | 135 ++++++++ src/diffusers/schedulers/scheduling_ddpm.py | 6 +- tests/test_models_unet.py | 85 ++++- 14 files changed, 1143 insertions(+), 28 deletions(-) create mode 100644 examples/community/pipeline.py create mode 100644 examples/community/value_guided_diffuser.py create mode 100644 examples/diffuser/run_diffuser.py create mode 100644 examples/diffuser/run_diffuser_value_guided.py create mode 100644 examples/diffuser/train_diffuser.py create mode 100644 scripts/convert_models_diffuser_to_diffusers.py create mode 100644 src/diffusers/models/unet_rl.py diff --git a/.gitignore b/.gitignore index cf8183463613..f018a111ea33 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 \ No newline at end of file diff --git a/examples/community/pipeline.py b/examples/community/pipeline.py new file mode 100644 index 000000000000..7e3f2b832b1f --- /dev/null +++ b/examples/community/pipeline.py @@ -0,0 +1,99 @@ +import torch + +import tqdm +from diffusers import DiffusionPipeline +from diffusers.models.unet_1d import UNet1DModel +from diffusers.utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedDiffuserPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + # 3. call the sample function + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 4. apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + denorm_actions = denorm_actions[0, 0] + return denorm_actions diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py new file mode 100644 index 000000000000..7e3f2b832b1f --- /dev/null +++ b/examples/community/value_guided_diffuser.py @@ -0,0 +1,99 @@ +import torch + +import tqdm +from diffusers import DiffusionPipeline +from diffusers.models.unet_1d import UNet1DModel +from diffusers.utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedDiffuserPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + # 3. call the sample function + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 4. apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + denorm_actions = denorm_actions[0, 0] + return denorm_actions diff --git a/examples/diffuser/run_diffuser.py b/examples/diffuser/run_diffuser.py new file mode 100644 index 000000000000..b29d89992dfc --- /dev/null +++ b/examples/diffuser/run_diffuser.py @@ -0,0 +1,122 @@ +import numpy as np +import torch + +import d4rl # noqa +import gym +import tqdm +import train_diffuser +from diffusers import DDPMScheduler, UNet1DModel + + +env_name = "hopper-medium-expert-v2" +env = gym.make(env_name) +data = env.get_dataset() # dataset is only used for normalization in this colab + +DEVICE = "cpu" +DTYPE = torch.float + +# diffusion model settings +n_samples = 4 # number of trajectories planned via diffusion +horizon = 128 # length of sampled trajectories +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.shape[0] +num_inference_steps = 100 # number of difusion steps + + +# Two generators for different parts of the diffusion loop to work in colab +generator_cpu = torch.Generator(device="cpu") + +scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") + +# 3 different pretrained models are available for this task. +# The horizion represents the length of trajectories used in training. +network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) + + +# network specific constants for inference +clip_denoised = network.clip_denoised +predict_epsilon = network.predict_epsilon + +# [ observation_dim ] --> [ n_samples x observation_dim ] +obs = env.reset() +total_reward = 0 +done = False +T = 300 +rollout = [obs.copy()] + +try: + for t in tqdm.tqdm(range(T)): + obs_raw = obs + + # normalize observations for forward passes + obs = train_diffuser.normalize(obs, data, "observations") + obs = obs[None].repeat(n_samples, axis=0) + conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} + + # constants for inference + batch_size = len(conditions[0]) + shape = (batch_size, horizon, state_dim + action_dim) + + # sample random initial noise vector + x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) + + # this model is conditioned from an initial state, so you will see this function + # multiple times to change the initial state of generated data to the state + # generated via env.reset() above or env.step() below + x = train_diffuser.reset_x0(x1, conditions, action_dim) + + # convert a np observation to torch for model forward pass + x = train_diffuser.to_torch(x) + + eta = 1.0 # noise factor for sampling reconstructed state + + # run the diffusion process + # for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): + for i in tqdm.tqdm(scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) + + # 1. generate prediction from model + with torch.no_grad(): + residual = network(x, timesteps).sample + + # 2. use the model prediction to reconstruct an observation (de-noise) + obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] + + # 3. [optional] add posterior noise to the sample + if eta > 0: + noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) + posterior_variance = scheduler._get_variance(i) # * noise + # no noise when t == 0 + # NOTE: original implementation missing sqrt on posterior_variance + obs_reconstruct = ( + obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise + ) # MJ had as log var, exponentiated + + # 4. apply conditions to the trajectory + obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) + x = train_diffuser.to_torch(obs_reconstruct_postcond) + plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) + # select random plan + idx = np.random.randint(plans.shape[0]) + # select action at correct time + action = plans[idx, 0, :] + actions = train_diffuser.de_normalize(action, data, "actions") + # execute action in environment + next_observation, reward, terminal, _ = env.step(action) + + # update return + total_reward += reward + print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") + + # save observations for rendering + rollout.append(next_observation.copy()) + obs = next_observation +except KeyboardInterrupt: + pass + +print(f"Total reward: {total_reward}") +render = train_diffuser.MuJoCoRenderer(env) +train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py new file mode 100644 index 000000000000..4272ec2c3106 --- /dev/null +++ b/examples/diffuser/run_diffuser_value_guided.py @@ -0,0 +1,94 @@ +import d4rl # noqa +import gym +import tqdm + +# import train_diffuser +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + # Cuda settings for colab + # torch.cuda.get_device_name(0) + DEVICE = config["device"] + + # Two generators for different parts of the diffusion loop to work in colab + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + + # 3 different pretrained models are available for this task. + # The horizion represents the length of trajectories used in training. + # network = ValueFunction(training_horizon=horizon, dim=32, dim_mults=(1, 2, 4, 8), transition_dim=14, cond_dim=11) + + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + # unet = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) + # network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) + + # add a batch dimension and repeat for multiple samples + # [ observation_dim ] --> [ n_samples x observation_dim ] + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # 1. Call the policy + # normalize observations for forward passes + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/examples/diffuser/train_diffuser.py b/examples/diffuser/train_diffuser.py new file mode 100644 index 000000000000..b063a0456d97 --- /dev/null +++ b/examples/diffuser/train_diffuser.py @@ -0,0 +1,312 @@ +import os +import warnings + +import numpy as np +import torch + +import d4rl # noqa +import gym +import mediapy as media +import mujoco_py as mjc +import tqdm +from diffusers import DDPMScheduler, UNet1DModel + + +# Define some helper functions + + +DTYPE = torch.float + + +def normalize(x_in, data, key): + means = data[key].mean(axis=0) + stds = data[key].std(axis=0) + return (x_in - means) / stds + + +def de_normalize(x_in, data, key): + means = data[key].mean(axis=0) + stds = data[key].std(axis=0) + return x_in * stds + means + + +def to_torch(x_in, dtype=None, device="cuda"): + dtype = dtype or DTYPE + device = device + if type(x_in) is dict: + return {k: to_torch(v, dtype, device) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(device).type(dtype) + return torch.tensor(x_in, dtype=dtype, device=device) + + +def reset_x0(x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + +def run_diffusion(x, scheduler, network, unet, conditions, action_dim, config): + y = None + for i in tqdm.tqdm(scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long) + # 3. call the sample function + for _ in range(config["n_guide_steps"]): + with torch.enable_grad(): + x.requires_grad_() + y = network(x, timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + if config["scale_grad_by_std"]: + posterior_variance = scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < config["t_grad_cutoff"]] = 0 + x = x.detach() + x = x + config["scale"] * grad + x = reset_x0(x, conditions, action_dim) + # with torch.no_grad(): + prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # 3. [optional] add posterior noise to the sample + if config["eta"] > 0: + noise = torch.randn(x.shape).to(x.device) + posterior_variance = scheduler._get_variance(i) # * noise + # no noise when t == 0 + # NOTE: original implementation missing sqrt on posterior_variance + x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated + + # 4. apply conditions to the trajectory + x = reset_x0(x, conditions, action_dim) + x = to_torch(x, device=config["device"]) + # y = network(x, timesteps).sample + return x, y + + +def to_np(x_in): + if torch.is_tensor(x_in): + x_in = x_in.detach().cpu().numpy() + return x_in + + +# from MJ's Diffuser code +# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79 +def mkdir(savepath): + """ + returns `True` iff `savepath` is created + """ + if not os.path.exists(savepath): + os.makedirs(savepath) + return True + else: + return False + + +def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"): + """ + observations : [ batch_size x horizon x observation_dim ] + """ + + mkdir(savebase) + savepath = os.path.join(savebase, filename) + + images = [] + for rollout in observations: + # [ horizon x height x width x channels ] + img = renderer._renders(rollout, partial=True) + images.append(img) + + # [ horizon x height x (batch_size * width) x channels ] + images = np.concatenate(images, axis=2) + media.write_video(savepath, images, fps=60) + media.show_video(images, codec="h264", fps=60) + return images + + +# Code adapted from Michael Janner +# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py + + +def env_map(env_name): + """ + map D4RL dataset names to custom fully-observed + variants for rendering + """ + if "halfcheetah" in env_name: + return "HalfCheetahFullObs-v2" + elif "hopper" in env_name: + return "HopperFullObs-v2" + elif "walker2d" in env_name: + return "Walker2dFullObs-v2" + else: + return env_name + + +def get_image_mask(img): + background = (img == 255).all(axis=-1, keepdims=True) + mask = ~background.repeat(3, axis=-1) + return mask + + +def atmost_2d(x): + while x.ndim > 2: + x = x.squeeze(0) + return x + + +def set_state(env, state): + qpos_dim = env.sim.data.qpos.size + qvel_dim = env.sim.data.qvel.size + if not state.size == qpos_dim + qvel_dim: + warnings.warn( + f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}" + ) + state = state[: qpos_dim + qvel_dim] + + env.set_state(state[:qpos_dim], state[qpos_dim:]) + + +class MuJoCoRenderer: + """ + default mujoco renderer + """ + + def __init__(self, env): + if type(env) is str: + env = env_map(env) + self.env = gym.make(env) + else: + self.env = env + # - 1 because the envs in renderer are fully-observed + # @TODO : clean up + self.observation_dim = np.prod(self.env.observation_space.shape) - 1 + self.action_dim = np.prod(self.env.action_space.shape) + try: + self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) + except: + print("[ utils/rendering ] Warning: could not initialize offscreen renderer") + self.viewer = None + + def pad_observation(self, observation): + state = np.concatenate( + [ + np.zeros(1), + observation, + ] + ) + return state + + def pad_observations(self, observations): + qpos_dim = self.env.sim.data.qpos.size + # xpos is hidden + xvel_dim = qpos_dim - 1 + xvel = observations[:, xvel_dim] + xpos = np.cumsum(xvel) * self.env.dt + states = np.concatenate( + [ + xpos[:, None], + observations, + ], + axis=-1, + ) + return states + + def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): + if type(dim) == int: + dim = (dim, dim) + + if self.viewer is None: + return np.zeros((*dim, 3), np.uint8) + + if render_kwargs is None: + xpos = observation[0] if not partial else 0 + render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20} + + for key, val in render_kwargs.items(): + if key == "lookat": + self.viewer.cam.lookat[:] = val[:] + else: + setattr(self.viewer.cam, key, val) + + if partial: + state = self.pad_observation(observation) + else: + state = observation + + qpos_dim = self.env.sim.data.qpos.size + if not qvel or state.shape[-1] == qpos_dim: + qvel_dim = self.env.sim.data.qvel.size + state = np.concatenate([state, np.zeros(qvel_dim)]) + + set_state(self.env, state) + + self.viewer.render(*dim) + data = self.viewer.read_pixels(*dim, depth=False) + data = data[::-1, :, :] + return data + + def _renders(self, observations, **kwargs): + images = [] + for observation in observations: + img = self.render(observation, **kwargs) + images.append(img) + return np.stack(images, axis=0) + + def renders(self, samples, partial=False, **kwargs): + if partial: + samples = self.pad_observations(samples) + partial = False + + sample_images = self._renders(samples, partial=partial, **kwargs) + + composite = np.ones_like(sample_images[0]) * 255 + + for img in sample_images: + mask = get_image_mask(img) + composite[mask] = img[mask] + + return composite + + def __call__(self, *args, **kwargs): + return self.renders(*args, **kwargs) + + +env_name = "hopper-medium-expert-v2" +env = gym.make(env_name) +data = env.get_dataset() # dataset is only used for normalization in this colab + +# Cuda settings for colab +# torch.cuda.get_device_name(0) +DEVICE = "cpu" +DTYPE = torch.float + +# diffusion model settings +n_samples = 4 # number of trajectories planned via diffusion +horizon = 128 # length of sampled trajectories +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.shape[0] +num_inference_steps = 100 # number of difusion steps + +obs = env.reset() +obs_raw = obs + +# normalize observations for forward passes +obs = normalize(obs, data, "observations") + + +# Two generators for different parts of the diffusion loop to work in colab +generator = torch.Generator(device="cuda") +generator_cpu = torch.Generator(device="cpu") +network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) + +scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") +optimizer = torch.optim.AdamW( + network.parameters(), + lr=0.001, + betas=(0.95, 0.99), + weight_decay=1e-6, + eps=1e-8, +) + +# TODO: Flesh this out using accelerate library (a la other examples) diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py new file mode 100644 index 000000000000..b154295e9726 --- /dev/null +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,77 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = dict( + down_block_types=down_block_types, + block_out_channels=block_out_channels, + up_block_types=up_block_types, + layers_per_block=1, + ) + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = dict( + in_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + always_downsample=True, + ) + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fa97effaaf0a..7088e560dd66 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c5d53b2feb4b..b771aaac8467 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_rl import ValueFunction from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 3ede756c9b3d..b720c78b8833 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -17,14 +17,12 @@ import torch import torch.nn as nn -from diffusers.models.resnet import ResidualTemporalBlock1D -from diffusers.models.unet_1d_blocks import get_down_block, get_up_block +from diffusers.models.unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import TimestepEmbedding, Timesteps -from .resnet import rearrange_dims @dataclass @@ -62,10 +60,13 @@ def __init__( out_channels: int = 14, down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), + mid_block_types: Tuple[str] = ("MidResTemporalBlock1D", "MidResTemporalBlock1D"), + out_block_type: str = "OutConv1DBlock", block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", norm_num_groups: int = 8, layers_per_block: int = 1, + always_downsample: bool = False, ): super().__init__() @@ -95,14 +96,30 @@ def __init__( in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], - add_downsample=not is_final_block, + add_downsample=not is_final_block or always_downsample, ) self.down_blocks.append(down_block) # mid - self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) - self.mid_block2 = ResidualTemporalBlock1D(mid_dim, mid_dim, embed_dim=block_out_channels[0]) - + self.mid_blocks = nn.ModuleList([]) + for i, mid_block_type in enumerate(mid_block_types): + if always_downsample: + mid_block = get_mid_block( + mid_block_type, + in_channels=mid_dim // (i + 1), + out_channels=mid_dim // ((i + 1) * 2), + embed_dim=block_out_channels[0], + add_downsample=True, + ) + else: + mid_block = get_mid_block( + mid_block_type, + in_channels=mid_dim, + out_channels=mid_dim, + embed_dim=block_out_channels[0], + add_downsample=False, + ) + self.mid_blocks.append(mid_block) # up reversed_block_out_channels = list(reversed(block_out_channels)) for i, up_block_type in enumerate(up_block_types): @@ -123,13 +140,14 @@ def __init__( # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.final_conv1d_1 = nn.Conv1d(block_out_channels[0], block_out_channels[0], 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(num_groups_out, block_out_channels[0]) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(block_out_channels[0], out_channels, 1) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=mid_dim // 4, + ) def forward( self, @@ -166,20 +184,15 @@ def forward( down_block_res_samples.append(res_samples[0]) # 3. mid - sample = self.mid_block1(sample, temb) - sample = self.mid_block2(sample, temb) + for mid_block in self.mid_blocks: + sample = mid_block(sample, temb) # 4. up for up_block in self.up_blocks: sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb) # 5. post-process - sample = self.final_conv1d_1(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_gn(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_act(sample) - sample = self.final_conv1d_2(sample) + sample = self.out_block(sample, temb) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 40e25fb43afb..a00372faf7d9 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch import nn -from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): @@ -173,6 +173,66 @@ class UpBlock1DNoSkip(nn.Module): pass +class MidResTemporalBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim, add_downsample): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + self.resnet = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) + + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + else: + self.downsample = nn.Identity() + + def forward(self, sample, temb): + sample = self.resnet(sample, temb) + sample = self.downsample(sample) + return sample + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, sample, t): + sample = self.final_conv1d_1(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_gn(sample) + sample = rearrange_dims(sample) + sample = self.final_conv1d_act(sample) + sample = self.final_conv1d_2(sample) + return sample + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, sample, t): + sample = sample.view(sample.shape[0], -1) + sample = torch.cat((sample, t), dim=-1) + for layer in self.final_block: + sample = layer(sample) + + return sample + + def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( @@ -195,5 +255,19 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan temb_channels=temb_channels, add_upsample=add_upsample, ) - + elif up_block_type == "Identity": + return nn.Identity() raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, in_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D(in_channels, out_channels, embed_dim, add_downsample) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py new file mode 100644 index 000000000000..66822f99b198 --- /dev/null +++ b/src/diffusers/models/unet_rl.py @@ -0,0 +1,135 @@ +# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py +from dataclasses import dataclass +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock1D +from diffusers.models.unet_1d_blocks import get_down_block + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import TimestepEmbedding, Timesteps + + +@dataclass +class ValueFunctionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch, horizon, 1)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class ValueFunction(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels=14, + down_block_types: Tuple[str] = ( + "DownResnetBlock1D", + "DownResnetBlock1D", + "DownResnetBlock1D", + "DownResnetBlock1D", + ), + block_out_channels: Tuple[int] = (32, 64, 128, 256), + act_fn: str = "mish", + norm_num_groups: int = 8, + layers_per_block: int = 1, + ): + super().__init__() + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) + self.time_mlp = TimestepEmbedding( + channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn="mish", out_dim=block_out_channels[0] + ) + + self.blocks = nn.ModuleList([]) + mid_dim = block_out_channels[-1] + + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + down_block_type = down_block_types[i] + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=True, + ) + self.blocks.append(down_block) + + ## + self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim // 2, embed_dim=block_out_channels[0]) + self.mid_down1 = Downsample1D(mid_dim // 2, use_conv=True) + ## + self.mid_block2 = ResidualTemporalBlock1D(mid_dim // 2, mid_dim // 4, embed_dim=block_out_channels[0]) + self.mid_down2 = Downsample1D(mid_dim // 4, use_conv=True) + ## + fc_dim = mid_dim // 4 + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + block_out_channels[0], fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[ValueFunctionOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_rl.ValueFunctionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_rl.ValueFunctionOutput`] or `tuple`: [`~models.unet_rl.ValueFunctionOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + sample = sample.permute(0, 2, 1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t = self.time_proj(timesteps) + t = self.time_mlp(t) + down_block_res_samples = [] + + # 2. down + for downsample_block in self.blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=t) + down_block_res_samples.append(res_samples[0]) + + # 3. mid + sample = self.mid_block1(sample, t) + sample = self.mid_down1(sample) + sample = self.mid_block2(sample, t) + sample = self.mid_down2(sample) + + sample = sample.view(sample.shape[0], -1) + sample = torch.cat((sample, t), dim=-1) + for layer in self.final_block: + sample = layer(sample) + + if not return_dict: + return (sample,) + + return ValueFunctionOutput(sample=sample) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 04c92904a660..06596bd6091f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -200,6 +200,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": @@ -283,7 +284,10 @@ def step( noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index e1dbdfaa4611..55f373af8a9b 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -524,3 +524,86 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = ValueFunction + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + unet, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True + ) + value_function, vf_loading_info = ValueFunction.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + self.assertIsNotNone(unet) + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + unet.to(torch_device) + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + value_function, vf_loading_info = ValueFunction.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([207.0272] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass From 3acddb56e67e1fbcbac3e243a651daad93df3239 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 20 Oct 2022 20:05:20 -0700 Subject: [PATCH 29/63] update post merge of scripts --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 - src/diffusers/models/unet_rl.py | 135 ------------------------------- tests/test_models_unet.py | 85 +------------------ 4 files changed, 2 insertions(+), 221 deletions(-) delete mode 100644 src/diffusers/models/unet_rl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7088e560dd66..fa97effaaf0a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction, VQModel + from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b771aaac8467..c5d53b2feb4b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,7 +19,6 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .unet_rl import ValueFunction from .vae import AutoencoderKL, VQModel if is_flax_available(): diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py deleted file mode 100644 index 66822f99b198..000000000000 --- a/src/diffusers/models/unet_rl.py +++ /dev/null @@ -1,135 +0,0 @@ -# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py -from dataclasses import dataclass -from typing import Tuple, Union - -import torch -import torch.nn as nn - -from diffusers.models.resnet import Downsample1D, ResidualTemporalBlock1D -from diffusers.models.unet_1d_blocks import get_down_block - -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..utils import BaseOutput -from .embeddings import TimestepEmbedding, Timesteps - - -@dataclass -class ValueFunctionOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch, horizon, 1)`): - Hidden states output. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class ValueFunction(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - in_channels=14, - down_block_types: Tuple[str] = ( - "DownResnetBlock1D", - "DownResnetBlock1D", - "DownResnetBlock1D", - "DownResnetBlock1D", - ), - block_out_channels: Tuple[int] = (32, 64, 128, 256), - act_fn: str = "mish", - norm_num_groups: int = 8, - layers_per_block: int = 1, - ): - super().__init__() - time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_mlp = TimestepEmbedding( - channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn="mish", out_dim=block_out_channels[0] - ) - - self.blocks = nn.ModuleList([]) - mid_dim = block_out_channels[-1] - - output_channel = in_channels - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - - down_block_type = down_block_types[i] - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_downsample=True, - ) - self.blocks.append(down_block) - - ## - self.mid_block1 = ResidualTemporalBlock1D(mid_dim, mid_dim // 2, embed_dim=block_out_channels[0]) - self.mid_down1 = Downsample1D(mid_dim // 2, use_conv=True) - ## - self.mid_block2 = ResidualTemporalBlock1D(mid_dim // 2, mid_dim // 4, embed_dim=block_out_channels[0]) - self.mid_down2 = Downsample1D(mid_dim // 4, use_conv=True) - ## - fc_dim = mid_dim // 4 - self.final_block = nn.ModuleList( - [ - nn.Linear(fc_dim + block_out_channels[0], fc_dim // 2), - nn.Mish(), - nn.Linear(fc_dim // 2, 1), - ] - ) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - return_dict: bool = True, - ) -> Union[ValueFunctionOutput, Tuple]: - """r - Args: - sample (`torch.FloatTensor`): (batch, horizon, obs_dimension + action_dimension) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): batch (batch) timesteps - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_rl.ValueFunctionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_rl.ValueFunctionOutput`] or `tuple`: [`~models.unet_rl.ValueFunctionOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - sample = sample.permute(0, 2, 1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - t = self.time_proj(timesteps) - t = self.time_mlp(t) - down_block_res_samples = [] - - # 2. down - for downsample_block in self.blocks: - sample, res_samples = downsample_block(hidden_states=sample, temb=t) - down_block_res_samples.append(res_samples[0]) - - # 3. mid - sample = self.mid_block1(sample, t) - sample = self.mid_down1(sample) - sample = self.mid_block2(sample, t) - sample = self.mid_down2(sample) - - sample = sample.view(sample.shape[0], -1) - sample = torch.cat((sample, t), dim=-1) - for layer in self.final_block: - sample = layer(sample) - - if not return_dict: - return (sample,) - - return ValueFunctionOutput(sample=sample) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 55f373af8a9b..e1dbdfaa4611 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -524,86 +524,3 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass - - -class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): - model_class = ValueFunction - - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 14, 16) - - @property - def output_shape(self): - return (4, 14, 1) - - def test_ema_training(self): - pass - - def test_training(self): - pass - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64, 128, 256), - "in_channels": 14, - "out_channels": 14, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - unet, loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True - ) - value_function, vf_loading_info = ValueFunction.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True - ) - self.assertIsNotNone(unet) - self.assertEqual(len(loading_info["missing_keys"]), 0) - self.assertIsNotNone(value_function) - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) - - unet.to(torch_device) - value_function.to(torch_device) - image = value_function(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - value_function, vf_loading_info = ValueFunction.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True - ) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = value_function.in_channels - seq_len = 14 - noise = torch.randn((1, seq_len, num_features)).permute( - 0, 2, 1 - ) # match original, we can update values and remove - time_step = torch.full((num_features,), 0) - - with torch.no_grad(): - output = value_function(noise, time_step).sample - - # fmt: off - expected_output_slice = torch.tensor([207.0272] * seq_len) - # fmt: on - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) - - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass From 713e8f27172fe708703342a4b3f67802172845dd Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 09:47:55 -0700 Subject: [PATCH 30/63] add mdiblock / outblock architecture --- src/diffusers/models/unet_1d.py | 39 +++++------ src/diffusers/models/unet_1d_blocks.py | 94 +++++++++++++++++++------- 2 files changed, 85 insertions(+), 48 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index b720c78b8833..8f74926da505 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -60,7 +60,7 @@ def __init__( out_channels: int = 14, down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), - mid_block_types: Tuple[str] = ("MidResTemporalBlock1D", "MidResTemporalBlock1D"), + mid_block_type: Tuple[str] = "MidResTemporalBlock1D", out_block_type: str = "OutConv1DBlock", block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", @@ -79,7 +79,9 @@ def __init__( ) self.down_blocks = nn.ModuleList([]) + self.mid_block = None self.up_blocks = nn.ModuleList([]) + self.out_block = None mid_dim = block_out_channels[-1] # down @@ -101,25 +103,15 @@ def __init__( self.down_blocks.append(down_block) # mid - self.mid_blocks = nn.ModuleList([]) - for i, mid_block_type in enumerate(mid_block_types): - if always_downsample: - mid_block = get_mid_block( - mid_block_type, - in_channels=mid_dim // (i + 1), - out_channels=mid_dim // ((i + 1) * 2), - embed_dim=block_out_channels[0], - add_downsample=True, - ) - else: - mid_block = get_mid_block( - mid_block_type, - in_channels=mid_dim, - out_channels=mid_dim, - embed_dim=block_out_channels[0], - add_downsample=False, - ) - self.mid_blocks.append(mid_block) + self.mid_block = get_mid_block( + mid_block_type, + in_channels=mid_dim, + out_channels=mid_dim, + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=always_downsample, + ) + # up reversed_block_out_channels = list(reversed(block_out_channels)) for i, up_block_type in enumerate(up_block_types): @@ -184,15 +176,16 @@ def forward( down_block_res_samples.append(res_samples[0]) # 3. mid - for mid_block in self.mid_blocks: - sample = mid_block(sample, temb) + if self.mid_block: + sample = self.mid_block(sample, temb) # 4. up for up_block in self.up_blocks: sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb) # 5. post-process - sample = self.out_block(sample, temb) + if self.out_block: + sample = self.out_block(sample, temb) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a00372faf7d9..e1a1ac4a8f0c 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -47,7 +47,7 @@ def __init__( if groups_out is None: groups_out = groups - # there will always be at least one resenet + # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): @@ -111,7 +111,7 @@ def __init__( if groups_out is None: groups_out = groups - # there will always be at least one resenet + # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): @@ -174,22 +174,60 @@ class UpBlock1DNoSkip(nn.Module): class MidResTemporalBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, embed_dim, add_downsample): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity=None, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_downsample = add_downsample - self.resnet = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim) + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_downsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True) - else: - self.downsample = nn.Identity() - def forward(self, sample, temb): - sample = self.resnet(sample, temb) - sample = self.downsample(sample) - return sample + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states class OutConv1DBlock(nn.Module): @@ -203,14 +241,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): self.final_conv1d_act = nn.Mish() self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) - def forward(self, sample, t): - sample = self.final_conv1d_1(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_gn(sample) - sample = rearrange_dims(sample) - sample = self.final_conv1d_act(sample) - sample = self.final_conv1d_2(sample) - return sample + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states class OutValueFunctionBlock(nn.Module): @@ -224,13 +262,13 @@ def __init__(self, fc_dim, embed_dim): ] ) - def forward(self, sample, t): - sample = sample.view(sample.shape[0], -1) - sample = torch.cat((sample, t), dim=-1) + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) for layer in self.final_block: - sample = layer(sample) + hidden_states = layer(hidden_states) - return sample + return hidden_states def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): @@ -260,9 +298,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan raise ValueError(f"{up_block_type} does not exist.") -def get_mid_block(mid_block_type, in_channels, out_channels, embed_dim, add_downsample): +def get_mid_block(mid_block_type, num_layers, in_channels, out_channels, embed_dim, add_downsample): if mid_block_type == "MidResTemporalBlock1D": - return MidResTemporalBlock1D(in_channels, out_channels, embed_dim, add_downsample) + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) raise ValueError(f"{mid_block_type} does not exist.") From 268ebdf4d84ec6b7432aa77e025827c25a628f7b Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 24 Oct 2022 13:17:41 -0400 Subject: [PATCH 31/63] Pipeline cleanup (#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert --- examples/community/pipeline.py | 21 ++++- examples/community/value_guided_diffuser.py | 13 ++- examples/diffuser/README.md | 16 ++++ .../diffuser/run_diffuser_gen_trajectories.py | 79 +++++++++++++++++ examples/diffuser/run_diffuser_locomotion.py | 83 ++++++++++++++++++ tests/test_models_unet.py | 85 ++++++++++++++++++- 6 files changed, 291 insertions(+), 6 deletions(-) create mode 100644 examples/diffuser/README.md create mode 100644 examples/diffuser/run_diffuser_gen_trajectories.py create mode 100644 examples/diffuser/run_diffuser_locomotion.py diff --git a/examples/community/pipeline.py b/examples/community/pipeline.py index 7e3f2b832b1f..85e359c5c4c9 100644 --- a/examples/community/pipeline.py +++ b/examples/community/pipeline.py @@ -1,3 +1,4 @@ +import numpy as np import torch import tqdm @@ -59,7 +60,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - # 3. call the sample function for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() @@ -76,24 +76,39 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - # 4. apply conditions to the trajectory + # apply conditions to the trajectory x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = torch.randn(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) + + # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") - denorm_actions = denorm_actions[0, 0] + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + denorm_actions = denorm_actions[selected_index, 0] return denorm_actions diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py index 7e3f2b832b1f..6b28e868eddd 100644 --- a/examples/community/value_guided_diffuser.py +++ b/examples/community/value_guided_diffuser.py @@ -59,7 +59,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - # 3. call the sample function for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() @@ -76,24 +75,34 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - # 4. apply conditions to the trajectory + # apply conditions to the trajectory x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = torch.randn(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) + + # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value denorm_actions = denorm_actions[0, 0] return denorm_actions diff --git a/examples/diffuser/README.md b/examples/diffuser/README.md new file mode 100644 index 000000000000..464ccd57af85 --- /dev/null +++ b/examples/diffuser/README.md @@ -0,0 +1,16 @@ +# Overview + +These examples show how to run (Diffuser)[https://arxiv.org/pdf/2205.09991.pdf] in Diffusers. There are two scripts, `run_diffuser_value_guided.py` and `run_diffuser.py`. + +You will need some RL specific requirements to run the examples: + +``` +pip install -f https://download.pytorch.org/whl/torch_stable.html \ + free-mujoco-py \ + einops \ + gym \ + protobuf==3.20.1 \ + git+https://github.com/rail-berkeley/d4rl.git \ + mediapy \ + Pillow==9.0.0 +``` diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/diffuser/run_diffuser_gen_trajectories.py new file mode 100644 index 000000000000..f4c86635c652 --- /dev/null +++ b/examples/diffuser/run_diffuser_gen_trajectories.py @@ -0,0 +1,79 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=0, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + DEVICE = config["device"] + + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # Call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/diffuser/run_diffuser_locomotion.py new file mode 100644 index 000000000000..1b4351095d3b --- /dev/null +++ b/examples/diffuser/run_diffuser_locomotion.py @@ -0,0 +1,83 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + # Cuda settings for colab + # torch.cuda.get_device_name(0) + DEVICE = config["device"] + + # Two generators for different parts of the diffusion loop to work in colab + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index e1dbdfaa4611..1ff092b3ce78 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -524,3 +524,86 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + unet, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True + ) + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + self.assertIsNotNone(unet) + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + unet.to(torch_device) + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([207.0272] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass From daa05fb66f70f63f6e8b34fecb6f48f80a8f995c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 10:25:39 -0700 Subject: [PATCH 32/63] Update src/diffusers/models/unet_1d_blocks.py --- src/diffusers/models/unet_1d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index e1a1ac4a8f0c..1beea2b123ac 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -207,7 +207,7 @@ def __init__( self.nonlinearity = None self.upsample = None - if add_downsample: + if add_upsample: self.upsample = Downsample1D(out_channels, use_conv=True) self.downsample = None From ea5f2310c74a32a43d4ac564ccf1532de7baa970 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 10:27:29 -0700 Subject: [PATCH 33/63] Update tests/test_models_unet.py --- tests/test_models_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 1ff092b3ce78..d6578955f295 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin From 4f7a3a43e49ddeefa392ff7e84e9480f89f57891 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 24 Oct 2022 13:52:41 -0400 Subject: [PATCH 34/63] RL Cleanup v2 (#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert --- .../convert_models_diffuser_to_diffusers.py | 1 + src/diffusers/models/unet_1d_blocks.py | 22 +++++++++++++++++++ tests/test_models_unet.py | 16 ++++++++------ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py index b154295e9726..4b4608358c17 100644 --- a/scripts/convert_models_diffuser_to_diffusers.py +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -49,6 +49,7 @@ def value_function(): down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types=(), out_block_type="ValueFunction", + mid_block_type="ValueFunctionMidBlock1D", block_out_channels=(32, 64, 128, 256), layers_per_block=1, always_downsample=True, diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 1beea2b123ac..0788cc1e76e5 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -173,6 +173,26 @@ class UpBlock1DNoSkip(nn.Module): pass +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + class MidResTemporalBlock1D(nn.Module): def __init__( self, @@ -307,6 +327,8 @@ def get_mid_block(mid_block_type, num_layers, in_channels, out_channels, embed_d embed_dim=embed_dim, add_downsample=add_downsample, ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) raise ValueError(f"{mid_block_type} does not exist.") diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index d6578955f295..03ed28f5d442 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -22,6 +22,7 @@ from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel from diffusers.utils import floats_tensor, slow, torch_device +from regex import subf from .test_modeling_common import ModelTesterMixin @@ -489,7 +490,7 @@ def prepare_init_args_and_inputs_for_common(self): def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( - "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" ) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -500,7 +501,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") + model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) @@ -517,7 +518,8 @@ def test_output_pretrained(self): output_slice = output[0, -3:, -3:].flatten() # fmt: off - expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) + expected_output_slice = torch.tensor([-2.137172 , 1.1426016 , 0.3688687 , -0.766922 , 0.7303146 , + 0.11038864, -0.4760633 , 0.13270172, 0.02591348]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) @@ -565,10 +567,10 @@ def prepare_init_args_and_inputs_for_common(self): def test_from_pretrained_hub(self): unet, loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" ) value_function, vf_loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) self.assertIsNotNone(unet) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -583,7 +585,7 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) torch.manual_seed(0) if torch.cuda.is_available(): @@ -600,7 +602,7 @@ def test_output_pretrained(self): output = value_function(noise, time_step).sample # fmt: off - expected_output_slice = torch.tensor([207.0272] * seq_len) + expected_output_slice = torch.tensor([165.25] * seq_len) # fmt: on self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) From d90b8b1bd60fbf5c858746849bde62e95f428ac2 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 10:56:04 -0700 Subject: [PATCH 35/63] fix quality in tests --- tests/test_models_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 03ed28f5d442..8608eec166a4 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -518,8 +518,8 @@ def test_output_pretrained(self): output_slice = output[0, -3:, -3:].flatten() # fmt: off - expected_output_slice = torch.tensor([-2.137172 , 1.1426016 , 0.3688687 , -0.766922 , 0.7303146 , - 0.11038864, -0.4760633 , 0.13270172, 0.02591348]) + expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, + 0.11038864, -0.4760633, 0.13270172, 0.02591348]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) From ad8b6cf1123ca6b5bebe7ab4af866ea54d00539a Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 11:00:29 -0700 Subject: [PATCH 36/63] fix quality style, split test file --- tests/test_models_unet_1d.py | 185 ++++++++++++++++++ ..._models_unet.py => test_models_unet_2d.py} | 163 +-------------- 2 files changed, 186 insertions(+), 162 deletions(-) create mode 100644 tests/test_models_unet_1d.py rename tests/{test_models_unet.py => test_models_unet_2d.py} (74%) diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py new file mode 100644 index 000000000000..f50bb8785eae --- /dev/null +++ b/tests/test_models_unet_1d.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import UNet1DModel +from diffusers.utils import floats_tensor, torch_device + +from .test_modeling_common import ModelTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 16) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.in_channels + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step).sample.permute(0, 2, 1) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + unet, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" + ) + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + self.assertIsNotNone(unet) + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + unet.to(torch_device) + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([165.25] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass diff --git a/tests/test_models_unet.py b/tests/test_models_unet_2d.py similarity index 74% rename from tests/test_models_unet.py rename to tests/test_models_unet_2d.py index 8608eec166a4..b2f16aef5825 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet_2d.py @@ -20,9 +20,8 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel +from diffusers import UNet2DConditionModel, UNet2DModel from diffusers.utils import floats_tensor, slow, torch_device -from regex import subf from .test_modeling_common import ModelTesterMixin @@ -449,163 +448,3 @@ def test_output_pretrained_ve_large(self): def test_forward_with_norm_groups(self): # not required for this model pass - - -class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet1DModel - - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 14, 16) - - @property - def output_shape(self): - return (4, 14, 16) - - def test_ema_training(self): - pass - - def test_training(self): - pass - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 128, 256), - "in_channels": 14, - "out_channels": 14, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" - ) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = model.in_channels - seq_len = 16 - noise = torch.randn((1, seq_len, num_features)).permute( - 0, 2, 1 - ) # match original, we can update values and remove - time_step = torch.full((num_features,), 0) - - with torch.no_grad(): - output = model(noise, time_step).sample.permute(0, 2, 1) - - output_slice = output[0, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, - 0.11038864, -0.4760633, 0.13270172, 0.02591348]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) - - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass - - -class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet1DModel - - @property - def dummy_input(self): - batch_size = 4 - num_features = 14 - seq_len = 16 - - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) - time_step = torch.tensor([10] * batch_size).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 14, 16) - - @property - def output_shape(self): - return (4, 14, 1) - - def test_ema_training(self): - pass - - def test_training(self): - pass - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64, 128, 256), - "in_channels": 14, - "out_channels": 14, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - unet, loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" - ) - value_function, vf_loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" - ) - self.assertIsNotNone(unet) - self.assertEqual(len(loading_info["missing_keys"]), 0) - self.assertIsNotNone(value_function) - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) - - unet.to(torch_device) - value_function.to(torch_device) - image = value_function(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - value_function, vf_loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" - ) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - num_features = value_function.in_channels - seq_len = 14 - noise = torch.randn((1, seq_len, num_features)).permute( - 0, 2, 1 - ) # match original, we can update values and remove - time_step = torch.full((num_features,), 0) - - with torch.no_grad(): - output = value_function(noise, time_step).sample - - # fmt: off - expected_output_slice = torch.tensor([165.25] * seq_len) - # fmt: on - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) - - def test_forward_with_norm_groups(self): - # Not implemented yet for this UNet - pass From 99b2c815216860ca0be8752caaad0a9f17eb75df Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 24 Oct 2022 11:37:00 -0700 Subject: [PATCH 37/63] fix checks / tests --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++ tests/test_models_unet_1d.py | 33 +++++++++++++++++++------ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ee748f5b1d6a..ed3750e29991 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index f50bb8785eae..2ff9315e0bf5 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -124,6 +124,23 @@ def input_shape(self): def output_shape(self): return (4, 14, 1) + def test_output(self): + # UNetRL is a value-function is different output shape + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_ema_training(self): pass @@ -132,26 +149,26 @@ def test_training(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_out_channels": (32, 64, 128, 256), "in_channels": 14, - "out_channels": 14, + "out_channels": 1, + "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], + "up_block_types": [], + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": [32, 64, 128, 256], + "layers_per_block": 1, + "always_downsample": True, } inputs_dict = self.dummy_input return init_dict, inputs_dict def test_from_pretrained_hub(self): - unet, loading_info = UNet1DModel.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" - ) value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" ) - self.assertIsNotNone(unet) - self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertIsNotNone(value_function) self.assertEqual(len(vf_loading_info["missing_keys"]), 0) - unet.to(torch_device) value_function.to(torch_device) image = value_function(**self.dummy_input) From de4b6e4672cbadd321351f191b6fa4219da361c1 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 15:49:50 -0700 Subject: [PATCH 38/63] make timesteps closer to main --- src/diffusers/models/embeddings.py | 4 +-- src/diffusers/models/unet_1d.py | 56 +++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7d2e1b677a9f..cbf7ce31bded 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -62,10 +62,10 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() - self.linear_1 = nn.Linear(channel, time_embed_dim) + self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 8f74926da505..bcc0b4636e14 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .embeddings import TimestepEmbedding, Timesteps +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps @dataclass @@ -44,11 +44,21 @@ class UNet1DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - in_channels: - out_channels: - down_block_types: - up_block_types: - block_out_channels: + sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type: + out_block_type: act_fn: norm_num_groups: """ @@ -56,8 +66,15 @@ class UNet1DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, in_channels: int = 14, out_channels: int = 14, + extra_in_channels: int = 0, + time_embedding_type: str = "positional", + flip_sin_to_cos: bool = False, + use_timestep_embedding: bool = True, + downscale_freq_shift: float = 1.0, down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), mid_block_type: Tuple[str] = "MidResTemporalBlock1D", @@ -70,13 +87,28 @@ def __init__( ): super().__init__() - time_embed_dim = block_out_channels[0] * 4 + self.sample_size = sample_size # time - self.time_proj = Timesteps(num_channels=block_out_channels[0], flip_sin_to_cos=False, downscale_freq_shift=1) - self.time_mlp = TimestepEmbedding( - channel=block_out_channels[0], time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0] - ) + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) self.down_blocks = nn.ModuleList([]) self.mid_block = None From ef6ca1ff320c8870ab865b9539695692e7166df2 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 16:06:37 -0700 Subject: [PATCH 39/63] unify block API --- src/diffusers/models/unet_1d.py | 29 +++++++++++++++++++------- src/diffusers/models/unet_1d_blocks.py | 4 +++- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index bcc0b4636e14..a54fab2afaf9 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -114,16 +115,18 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None - mid_dim = block_out_channels[-1] # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + is_final_block = i == len(block_out_channels) - 1 - down_block_type = down_block_types[i] down_block = get_down_block( down_block_type, num_layers=layers_per_block, @@ -137,8 +140,9 @@ def __init__( # mid self.mid_block = get_mid_block( mid_block_type, - in_channels=mid_dim, - out_channels=mid_dim, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, add_downsample=always_downsample, @@ -146,21 +150,30 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): - input_channel = reversed_block_out_channels[i] - output_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block, - in_channels=input_channel, + in_channels=prev_output_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) + prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) @@ -170,7 +183,7 @@ def __init__( embed_dim=block_out_channels[0], out_channels=out_channels, act_fn=act_fn, - fc_dim=mid_dim // 4, + fc_dim=block_out_channels[-1] // 4, ) def forward( diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 0788cc1e76e5..d2197ae5116f 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -318,7 +318,7 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan raise ValueError(f"{up_block_type} does not exist.") -def get_mid_block(mid_block_type, num_layers, in_channels, out_channels, embed_dim, add_downsample): +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): if mid_block_type == "MidResTemporalBlock1D": return MidResTemporalBlock1D( num_layers=num_layers, @@ -337,3 +337,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim) + else: + return None From e6f1a83a1d78aed2aaad78571e813097441c6673 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 16:35:01 -0700 Subject: [PATCH 40/63] unify forward api --- src/diffusers/models/unet_1d.py | 48 +++++++++----------------- src/diffusers/models/unet_1d_blocks.py | 30 ++-------------- tests/test_models_unet_1d.py | 1 + 3 files changed, 20 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 42043064645a..2d0c1e2c6804 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -211,48 +211,32 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - temb = self.time_proj(timesteps) - temb = self.time_mlp(temb) - down_block_res_samples = [] + timestep_embed = self.time_proj(timesteps) + if self.time_mlp: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) # 2. down - for down_block in self.down_blocks: - sample, res_samples = down_block(hidden_states=sample, temb=temb) - down_block_res_samples.append(res_samples[0]) + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples # 3. mid if self.mid_block: - sample = self.mid_block(sample, temb) + sample = self.mid_block(sample, timestep_embed) # 4. up - for up_block in self.up_blocks: - sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb) + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) # 5. post-process if self.out_block: - sample = self.out_block(sample, temb) - - # # 1. time - # if len(timestep.shape) == 0: - # timestep = timestep[None] - # - # timestep_embed = self.time_proj(timestep)[..., None] - # timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) - # - # # 2. down - # down_block_res_samples = () - # for downsample_block in self.down_blocks: - # sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) - # down_block_res_samples += res_samples - # - # # 3. mid - # sample = self.mid_block(sample) - # - # # 4. up - # for i, upsample_block in enumerate(self.up_blocks): - # res_samples = down_block_res_samples[-1:] - # down_block_res_samples = down_block_res_samples[:-1] - # sample = upsample_block(sample, res_samples) + sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 367a6871a087..98ecc07d0873 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -132,8 +132,9 @@ def __init__( if add_upsample: self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - def forward(self, hidden_states, res_hidden_states=None, temb=None): - if res_hidden_states is not None: + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = self.resnets[0](hidden_states, temb) @@ -148,31 +149,6 @@ def forward(self, hidden_states, res_hidden_states=None, temb=None): return hidden_states - -class DownBlock1D(nn.Module): - pass - - -class AttnDownBlock1D(nn.Module): - pass - - -class DownBlock1DNoSkip(nn.Module): - pass - - -class UpBlock1D(nn.Module): - pass - - -class AttnUpBlock1D(nn.Module): - pass - - -class UpBlock1DNoSkip(nn.Module): - pass - - class ValueFunctionMidBlock1D(nn.Module): def __init__(self, in_channels, out_channels, embed_dim): super().__init__() diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index 364e9193f53f..4e5e8d666413 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -59,6 +59,7 @@ def prepare_init_args_and_inputs_for_common(self): "block_out_channels": (32, 128, 256), "in_channels": 14, "out_channels": 14, + "time_embedding_type": "positional", } inputs_dict = self.dummy_input return init_dict, inputs_dict From c35a925747630d230c1a3048a64c381fb47aae87 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 16:36:42 -0700 Subject: [PATCH 41/63] delete lines in examples --- .../diffuser/run_diffuser_gen_trajectories.py | 13 ----------- examples/diffuser/run_diffuser_locomotion.py | 17 -------------- .../diffuser/run_diffuser_value_guided.py | 23 ------------------- 3 files changed, 53 deletions(-) diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/diffuser/run_diffuser_gen_trajectories.py index f4c86635c652..097222462dc3 100644 --- a/examples/diffuser/run_diffuser_gen_trajectories.py +++ b/examples/diffuser/run_diffuser_gen_trajectories.py @@ -21,21 +21,8 @@ def _run(): env_name = "hopper-medium-v2" env = gym.make(env_name) - DEVICE = config["device"] - - scheduler = DDPMScheduler( - num_train_timesteps=config["num_inference_steps"], - beta_schedule="squaredcos_cap_v2", - clip_sample=False, - variance_type="fixed_small_log", - ) - network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() - unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() pipeline = DiffusionPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", - value_function=network, - unet=unet, - scheduler=scheduler, env=env, custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/diffuser/run_diffuser_locomotion.py index 1b4351095d3b..0714769cdc2e 100644 --- a/examples/diffuser/run_diffuser_locomotion.py +++ b/examples/diffuser/run_diffuser_locomotion.py @@ -21,25 +21,8 @@ def _run(): env_name = "hopper-medium-v2" env = gym.make(env_name) - # Cuda settings for colab - # torch.cuda.get_device_name(0) - DEVICE = config["device"] - - # Two generators for different parts of the diffusion loop to work in colab - scheduler = DDPMScheduler( - num_train_timesteps=config["num_inference_steps"], - beta_schedule="squaredcos_cap_v2", - clip_sample=False, - variance_type="fixed_small_log", - ) - - network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() - unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() pipeline = DiffusionPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", - value_function=network, - unet=unet, - scheduler=scheduler, env=env, custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py index 4272ec2c3106..8c0ec54fcf97 100644 --- a/examples/diffuser/run_diffuser_value_guided.py +++ b/examples/diffuser/run_diffuser_value_guided.py @@ -23,34 +23,11 @@ def _run(): env_name = "hopper-medium-v2" env = gym.make(env_name) - # Cuda settings for colab - # torch.cuda.get_device_name(0) - DEVICE = config["device"] - - # Two generators for different parts of the diffusion loop to work in colab - scheduler = DDPMScheduler( - num_train_timesteps=config["num_inference_steps"], - beta_schedule="squaredcos_cap_v2", - clip_sample=False, - variance_type="fixed_small_log", - ) - - # 3 different pretrained models are available for this task. - # The horizion represents the length of trajectories used in training. - # network = ValueFunction(training_horizon=horizon, dim=32, dim_mults=(1, 2, 4, 8), transition_dim=14, cond_dim=11) - - network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() - unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() pipeline = DiffusionPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", - value_function=network, - unet=unet, - scheduler=scheduler, env=env, custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) - # unet = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) - # network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) # add a batch dimension and repeat for multiple samples # [ observation_dim ] --> [ n_samples x observation_dim ] From 949b93a17d14288d547e98016888b55eb22e7c8c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 16:37:55 -0700 Subject: [PATCH 42/63] style --- src/diffusers/models/unet_1d.py | 3 +-- src/diffusers/models/unet_1d_blocks.py | 5 ++--- tests/test_models_unet_1d.py | 3 ++- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 2d0c1e2c6804..da61fd1d145b 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -18,12 +18,11 @@ import torch import torch.nn as nn - from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block, get_out_block +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 98ecc07d0873..697de5184cfb 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -149,6 +149,7 @@ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): return hidden_states + class ValueFunctionMidBlock1D(nn.Module): def __init__(self, in_channels, out_channels, embed_dim): super().__init__() @@ -267,8 +268,6 @@ def forward(self, hidden_states, temb): return hidden_states - - _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], @@ -666,4 +665,4 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim) - return None \ No newline at end of file + return None diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index 4e5e8d666413..edda27a7ad6a 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -18,9 +18,9 @@ import torch from diffusers import UNet1DModel +from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin -from diffusers.utils import floats_tensor, slow, torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -202,6 +202,7 @@ def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + class UnetModel1DTests(unittest.TestCase): @slow def test_unet_1d_maestro(self): From 2f6462b2f685c9f67b6d4c9d090ea78e3c15a9dd Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 16:43:30 -0700 Subject: [PATCH 43/63] examples style --- examples/diffuser/run_diffuser_gen_trajectories.py | 2 +- examples/diffuser/run_diffuser_locomotion.py | 2 +- examples/diffuser/run_diffuser_value_guided.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/diffuser/run_diffuser_gen_trajectories.py index 097222462dc3..3de8521343e3 100644 --- a/examples/diffuser/run_diffuser_gen_trajectories.py +++ b/examples/diffuser/run_diffuser_gen_trajectories.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel +from diffusers import DiffusionPipeline config = dict( diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/diffuser/run_diffuser_locomotion.py index 0714769cdc2e..9ac9df28db81 100644 --- a/examples/diffuser/run_diffuser_locomotion.py +++ b/examples/diffuser/run_diffuser_locomotion.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel +from diffusers import DiffusionPipeline config = dict( diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py index 8c0ec54fcf97..707663abb3bf 100644 --- a/examples/diffuser/run_diffuser_value_guided.py +++ b/examples/diffuser/run_diffuser_value_guided.py @@ -1,9 +1,7 @@ import d4rl # noqa import gym import tqdm - -# import train_diffuser -from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel +from diffusers import DiffusionPipeline config = dict( From a2dd559e12ddf9255dfea39d8fa77f92a5cc4f91 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 17:10:00 -0700 Subject: [PATCH 44/63] all tests pass --- src/diffusers/models/unet_1d.py | 6 +++--- src/diffusers/models/unet_1d_blocks.py | 10 +++++----- tests/test_models_unet_1d.py | 3 +++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index da61fd1d145b..7fdee9ea84ba 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -44,7 +44,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. @@ -78,7 +78,7 @@ def __init__( down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), mid_block_type: Tuple[str] = "MidResTemporalBlock1D", - out_block_type: str = "OutConv1DBlock", + out_block_type: str = None, block_out_channels: Tuple[int] = (32, 128, 256), act_fn: str = "mish", norm_num_groups: int = 8, @@ -211,7 +211,7 @@ def forward( timesteps = timesteps[None].to(sample.device) timestep_embed = self.time_proj(timesteps) - if self.time_mlp: + if self.config.use_timestep_embedding: timestep_embed = self.time_mlp(timestep_embed) else: timestep_embed = timestep_embed[..., None] diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 697de5184cfb..fc758ebbb044 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -312,7 +312,7 @@ def __init__(self, kernel="linear", pad_mode="reflect"): self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) @@ -441,7 +441,7 @@ def __init__(self, mid_channels, in_channels, out_channels=None): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) @@ -546,7 +546,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -573,7 +573,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -598,7 +598,7 @@ def __init__(self, in_channels, out_channels, mid_channels=None): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index edda27a7ad6a..26ef5b419345 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -60,6 +60,8 @@ def prepare_init_args_and_inputs_for_common(self): "in_channels": 14, "out_channels": 14, "time_embedding_type": "positional", + "use_timestep_embedding": True, + "out_block_type": "OutConv1DBlock", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -159,6 +161,7 @@ def prepare_init_args_and_inputs_for_common(self): "block_out_channels": [32, 64, 128, 256], "layers_per_block": 1, "always_downsample": True, + "use_timestep_embedding": True } inputs_dict = self.dummy_input return init_dict, inputs_dict From 39dff7331bc50bdceffd616dd6b866eedc1b3c7a Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 25 Oct 2022 17:11:31 -0700 Subject: [PATCH 45/63] make style --- tests/test_models_unet_1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index 26ef5b419345..ab86b5b6f202 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -161,7 +161,7 @@ def prepare_init_args_and_inputs_for_common(self): "block_out_channels": [32, 64, 128, 256], "layers_per_block": 1, "always_downsample": True, - "use_timestep_embedding": True + "use_timestep_embedding": True, } inputs_dict = self.dummy_input return init_dict, inputs_dict From d5eedff900f9823b2f0ad0a22ad1adc676441288 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 26 Oct 2022 15:08:43 -0700 Subject: [PATCH 46/63] make dance_diff test pass --- tests/pipelines/dance_diffusion/test_dance_diffusion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 72e67e4479d2..a63ef84c63f5 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -44,6 +44,10 @@ def dummy_unet(self): sample_rate=16_000, in_channels=2, out_channels=2, + flip_sin_to_cos=True, + use_timestep_embedding=False, + time_embedding_type="fourier", + mid_block_type="UNetMidBlock1D", down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], ) From faeacd56e50dcaf8625851fa182b28c55968864a Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 8 Nov 2022 14:32:09 -0800 Subject: [PATCH 47/63] Refactoring RL PR (#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes --- docs/source/api/models.mdx | 15 +- examples/README.md | 5 +- examples/community/value_guided_diffuser.py | 108 ------ examples/diffuser/run_diffuser.py | 122 ------- .../diffuser/run_diffuser_value_guided.py | 69 ---- examples/diffuser/train_diffuser.py | 312 ------------------ examples/{diffuser => rl}/README.md | 5 +- .../run_diffuser_gen_trajectories.py | 15 +- .../run_diffuser_locomotion.py | 15 +- src/diffusers/__init__.py | 1 + src/diffusers/experimental/README.md | 5 + src/diffusers/experimental/__init__.py | 1 + src/diffusers/experimental/rl/__init__.py | 1 + .../experimental/rl/value_guided_sampling.py | 23 +- tests/test_models_unet_1d.py | 40 ++- 15 files changed, 65 insertions(+), 672 deletions(-) delete mode 100644 examples/community/value_guided_diffuser.py delete mode 100644 examples/diffuser/run_diffuser.py delete mode 100644 examples/diffuser/run_diffuser_value_guided.py delete mode 100644 examples/diffuser/train_diffuser.py rename examples/{diffuser => rl}/README.md (56%) rename examples/{diffuser => rl}/run_diffuser_gen_trajectories.py (85%) rename examples/{diffuser => rl}/run_diffuser_locomotion.py (85%) create mode 100644 src/diffusers/experimental/README.md create mode 100644 src/diffusers/experimental/__init__.py create mode 100644 src/diffusers/experimental/rl/__init__.py rename examples/community/pipeline.py => src/diffusers/experimental/rl/value_guided_sampling.py (83%) diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index a6d342f575a9..893fc6bea0ca 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput -## UNet1DModel -[[autodoc]] UNet1DModel - ## UNet2DModel [[autodoc]] UNet2DModel +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput + +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DConditionOutput [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput @@ -37,12 +40,6 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## DecoderOutput [[autodoc]] models.vae.DecoderOutput -## UNet1DModel -[[autodoc]] UNet1DModel - -## UNet1DOutput -[[autodoc]] models.unet_1d.UNet1DOutput - ## VQEncoderOutput [[autodoc]] models.vae.VQEncoderOutput diff --git a/examples/README.md b/examples/README.md index 2680b638d585..a50fbfc3a713 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,9 +36,10 @@ If you feel like another important example should exist, we are more than happy Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support: -| Task | ๐Ÿค— Accelerate | ๐Ÿค— Datasets | Colab -|---|---|:---:|:---:| +| Task | ๐Ÿค— Accelerate | ๐Ÿค— Datasets | Colab +|---------------------------------------------------------------------------------------------------------------------------------------------------------|---|:---:|:---:| | [**Unconditional Image Generation**](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) | โœ… | โœ… | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | | | coming soon. ## Community diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py deleted file mode 100644 index 6b28e868eddd..000000000000 --- a/examples/community/value_guided_diffuser.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch - -import tqdm -from diffusers import DiffusionPipeline -from diffusers.models.unet_1d import UNet1DModel -from diffusers.utils.dummy_pt_objects import DDPMScheduler - - -class ValueGuidedDiffuserPipeline(DiffusionPipeline): - def __init__( - self, - value_function: UNet1DModel, - unet: UNet1DModel, - scheduler: DDPMScheduler, - env, - ): - super().__init__() - self.value_function = value_function - self.unet = unet - self.scheduler = scheduler - self.env = env - self.data = env.get_dataset() - self.means = dict() - for key in self.data.keys(): - try: - self.means[key] = self.data[key].mean() - except: - pass - self.stds = dict() - for key in self.data.keys(): - try: - self.stds[key] = self.data[key].std() - except: - pass - self.state_dim = env.observation_space.shape[0] - self.action_dim = env.action_space.shape[0] - - def normalize(self, x_in, key): - return (x_in - self.means[key]) / self.stds[key] - - def de_normalize(self, x_in, key): - return x_in * self.stds[key] + self.means[key] - - def to_torch(self, x_in): - if type(x_in) is dict: - return {k: self.to_torch(v) for k, v in x_in.items()} - elif torch.is_tensor(x_in): - return x_in.to(self.unet.device) - return torch.tensor(x_in, device=self.unet.device) - - def reset_x0(self, x_in, cond, act_dim): - for key, val in cond.items(): - x_in[:, key, act_dim:] = val.clone() - return x_in - - def run_diffusion(self, x, conditions, n_guide_steps, scale): - batch_size = x.shape[0] - y = None - for i in tqdm.tqdm(self.scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - for _ in range(n_guide_steps): - with torch.enable_grad(): - x.requires_grad_() - y = self.value_function(x.permute(0, 2, 1), timesteps).sample - grad = torch.autograd.grad([y.sum()], [x])[0] - - posterior_variance = self.scheduler._get_variance(i) - model_std = torch.exp(0.5 * posterior_variance) - grad = model_std * grad - grad[timesteps < 2] = 0 - x = x.detach() - x = x + scale * grad - x = self.reset_x0(x, conditions, self.action_dim) - prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) - x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - - # apply conditions to the trajectory - x = self.reset_x0(x, conditions, self.action_dim) - x = self.to_torch(x) - return x, y - - def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): - # normalize the observations and create batch dimension - obs = self.normalize(obs, "observations") - obs = obs[None].repeat(batch_size, axis=0) - - conditions = {0: self.to_torch(obs)} - shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) - - # generate initial noise and apply our conditions (to make the trajectories start at current state) - x1 = torch.randn(shape, device=self.unet.device) - x = self.reset_x0(x1, conditions, self.action_dim) - x = self.to_torch(x) - - # run the diffusion process - x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) - - # sort output trajectories by value - sorted_idx = y.argsort(0, descending=True).squeeze() - sorted_values = x[sorted_idx] - actions = sorted_values[:, :, : self.action_dim] - actions = actions.detach().cpu().numpy() - denorm_actions = self.de_normalize(actions, key="actions") - - # select the action with the highest value - denorm_actions = denorm_actions[0, 0] - return denorm_actions diff --git a/examples/diffuser/run_diffuser.py b/examples/diffuser/run_diffuser.py deleted file mode 100644 index b29d89992dfc..000000000000 --- a/examples/diffuser/run_diffuser.py +++ /dev/null @@ -1,122 +0,0 @@ -import numpy as np -import torch - -import d4rl # noqa -import gym -import tqdm -import train_diffuser -from diffusers import DDPMScheduler, UNet1DModel - - -env_name = "hopper-medium-expert-v2" -env = gym.make(env_name) -data = env.get_dataset() # dataset is only used for normalization in this colab - -DEVICE = "cpu" -DTYPE = torch.float - -# diffusion model settings -n_samples = 4 # number of trajectories planned via diffusion -horizon = 128 # length of sampled trajectories -state_dim = env.observation_space.shape[0] -action_dim = env.action_space.shape[0] -num_inference_steps = 100 # number of difusion steps - - -# Two generators for different parts of the diffusion loop to work in colab -generator_cpu = torch.Generator(device="cpu") - -scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") - -# 3 different pretrained models are available for this task. -# The horizion represents the length of trajectories used in training. -network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) -# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) -# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) - - -# network specific constants for inference -clip_denoised = network.clip_denoised -predict_epsilon = network.predict_epsilon - -# [ observation_dim ] --> [ n_samples x observation_dim ] -obs = env.reset() -total_reward = 0 -done = False -T = 300 -rollout = [obs.copy()] - -try: - for t in tqdm.tqdm(range(T)): - obs_raw = obs - - # normalize observations for forward passes - obs = train_diffuser.normalize(obs, data, "observations") - obs = obs[None].repeat(n_samples, axis=0) - conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} - - # constants for inference - batch_size = len(conditions[0]) - shape = (batch_size, horizon, state_dim + action_dim) - - # sample random initial noise vector - x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) - - # this model is conditioned from an initial state, so you will see this function - # multiple times to change the initial state of generated data to the state - # generated via env.reset() above or env.step() below - x = train_diffuser.reset_x0(x1, conditions, action_dim) - - # convert a np observation to torch for model forward pass - x = train_diffuser.to_torch(x) - - eta = 1.0 # noise factor for sampling reconstructed state - - # run the diffusion process - # for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - for i in tqdm.tqdm(scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) - - # 1. generate prediction from model - with torch.no_grad(): - residual = network(x, timesteps).sample - - # 2. use the model prediction to reconstruct an observation (de-noise) - obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] - - # 3. [optional] add posterior noise to the sample - if eta > 0: - noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) - posterior_variance = scheduler._get_variance(i) # * noise - # no noise when t == 0 - # NOTE: original implementation missing sqrt on posterior_variance - obs_reconstruct = ( - obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise - ) # MJ had as log var, exponentiated - - # 4. apply conditions to the trajectory - obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) - x = train_diffuser.to_torch(obs_reconstruct_postcond) - plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) - # select random plan - idx = np.random.randint(plans.shape[0]) - # select action at correct time - action = plans[idx, 0, :] - actions = train_diffuser.de_normalize(action, data, "actions") - # execute action in environment - next_observation, reward, terminal, _ = env.step(action) - - # update return - total_reward += reward - print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") - - # save observations for rendering - rollout.append(next_observation.copy()) - obs = next_observation -except KeyboardInterrupt: - pass - -print(f"Total reward: {total_reward}") -render = train_diffuser.MuJoCoRenderer(env) -train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) diff --git a/examples/diffuser/run_diffuser_value_guided.py b/examples/diffuser/run_diffuser_value_guided.py deleted file mode 100644 index 707663abb3bf..000000000000 --- a/examples/diffuser/run_diffuser_value_guided.py +++ /dev/null @@ -1,69 +0,0 @@ -import d4rl # noqa -import gym -import tqdm -from diffusers import DiffusionPipeline - - -config = dict( - n_samples=64, - horizon=32, - num_inference_steps=20, - n_guide_steps=2, - scale_grad_by_std=True, - scale=0.1, - eta=0.0, - t_grad_cutoff=2, - device="cpu", -) - - -def _run(): - env_name = "hopper-medium-v2" - env = gym.make(env_name) - - pipeline = DiffusionPipeline.from_pretrained( - "bglick13/hopper-medium-v2-value-function-hor32", - env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", - ) - - # add a batch dimension and repeat for multiple samples - # [ observation_dim ] --> [ n_samples x observation_dim ] - env.seed(0) - obs = env.reset() - total_reward = 0 - total_score = 0 - T = 1000 - rollout = [obs.copy()] - try: - for t in tqdm.tqdm(range(T)): - # 1. Call the policy - # normalize observations for forward passes - denorm_actions = pipeline(obs, planning_horizon=32) - - # execute action in environment - next_observation, reward, terminal, _ = env.step(denorm_actions) - score = env.get_normalized_score(total_reward) - # update return - total_reward += reward - total_score += score - print( - f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" - f" {total_score}" - ) - # save observations for rendering - rollout.append(next_observation.copy()) - - obs = next_observation - except KeyboardInterrupt: - pass - - print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/examples/diffuser/train_diffuser.py b/examples/diffuser/train_diffuser.py deleted file mode 100644 index b063a0456d97..000000000000 --- a/examples/diffuser/train_diffuser.py +++ /dev/null @@ -1,312 +0,0 @@ -import os -import warnings - -import numpy as np -import torch - -import d4rl # noqa -import gym -import mediapy as media -import mujoco_py as mjc -import tqdm -from diffusers import DDPMScheduler, UNet1DModel - - -# Define some helper functions - - -DTYPE = torch.float - - -def normalize(x_in, data, key): - means = data[key].mean(axis=0) - stds = data[key].std(axis=0) - return (x_in - means) / stds - - -def de_normalize(x_in, data, key): - means = data[key].mean(axis=0) - stds = data[key].std(axis=0) - return x_in * stds + means - - -def to_torch(x_in, dtype=None, device="cuda"): - dtype = dtype or DTYPE - device = device - if type(x_in) is dict: - return {k: to_torch(v, dtype, device) for k, v in x_in.items()} - elif torch.is_tensor(x_in): - return x_in.to(device).type(dtype) - return torch.tensor(x_in, dtype=dtype, device=device) - - -def reset_x0(x_in, cond, act_dim): - for key, val in cond.items(): - x_in[:, key, act_dim:] = val.clone() - return x_in - - -def run_diffusion(x, scheduler, network, unet, conditions, action_dim, config): - y = None - for i in tqdm.tqdm(scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long) - # 3. call the sample function - for _ in range(config["n_guide_steps"]): - with torch.enable_grad(): - x.requires_grad_() - y = network(x, timesteps).sample - grad = torch.autograd.grad([y.sum()], [x])[0] - if config["scale_grad_by_std"]: - posterior_variance = scheduler._get_variance(i) - model_std = torch.exp(0.5 * posterior_variance) - grad = model_std * grad - grad[timesteps < config["t_grad_cutoff"]] = 0 - x = x.detach() - x = x + config["scale"] * grad - x = reset_x0(x, conditions, action_dim) - # with torch.no_grad(): - prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) - x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - - # 3. [optional] add posterior noise to the sample - if config["eta"] > 0: - noise = torch.randn(x.shape).to(x.device) - posterior_variance = scheduler._get_variance(i) # * noise - # no noise when t == 0 - # NOTE: original implementation missing sqrt on posterior_variance - x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated - - # 4. apply conditions to the trajectory - x = reset_x0(x, conditions, action_dim) - x = to_torch(x, device=config["device"]) - # y = network(x, timesteps).sample - return x, y - - -def to_np(x_in): - if torch.is_tensor(x_in): - x_in = x_in.detach().cpu().numpy() - return x_in - - -# from MJ's Diffuser code -# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79 -def mkdir(savepath): - """ - returns `True` iff `savepath` is created - """ - if not os.path.exists(savepath): - os.makedirs(savepath) - return True - else: - return False - - -def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"): - """ - observations : [ batch_size x horizon x observation_dim ] - """ - - mkdir(savebase) - savepath = os.path.join(savebase, filename) - - images = [] - for rollout in observations: - # [ horizon x height x width x channels ] - img = renderer._renders(rollout, partial=True) - images.append(img) - - # [ horizon x height x (batch_size * width) x channels ] - images = np.concatenate(images, axis=2) - media.write_video(savepath, images, fps=60) - media.show_video(images, codec="h264", fps=60) - return images - - -# Code adapted from Michael Janner -# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py - - -def env_map(env_name): - """ - map D4RL dataset names to custom fully-observed - variants for rendering - """ - if "halfcheetah" in env_name: - return "HalfCheetahFullObs-v2" - elif "hopper" in env_name: - return "HopperFullObs-v2" - elif "walker2d" in env_name: - return "Walker2dFullObs-v2" - else: - return env_name - - -def get_image_mask(img): - background = (img == 255).all(axis=-1, keepdims=True) - mask = ~background.repeat(3, axis=-1) - return mask - - -def atmost_2d(x): - while x.ndim > 2: - x = x.squeeze(0) - return x - - -def set_state(env, state): - qpos_dim = env.sim.data.qpos.size - qvel_dim = env.sim.data.qvel.size - if not state.size == qpos_dim + qvel_dim: - warnings.warn( - f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}" - ) - state = state[: qpos_dim + qvel_dim] - - env.set_state(state[:qpos_dim], state[qpos_dim:]) - - -class MuJoCoRenderer: - """ - default mujoco renderer - """ - - def __init__(self, env): - if type(env) is str: - env = env_map(env) - self.env = gym.make(env) - else: - self.env = env - # - 1 because the envs in renderer are fully-observed - # @TODO : clean up - self.observation_dim = np.prod(self.env.observation_space.shape) - 1 - self.action_dim = np.prod(self.env.action_space.shape) - try: - self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) - except: - print("[ utils/rendering ] Warning: could not initialize offscreen renderer") - self.viewer = None - - def pad_observation(self, observation): - state = np.concatenate( - [ - np.zeros(1), - observation, - ] - ) - return state - - def pad_observations(self, observations): - qpos_dim = self.env.sim.data.qpos.size - # xpos is hidden - xvel_dim = qpos_dim - 1 - xvel = observations[:, xvel_dim] - xpos = np.cumsum(xvel) * self.env.dt - states = np.concatenate( - [ - xpos[:, None], - observations, - ], - axis=-1, - ) - return states - - def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): - if type(dim) == int: - dim = (dim, dim) - - if self.viewer is None: - return np.zeros((*dim, 3), np.uint8) - - if render_kwargs is None: - xpos = observation[0] if not partial else 0 - render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20} - - for key, val in render_kwargs.items(): - if key == "lookat": - self.viewer.cam.lookat[:] = val[:] - else: - setattr(self.viewer.cam, key, val) - - if partial: - state = self.pad_observation(observation) - else: - state = observation - - qpos_dim = self.env.sim.data.qpos.size - if not qvel or state.shape[-1] == qpos_dim: - qvel_dim = self.env.sim.data.qvel.size - state = np.concatenate([state, np.zeros(qvel_dim)]) - - set_state(self.env, state) - - self.viewer.render(*dim) - data = self.viewer.read_pixels(*dim, depth=False) - data = data[::-1, :, :] - return data - - def _renders(self, observations, **kwargs): - images = [] - for observation in observations: - img = self.render(observation, **kwargs) - images.append(img) - return np.stack(images, axis=0) - - def renders(self, samples, partial=False, **kwargs): - if partial: - samples = self.pad_observations(samples) - partial = False - - sample_images = self._renders(samples, partial=partial, **kwargs) - - composite = np.ones_like(sample_images[0]) * 255 - - for img in sample_images: - mask = get_image_mask(img) - composite[mask] = img[mask] - - return composite - - def __call__(self, *args, **kwargs): - return self.renders(*args, **kwargs) - - -env_name = "hopper-medium-expert-v2" -env = gym.make(env_name) -data = env.get_dataset() # dataset is only used for normalization in this colab - -# Cuda settings for colab -# torch.cuda.get_device_name(0) -DEVICE = "cpu" -DTYPE = torch.float - -# diffusion model settings -n_samples = 4 # number of trajectories planned via diffusion -horizon = 128 # length of sampled trajectories -state_dim = env.observation_space.shape[0] -action_dim = env.action_space.shape[0] -num_inference_steps = 100 # number of difusion steps - -obs = env.reset() -obs_raw = obs - -# normalize observations for forward passes -obs = normalize(obs, data, "observations") - - -# Two generators for different parts of the diffusion loop to work in colab -generator = torch.Generator(device="cuda") -generator_cpu = torch.Generator(device="cpu") -network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) - -scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") -optimizer = torch.optim.AdamW( - network.parameters(), - lr=0.001, - betas=(0.95, 0.99), - weight_decay=1e-6, - eps=1e-8, -) - -# TODO: Flesh this out using accelerate library (a la other examples) diff --git a/examples/diffuser/README.md b/examples/rl/README.md similarity index 56% rename from examples/diffuser/README.md rename to examples/rl/README.md index 464ccd57af85..dd8add8aa4ea 100644 --- a/examples/diffuser/README.md +++ b/examples/rl/README.md @@ -1,6 +1,9 @@ # Overview -These examples show how to run (Diffuser)[https://arxiv.org/pdf/2205.09991.pdf] in Diffusers. There are two scripts, `run_diffuser_value_guided.py` and `run_diffuser.py`. +These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. +There are four scripts, +1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, +2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. You will need some RL specific requirements to run the examples: diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py similarity index 85% rename from examples/diffuser/run_diffuser_gen_trajectories.py rename to examples/rl/run_diffuser_gen_trajectories.py index 3de8521343e3..4f04d3acd704 100644 --- a/examples/diffuser/run_diffuser_gen_trajectories.py +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DiffusionPipeline +from diffusers import ValueGuidedRLPipeline config = dict( @@ -17,14 +17,13 @@ ) -def _run(): +if __name__ == "__main__": env_name = "hopper-medium-v2" env = gym.make(env_name) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = ValueGuidedRLPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) env.seed(0) @@ -56,11 +55,3 @@ def _run(): pass print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py similarity index 85% rename from examples/diffuser/run_diffuser_locomotion.py rename to examples/rl/run_diffuser_locomotion.py index 9ac9df28db81..ad2fc8785f15 100644 --- a/examples/diffuser/run_diffuser_locomotion.py +++ b/examples/rl/run_diffuser_locomotion.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import DiffusionPipeline +from diffusers import ValueGuidedRLPipeline config = dict( @@ -17,14 +17,13 @@ ) -def _run(): +if __name__ == "__main__": env_name = "hopper-medium-v2" env = gym.make(env_name) - pipeline = DiffusionPipeline.from_pretrained( + pipeline = ValueGuidedRLPipeline.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", env=env, - custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", ) env.seed(0) @@ -56,11 +55,3 @@ def _run(): pass print(f"Total reward: {total_reward}") - - -def run(): - _run() - - -if __name__ == "__main__": - run() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2c531cf8cee0..2a16132d3d8c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -17,6 +17,7 @@ if is_torch_available(): + from .experimental import ValueGuidedRLPipeline from .modeling_utils import ModelMixin from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md new file mode 100644 index 000000000000..81a9de81c737 --- /dev/null +++ b/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# ๐Ÿงจ Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py new file mode 100644 index 000000000000..ebc815540301 --- /dev/null +++ b/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py new file mode 100644 index 000000000000..7b338d3173e1 --- /dev/null +++ b/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/examples/community/pipeline.py b/src/diffusers/experimental/rl/value_guided_sampling.py similarity index 83% rename from examples/community/pipeline.py rename to src/diffusers/experimental/rl/value_guided_sampling.py index 85e359c5c4c9..8d5062e3d4c5 100644 --- a/examples/community/pipeline.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -1,13 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import torch import tqdm -from diffusers import DiffusionPipeline -from diffusers.models.unet_1d import UNet1DModel -from diffusers.utils.dummy_pt_objects import DDPMScheduler + +from ...models.unet_1d import UNet1DModel +from ...pipeline_utils import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler -class ValueGuidedDiffuserPipeline(DiffusionPipeline): +class ValueGuidedRLPipeline(DiffusionPipeline): def __init__( self, value_function: UNet1DModel, diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py index ab86b5b6f202..dd320e8bd655 100644 --- a/tests/test_models_unet_1d.py +++ b/tests/test_models_unet_1d.py @@ -104,6 +104,25 @@ def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + @slow + def test_unet_1d_maestro(self): + model_id = "harmonai/maestro-150k" + model = UNet1DModel.from_pretrained(model_id, subfolder="unet") + model.to(torch_device) + + sample_size = 65536 + noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) + timestep = torch.tensor([1]).to(torch_device) + + with torch.no_grad(): + output = model(noise, timestep).sample + + output_sum = output.abs().sum() + output_max = output.abs().max() + + assert (output_sum - 224.0896).abs() < 4e-2 + assert (output_max - 0.0607).abs() < 4e-4 + class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet1DModel @@ -204,24 +223,3 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass - - -class UnetModel1DTests(unittest.TestCase): - @slow - def test_unet_1d_maestro(self): - model_id = "harmonai/maestro-150k" - model = UNet1DModel.from_pretrained(model_id, subfolder="unet") - model.to(torch_device) - - sample_size = 65536 - noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) - timestep = torch.tensor([1]).to(torch_device) - - with torch.no_grad(): - output = model(noise, timestep).sample - - output_sum = output.abs().sum() - output_max = output.abs().max() - - assert (output_sum - 224.0896).abs() < 4e-2 - assert (output_max - 0.0607).abs() < 4e-4 From 72b7ee841a004b7efe68ccd91bcfc7e6fd18da4a Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 8 Nov 2022 14:41:33 -0800 Subject: [PATCH 48/63] hotfix for tests --- tests/models/test_models_unet_1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index dd320e8bd655..487a711a51af 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -20,7 +20,7 @@ from diffusers import UNet1DModel from diffusers.utils import floats_tensor, slow, torch_device -from .test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin torch.backends.cuda.matmul.allow_tf32 = False From cf76a2d7762b81907ef21db8322a5184c6d83fd6 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 8 Nov 2022 14:42:16 -0800 Subject: [PATCH 49/63] quality --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9d296d29977d..ee9937d4aab9 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -4,6 +4,21 @@ from ..utils import DummyObject, requires_backends +class ValueGuidedRLPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] From 229035633723cc717331f45b5838443c5eaa8979 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:05:04 -0800 Subject: [PATCH 50/63] fix some tests --- tests/models/test_models_unet_1d.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 487a711a51af..e791958b178f 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -54,6 +54,10 @@ def test_ema_training(self): def test_training(self): pass + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism(self) + def prepare_init_args_and_inputs_for_common(self): init_dict = { "block_out_channels": (32, 128, 256), @@ -66,6 +70,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_hub(self): model, loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" @@ -78,6 +83,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output_pretrained(self): model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") torch.manual_seed(0) @@ -185,6 +191,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_hub(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" @@ -197,6 +204,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output_pretrained(self): value_function, vf_loading_info = UNet1DModel.from_pretrained( "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" From a061f7e015f5ef2fe51f351fefc8d12c6cfaaca2 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:11:54 -0800 Subject: [PATCH 51/63] change defaults --- src/diffusers/models/unet_1d.py | 30 +++++++++++++++-------------- tests/models/test_models_unet_1d.py | 2 +- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 7fdee9ea84ba..87f2485972d7 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -48,7 +48,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + downscale_freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : @@ -57,10 +57,12 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. - mid_block_type: - out_block_type: - act_fn: - norm_num_groups: + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + always_downsample (`int`, *optional*, defaults to False: experimental feature for using a UNet without upsampling. """ @register_to_config @@ -68,19 +70,19 @@ def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, - in_channels: int = 14, - out_channels: int = 14, + in_channels: int = 2, + out_channels: int = 2, extra_in_channels: int = 0, - time_embedding_type: str = "positional", - flip_sin_to_cos: bool = False, - use_timestep_embedding: bool = True, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, downscale_freq_shift: float = 1.0, - down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), - up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"), - mid_block_type: Tuple[str] = "MidResTemporalBlock1D", + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 128, 256), - act_fn: str = "mish", + act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, always_downsample: bool = False, diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index e791958b178f..46055db88ca0 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -56,7 +56,7 @@ def test_training(self): @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_determinism(self): - super().test_determinism(self) + super().test_determinism() def prepare_init_args_and_inputs_for_common(self): init_dict = { From 0c58758e455d9ec500cc5531daccdab4eed48616 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:24:42 -0800 Subject: [PATCH 52/63] more mps test fixes --- src/diffusers/models/unet_1d.py | 3 ++- tests/models/test_models_unet_1d.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 87f2485972d7..e387d4738be8 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -62,7 +62,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. - always_downsample (`int`, *optional*, defaults to False: experimental feature for using a UNet without upsampling. + always_downsample (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. """ @register_to_config diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 46055db88ca0..6ae9cfa3b4a1 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -58,6 +58,10 @@ def test_training(self): def test_determinism(self): super().test_determinism() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + def prepare_init_args_and_inputs_for_common(self): init_dict = { "block_out_channels": (32, 128, 256), @@ -152,6 +156,15 @@ def input_shape(self): def output_shape(self): return (4, 14, 1) + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): # UNetRL is a value-function is different output shape init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 691ddee8bc18c3692a249dfe37e3b240595750e6 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:27:39 -0800 Subject: [PATCH 53/63] unet1d defaults --- src/diffusers/models/unet_1d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index e387d4738be8..eee0d8b06f0a 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -77,12 +77,12 @@ def __init__( time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, - downscale_freq_shift: float = 1.0, + downscale_freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, - block_out_channels: Tuple[int] = (32, 128, 256), + block_out_channels: Tuple[int] = (32, 32, 64), act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, From 4948ca71dccb9dbf15726a55271e517b68c05533 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:29:31 -0800 Subject: [PATCH 54/63] do not default import experimental --- examples/rl/run_diffuser_gen_trajectories.py | 2 +- examples/rl/run_diffuser_locomotion.py | 2 +- src/diffusers/__init__.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/rl/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py index 4f04d3acd704..5bb068cc9fc7 100644 --- a/examples/rl/run_diffuser_gen_trajectories.py +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import ValueGuidedRLPipeline +from diffusers.experimental import ValueGuidedRLPipeline config = dict( diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py index ad2fc8785f15..e89181610b33 100644 --- a/examples/rl/run_diffuser_locomotion.py +++ b/examples/rl/run_diffuser_locomotion.py @@ -1,7 +1,7 @@ import d4rl # noqa import gym import tqdm -from diffusers import ValueGuidedRLPipeline +from diffusers.experimental import ValueGuidedRLPipeline config = dict( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6a35011dfa2b..da56dc888138 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -17,7 +17,6 @@ if is_torch_available(): - from .experimental import ValueGuidedRLPipeline from .modeling_utils import ModelMixin from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( From ac886772216db7b4c8e6ff62c024c9b9e9d6c960 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 12:38:30 -0800 Subject: [PATCH 55/63] defaults for tests --- tests/models/test_models_unet_1d.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 6ae9cfa3b4a1..0ad6e876c0c3 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -69,7 +69,13 @@ def prepare_init_args_and_inputs_for_common(self): "out_channels": 14, "time_embedding_type": "positional", "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "downscale_freq_shift": 1.0, "out_block_type": "OutConv1DBlock", + "mid_block_type": "MidResTemporalBlock1D", + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D"), + "act_fn": "mish", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -200,6 +206,8 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "always_downsample": True, "use_timestep_embedding": True, + "downscale_freq_shift": 1.0, + "flip_sin_to_cos": False, } inputs_dict = self.dummy_input return init_dict, inputs_dict From ba204db137061d3621d1a450d9c438b5cf7edaca Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 13:11:15 -0800 Subject: [PATCH 56/63] fix tests --- src/diffusers/models/unet_1d.py | 1 - tests/models/test_models_unet_1d.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index eee0d8b06f0a..c4fa92275fe6 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -89,7 +89,6 @@ def __init__( always_downsample: bool = False, ): super().__init__() - self.sample_size = sample_size # time diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 0ad6e876c0c3..bcbf6926da29 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -64,7 +64,7 @@ def test_outputs_equivalence(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { - "block_out_channels": (32, 128, 256), + "block_out_channels": (32, 64, 128, 256), "in_channels": 14, "out_channels": 14, "time_embedding_type": "positional", @@ -73,8 +73,8 @@ def prepare_init_args_and_inputs_for_common(self): "downscale_freq_shift": 1.0, "out_block_type": "OutConv1DBlock", "mid_block_type": "MidResTemporalBlock1D", - "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), - "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D"), + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), "act_fn": "mish", } inputs_dict = self.dummy_input @@ -197,7 +197,7 @@ def test_training(self): def prepare_init_args_and_inputs_for_common(self): init_dict = { "in_channels": 14, - "out_channels": 1, + "out_channels": 14, "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], "up_block_types": [], "out_block_type": "ValueFunction", @@ -208,6 +208,8 @@ def prepare_init_args_and_inputs_for_common(self): "use_timestep_embedding": True, "downscale_freq_shift": 1.0, "flip_sin_to_cos": False, + "time_embedding_type": "positional", + "act_fn": "mish", } inputs_dict = self.dummy_input return init_dict, inputs_dict From 915c41e450d254ecc4dedf3ef7fe169ac7621cf8 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 9 Nov 2022 13:13:39 -0800 Subject: [PATCH 57/63] fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ee9937d4aab9..9d296d29977d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -4,21 +4,6 @@ from ..utils import DummyObject, requires_backends -class ValueGuidedRLPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class ModelMixin(metaclass=DummyObject): _backends = ["torch"] From becc803ab3804da95ce460f44d11722ea556baa4 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 09:31:15 -0800 Subject: [PATCH 58/63] fix --- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 60045d9ca918..c3e373d2bdca 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -303,9 +303,9 @@ def step( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": - variance = self._get_variance(t, predicted_variance=predicted_variance) * noise + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise else: - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance From 9b8e5eef2d0053bb69248251be2e9a11e0692faf Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 14 Nov 2022 16:19:22 -0500 Subject: [PATCH 59/63] changes per Patrik's comments (#1285) * changes per Patrik's comments * update conversion script --- .../convert_models_diffuser_to_diffusers.py | 24 ++++++++++++++++++- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/resnet.py | 12 ++++------ src/diffusers/models/unet_1d.py | 10 ++++---- 4 files changed, 33 insertions(+), 15 deletions(-) diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py index 4b4608358c17..9475f7da93fb 100644 --- a/scripts/convert_models_diffuser_to_diffusers.py +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -29,6 +29,19 @@ def unet(hor): block_out_channels=block_out_channels, up_block_types=up_block_types, layers_per_block=1, + use_timestep_embedding=True, + out_block_type="OutConv1DBlock", + norm_num_groups=8, + downsample_each_block=False, + in_channels=14, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + flip_sin_to_cos=False, + freq_shift=1, + sample_size=65536, + mid_block_type="MidResTemporalBlock1D", + act_fn="mish", ) hf_value_function = UNet1DModel(**config) print(f"length of state dict: {len(state_dict.keys())}") @@ -52,7 +65,16 @@ def value_function(): mid_block_type="ValueFunctionMidBlock1D", block_out_channels=(32, 64, 128, 256), layers_per_block=1, - always_downsample=True, + downsample_each_block=True, + sample_size=65536, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1, + norm_num_groups=8, + act_fn="mish", ) model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b5356a2bc94e..0221d891f171 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -69,7 +69,7 @@ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", self.act = None if act_fn == "silu": self.act = nn.SiLU() - if act_fn == "mish": + elif act_fn == "mish": self.act = nn.Mish() if out_dim is not None: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 99b6092aed06..52d056ae96fb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -523,13 +523,9 @@ def forward(self, x): class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - self.blocks = nn.ModuleList( - [ - Conv1dBlock(inp_channels, out_channels, kernel_size), - Conv1dBlock(out_channels, out_channels, kernel_size), - ] - ) self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) @@ -548,8 +544,8 @@ def forward(self, x, t): """ t = self.time_emb_act(t) t = self.time_emb(t) - out = self.blocks[0](x) + rearrange_dims(t) - out = self.blocks[1](out) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) return out + self.residual_conv(x) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c4fa92275fe6..c974d0a82cb6 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -77,7 +77,7 @@ def __init__( time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, - downscale_freq_shift: float = 0.0, + freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", @@ -86,7 +86,7 @@ def __init__( act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, - always_downsample: bool = False, + downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size @@ -99,7 +99,7 @@ def __init__( timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] @@ -134,7 +134,7 @@ def __init__( in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], - add_downsample=not is_final_block or always_downsample, + add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) @@ -146,7 +146,7 @@ def __init__( out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, - add_downsample=always_downsample, + add_downsample=downsample_each_block, ) # up From 3684a8c4662d3f8f45232792d0f73d9a0337be04 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 13:27:34 -0800 Subject: [PATCH 60/63] fix renaming --- src/diffusers/models/unet_1d.py | 4 ++-- tests/models/test_models_unet_1d.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c974d0a82cb6..29d1d707f55a 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -48,7 +48,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - downscale_freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : @@ -62,7 +62,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. - always_downsample (`int`, *optional*, defaults to False: + downsample_each_block (`int`, *optional*, defaults to False: experimental feature for using a UNet without upsampling. """ diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index bcbf6926da29..606fe0812cad 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -70,7 +70,7 @@ def prepare_init_args_and_inputs_for_common(self): "time_embedding_type": "positional", "use_timestep_embedding": True, "flip_sin_to_cos": False, - "downscale_freq_shift": 1.0, + "freq_shift": 1.0, "out_block_type": "OutConv1DBlock", "mid_block_type": "MidResTemporalBlock1D", "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), @@ -204,9 +204,9 @@ def prepare_init_args_and_inputs_for_common(self): "mid_block_type": "ValueFunctionMidBlock1D", "block_out_channels": [32, 64, 128, 256], "layers_per_block": 1, - "always_downsample": True, + "downsample_each_block": True, "use_timestep_embedding": True, - "downscale_freq_shift": 1.0, + "freq_shift": 1.0, "flip_sin_to_cos": False, "time_embedding_type": "positional", "act_fn": "mish", From ebdef1628ae24b7246a11a683bf2bdb32bdd439b Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 13:35:44 -0800 Subject: [PATCH 61/63] skip more mps tests --- tests/models/test_models_unet_1d.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 606fe0812cad..7464b8dadaa4 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -62,6 +62,18 @@ def test_determinism(self): def test_outputs_equivalence(self): super().test_outputs_equivalence() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + super().test_output() + def prepare_init_args_and_inputs_for_common(self): init_dict = { "block_out_channels": (32, 64, 128, 256), @@ -170,6 +182,10 @@ def test_determinism(self): def test_outputs_equivalence(self): super().test_outputs_equivalence() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): # UNetRL is a value-function is different output shape From a259aaeaf9ba402853ed7a7f4bf304e01c476cb3 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 13:40:26 -0800 Subject: [PATCH 62/63] last test fix --- tests/models/test_models_unet_1d.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 7464b8dadaa4..41c4fdecfa0a 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -186,6 +186,10 @@ def test_outputs_equivalence(self): def test_from_pretrained_save_pretrained(self): super().test_from_pretrained_save_pretrained() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): # UNetRL is a value-function is different output shape From 1f7702c1d5c63300688a46f403c6b3ba9e813702 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 13:44:05 -0800 Subject: [PATCH 63/63] Update examples/rl/README.md --- examples/rl/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rl/README.md b/examples/rl/README.md index dd8add8aa4ea..d68f2bf780e3 100644 --- a/examples/rl/README.md +++ b/examples/rl/README.md @@ -11,7 +11,7 @@ You will need some RL specific requirements to run the examples: pip install -f https://download.pytorch.org/whl/torch_stable.html \ free-mujoco-py \ einops \ - gym \ + gym==0.24.1 \ protobuf==3.20.1 \ git+https://github.com/rail-berkeley/d4rl.git \ mediapy \