diff --git a/examples/community/pipeline.py b/examples/community/pipeline.py index 7e3f2b832b1f..85e359c5c4c9 100644 --- a/examples/community/pipeline.py +++ b/examples/community/pipeline.py @@ -1,3 +1,4 @@ +import numpy as np import torch import tqdm @@ -59,7 +60,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - # 3. call the sample function for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() @@ -76,24 +76,39 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - # 4. apply conditions to the trajectory + # apply conditions to the trajectory x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = torch.randn(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) + + # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") - denorm_actions = denorm_actions[0, 0] + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + denorm_actions = denorm_actions[selected_index, 0] return denorm_actions diff --git a/examples/community/value_guided_diffuser.py b/examples/community/value_guided_diffuser.py index 7e3f2b832b1f..6b28e868eddd 100644 --- a/examples/community/value_guided_diffuser.py +++ b/examples/community/value_guided_diffuser.py @@ -59,7 +59,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - # 3. call the sample function for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() @@ -76,24 +75,34 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale): prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - # 4. apply conditions to the trajectory + # apply conditions to the trajectory x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) + conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = torch.randn(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) + + # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value denorm_actions = denorm_actions[0, 0] return denorm_actions diff --git a/examples/diffuser/README.md b/examples/diffuser/README.md new file mode 100644 index 000000000000..464ccd57af85 --- /dev/null +++ b/examples/diffuser/README.md @@ -0,0 +1,16 @@ +# Overview + +These examples show how to run (Diffuser)[https://arxiv.org/pdf/2205.09991.pdf] in Diffusers. There are two scripts, `run_diffuser_value_guided.py` and `run_diffuser.py`. + +You will need some RL specific requirements to run the examples: + +``` +pip install -f https://download.pytorch.org/whl/torch_stable.html \ + free-mujoco-py \ + einops \ + gym \ + protobuf==3.20.1 \ + git+https://github.com/rail-berkeley/d4rl.git \ + mediapy \ + Pillow==9.0.0 +``` diff --git a/examples/diffuser/run_diffuser_gen_trajectories.py b/examples/diffuser/run_diffuser_gen_trajectories.py new file mode 100644 index 000000000000..f4c86635c652 --- /dev/null +++ b/examples/diffuser/run_diffuser_gen_trajectories.py @@ -0,0 +1,79 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=0, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + DEVICE = config["device"] + + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # Call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/examples/diffuser/run_diffuser_locomotion.py b/examples/diffuser/run_diffuser_locomotion.py new file mode 100644 index 000000000000..1b4351095d3b --- /dev/null +++ b/examples/diffuser/run_diffuser_locomotion.py @@ -0,0 +1,83 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +def _run(): + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + # Cuda settings for colab + # torch.cuda.get_device_name(0) + DEVICE = config["device"] + + # Two generators for different parts of the diffusion loop to work in colab + scheduler = DDPMScheduler( + num_train_timesteps=config["num_inference_steps"], + beta_schedule="squaredcos_cap_v2", + clip_sample=False, + variance_type="fixed_small_log", + ) + + network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval() + unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval() + pipeline = DiffusionPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + value_function=network, + unet=unet, + scheduler=scheduler, + env=env, + custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community", + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") + + +def run(): + _run() + + +if __name__ == "__main__": + run() diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index e1dbdfaa4611..1ff092b3ce78 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -20,7 +20,7 @@ import torch -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -524,3 +524,86 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + unet, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True + ) + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + self.assertIsNotNone(unet) + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + unet.to(torch_device) + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([207.0272] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass