Skip to content

Pipeline cleanup #947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f29ace4
valuefunction code
bglick13 Oct 8, 2022
1684e8b
start example scripts
bglick13 Oct 8, 2022
c757985
missing imports
bglick13 Oct 8, 2022
b315918
bug fixes and placeholder example script
bglick13 Oct 8, 2022
f01c014
add value function scheduler
bglick13 Oct 9, 2022
7b60c93
load value function from hub and get best actions in example
bglick13 Oct 9, 2022
0de435e
very close to working example
bglick13 Oct 10, 2022
a396529
larger batch size for planning
bglick13 Oct 10, 2022
713bd80
more tests
bglick13 Oct 11, 2022
e3fb50f
Merge branch 'main' into rl
bglick13 Oct 11, 2022
686069f
Merge branch 'hf_rl' into rl
bglick13 Oct 11, 2022
d9384ff
merge unet1d changes
bglick13 Oct 11, 2022
52e2668
wandb for debugging, use newer models
bglick13 Oct 11, 2022
75fe8b4
success!
bglick13 Oct 11, 2022
c7fe1dc
turns out we just need more diffusion steps
bglick13 Oct 12, 2022
a6871b1
run on modal
bglick13 Oct 12, 2022
13a443c
Merge branch 'hf_rl' into rl
bglick13 Oct 12, 2022
38616cf
merge and code cleanup
bglick13 Oct 12, 2022
d37b472
use same api for rl model
bglick13 Oct 12, 2022
aa19286
fix variance type
bglick13 Oct 13, 2022
02293e2
wrong normalization function
bglick13 Oct 13, 2022
56818e5
add tests
bglick13 Oct 17, 2022
d085725
style
bglick13 Oct 17, 2022
93fe3ef
style and quality
bglick13 Oct 17, 2022
4e378e9
edits based on comments
bglick13 Oct 18, 2022
e7e6963
style and quality
bglick13 Oct 18, 2022
4f77d89
remove unused var
bglick13 Oct 19, 2022
5de8a6a
Merge branch 'hf_rl' into rl
bglick13 Oct 19, 2022
6bd8397
hack unet1d into a value function
bglick13 Oct 20, 2022
435ad26
add pipeline
bglick13 Oct 20, 2022
5653408
fix arg order
bglick13 Oct 20, 2022
1491932
add pipeline to core library
bglick13 Oct 20, 2022
1a8098e
community pipeline
bglick13 Oct 20, 2022
0e4be75
fix couple shape bugs
bglick13 Oct 21, 2022
5ef88ef
style
bglick13 Oct 21, 2022
c6d94ce
Apply suggestions from code review
Oct 21, 2022
a9cee78
clean up comments
bglick13 Oct 21, 2022
5c8cfc2
Merge remote-tracking branch 'bglick13/rl' into rl
bglick13 Oct 21, 2022
b7fac18
convert older script to using pipeline and add readme
bglick13 Oct 21, 2022
b3edd7b
rename scripts
bglick13 Oct 21, 2022
8b01b93
style, update tests
bglick13 Oct 21, 2022
b0b8b0b
Merge branch 'hf_rl' into rl
bglick13 Oct 22, 2022
3c668a7
delete unet rl model file
bglick13 Oct 22, 2022
af26faa
remove imports in src
Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions examples/community/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

import tqdm
Expand Down Expand Up @@ -59,7 +60,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
# 3. call the sample function
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
Expand All @@ -76,24 +76,39 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# 4. apply conditions to the trajectory
# apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y

def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
# normalize the observations and create batch dimension
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)

conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)

# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)

# run the diffusion process
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)

# sort output trajectories by value
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")
denorm_actions = denorm_actions[0, 0]

# select the action with the highest value
if y is not None:
selected_index = 0
else:
# if we didn't run value guiding, select a random action
selected_index = np.random.randint(0, batch_size)
denorm_actions = denorm_actions[selected_index, 0]
return denorm_actions
13 changes: 11 additions & 2 deletions examples/community/value_guided_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
# 3. call the sample function
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
Expand All @@ -76,24 +75,34 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# 4. apply conditions to the trajectory
# apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y

def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
# normalize the observations and create batch dimension
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)

conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)

# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)

# run the diffusion process
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)

# sort output trajectories by value
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")

# select the action with the highest value
denorm_actions = denorm_actions[0, 0]
return denorm_actions
16 changes: 16 additions & 0 deletions examples/diffuser/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Overview

These examples show how to run (Diffuser)[https://arxiv.org/pdf/2205.09991.pdf] in Diffusers. There are two scripts, `run_diffuser_value_guided.py` and `run_diffuser.py`.

You will need some RL specific requirements to run the examples:

```
pip install -f https://download.pytorch.org/whl/torch_stable.html \
free-mujoco-py \
einops \
gym \
protobuf==3.20.1 \
git+https://github.com/rail-berkeley/d4rl.git \
mediapy \
Pillow==9.0.0
```
79 changes: 79 additions & 0 deletions examples/diffuser/run_diffuser_gen_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import d4rl # noqa
import gym
import tqdm
from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel


config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=0,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


def _run():
env_name = "hopper-medium-v2"
env = gym.make(env_name)

DEVICE = config["device"]

scheduler = DDPMScheduler(
num_train_timesteps=config["num_inference_steps"],
beta_schedule="squaredcos_cap_v2",
clip_sample=False,
variance_type="fixed_small_log",
)
network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval()
unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval()
pipeline = DiffusionPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
value_function=network,
unet=unet,
scheduler=scheduler,
env=env,
custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community",
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# Call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")


def run():
_run()


if __name__ == "__main__":
run()
83 changes: 83 additions & 0 deletions examples/diffuser/run_diffuser_locomotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import d4rl # noqa
import gym
import tqdm
from diffusers import DDPMScheduler, DiffusionPipeline, UNet1DModel


config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=2,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)


def _run():
env_name = "hopper-medium-v2"
env = gym.make(env_name)

# Cuda settings for colab
# torch.cuda.get_device_name(0)
DEVICE = config["device"]

# Two generators for different parts of the diffusion loop to work in colab
scheduler = DDPMScheduler(
num_train_timesteps=config["num_inference_steps"],
beta_schedule="squaredcos_cap_v2",
clip_sample=False,
variance_type="fixed_small_log",
)

network = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32").to(device=DEVICE).eval()
unet = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-unet-hor32").to(device=DEVICE).eval()
pipeline = DiffusionPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
value_function=network,
unet=unet,
scheduler=scheduler,
env=env,
custom_pipeline="/Users/bglickenhaus/Documents/diffusers/examples/community",
)

env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# call the policy
denorm_actions = pipeline(obs, planning_horizon=32)

# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")


def run():
_run()


if __name__ == "__main__":
run()
85 changes: 84 additions & 1 deletion tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch

from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel
from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction
from diffusers.utils import floats_tensor, slow, torch_device

from .test_modeling_common import ModelTesterMixin
Expand Down Expand Up @@ -524,3 +524,86 @@ def test_output_pretrained(self):
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass


class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel

@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16

noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)

return {"sample": noise, "timestep": time_step}

@property
def input_shape(self):
return (4, 14, 16)

@property
def output_shape(self):
return (4, 14, 1)

def test_ema_training(self):
pass

def test_training(self):
pass

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64, 128, 256),
"in_channels": 14,
"out_channels": 14,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_from_pretrained_hub(self):
unet, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True
)
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
)
self.assertIsNotNone(unet)
self.assertEqual(len(loading_info["missing_keys"]), 0)
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)

unet.to(torch_device)
value_function.to(torch_device)
image = value_function(**self.dummy_input)

assert image is not None, "Make sure output is not None"

def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

num_features = value_function.in_channels
seq_len = 14
noise = torch.randn((1, seq_len, num_features)).permute(
0, 2, 1
) # match original, we can update values and remove
time_step = torch.full((num_features,), 0)

with torch.no_grad():
output = value_function(noise, time_step).sample

# fmt: off
expected_output_slice = torch.tensor([207.0272] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))

def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass