-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add Value Function and corresponding example script to Diffuser implementation #884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f29ace4
1684e8b
c757985
b315918
f01c014
7b60c93
0de435e
a396529
713bd80
e3fb50f
686069f
d9384ff
52e2668
75fe8b4
c7fe1dc
a6871b1
13a443c
38616cf
d37b472
aa19286
02293e2
56818e5
d085725
93fe3ef
4e378e9
e7e6963
4f77d89
5de8a6a
6bd8397
435ad26
5653408
1491932
1a8098e
0e4be75
5ef88ef
c6d94ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,4 +163,6 @@ tags | |
*.lock | ||
|
||
# DS_Store (MacOS) | ||
.DS_Store | ||
.DS_Store | ||
# RL pipelines may produce mp4 outputs | ||
*.mp4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
|
||
import tqdm | ||
from diffusers import DiffusionPipeline | ||
from diffusers.models.unet_1d import UNet1DModel | ||
from diffusers.utils.dummy_pt_objects import DDPMScheduler | ||
|
||
|
||
class ValueGuidedDiffuserPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
value_function: UNet1DModel, | ||
unet: UNet1DModel, | ||
scheduler: DDPMScheduler, | ||
env, | ||
): | ||
super().__init__() | ||
self.value_function = value_function | ||
self.unet = unet | ||
self.scheduler = scheduler | ||
self.env = env | ||
self.data = env.get_dataset() | ||
self.means = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.means[key] = self.data[key].mean() | ||
except: | ||
pass | ||
self.stds = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.stds[key] = self.data[key].std() | ||
except: | ||
pass | ||
self.state_dim = env.observation_space.shape[0] | ||
self.action_dim = env.action_space.shape[0] | ||
|
||
def normalize(self, x_in, key): | ||
return (x_in - self.means[key]) / self.stds[key] | ||
|
||
def de_normalize(self, x_in, key): | ||
return x_in * self.stds[key] + self.means[key] | ||
|
||
def to_torch(self, x_in): | ||
if type(x_in) is dict: | ||
return {k: self.to_torch(v) for k, v in x_in.items()} | ||
elif torch.is_tensor(x_in): | ||
return x_in.to(self.unet.device) | ||
return torch.tensor(x_in, device=self.unet.device) | ||
|
||
def reset_x0(self, x_in, cond, act_dim): | ||
for key, val in cond.items(): | ||
x_in[:, key, act_dim:] = val.clone() | ||
return x_in | ||
|
||
def run_diffusion(self, x, conditions, n_guide_steps, scale): | ||
batch_size = x.shape[0] | ||
y = None | ||
for i in tqdm.tqdm(self.scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) | ||
# 3. call the sample function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numbering starts at 3, can we remove the comments? |
||
for _ in range(n_guide_steps): | ||
with torch.enable_grad(): | ||
x.requires_grad_() | ||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample | ||
grad = torch.autograd.grad([y.sum()], [x])[0] | ||
|
||
posterior_variance = self.scheduler._get_variance(i) | ||
model_std = torch.exp(0.5 * posterior_variance) | ||
grad = model_std * grad | ||
grad[timesteps < 2] = 0 | ||
x = x.detach() | ||
x = x + scale * grad | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) | ||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] | ||
|
||
# 4. apply conditions to the trajectory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix comment |
||
x = self.reset_x0(x, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
return x, y | ||
|
||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we get some comments and spacing here? |
||
obs = self.normalize(obs, "observations") | ||
obs = obs[None].repeat(batch_size, axis=0) | ||
conditions = {0: self.to_torch(obs)} | ||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) | ||
x1 = torch.randn(shape, device=self.unet.device) | ||
x = self.reset_x0(x1, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) | ||
sorted_idx = y.argsort(0, descending=True).squeeze() | ||
sorted_values = x[sorted_idx] | ||
actions = sorted_values[:, :, : self.action_dim] | ||
actions = actions.detach().cpu().numpy() | ||
denorm_actions = self.de_normalize(actions, key="actions") | ||
denorm_actions = denorm_actions[0, 0] | ||
return denorm_actions |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
|
||
import tqdm | ||
from diffusers import DiffusionPipeline | ||
from diffusers.models.unet_1d import UNet1DModel | ||
from diffusers.utils.dummy_pt_objects import DDPMScheduler | ||
|
||
|
||
class ValueGuidedDiffuserPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
value_function: UNet1DModel, | ||
unet: UNet1DModel, | ||
scheduler: DDPMScheduler, | ||
env, | ||
): | ||
super().__init__() | ||
self.value_function = value_function | ||
self.unet = unet | ||
self.scheduler = scheduler | ||
self.env = env | ||
self.data = env.get_dataset() | ||
self.means = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.means[key] = self.data[key].mean() | ||
except: | ||
pass | ||
self.stds = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.stds[key] = self.data[key].std() | ||
except: | ||
pass | ||
self.state_dim = env.observation_space.shape[0] | ||
self.action_dim = env.action_space.shape[0] | ||
|
||
def normalize(self, x_in, key): | ||
return (x_in - self.means[key]) / self.stds[key] | ||
|
||
def de_normalize(self, x_in, key): | ||
return x_in * self.stds[key] + self.means[key] | ||
|
||
def to_torch(self, x_in): | ||
if type(x_in) is dict: | ||
return {k: self.to_torch(v) for k, v in x_in.items()} | ||
elif torch.is_tensor(x_in): | ||
return x_in.to(self.unet.device) | ||
return torch.tensor(x_in, device=self.unet.device) | ||
|
||
def reset_x0(self, x_in, cond, act_dim): | ||
for key, val in cond.items(): | ||
x_in[:, key, act_dim:] = val.clone() | ||
return x_in | ||
|
||
def run_diffusion(self, x, conditions, n_guide_steps, scale): | ||
batch_size = x.shape[0] | ||
y = None | ||
for i in tqdm.tqdm(self.scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) | ||
# 3. call the sample function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix comment |
||
for _ in range(n_guide_steps): | ||
with torch.enable_grad(): | ||
x.requires_grad_() | ||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample | ||
grad = torch.autograd.grad([y.sum()], [x])[0] | ||
|
||
posterior_variance = self.scheduler._get_variance(i) | ||
model_std = torch.exp(0.5 * posterior_variance) | ||
grad = model_std * grad | ||
grad[timesteps < 2] = 0 | ||
x = x.detach() | ||
x = x + scale * grad | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) | ||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] | ||
|
||
# 4. apply conditions to the trajectory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix |
||
x = self.reset_x0(x, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
return x, y | ||
|
||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): | ||
obs = self.normalize(obs, "observations") | ||
obs = obs[None].repeat(batch_size, axis=0) | ||
conditions = {0: self.to_torch(obs)} | ||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) | ||
x1 = torch.randn(shape, device=self.unet.device) | ||
x = self.reset_x0(x1, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) | ||
sorted_idx = y.argsort(0, descending=True).squeeze() | ||
sorted_values = x[sorted_idx] | ||
actions = sorted_values[:, :, : self.action_dim] | ||
actions = actions.detach().cpu().numpy() | ||
denorm_actions = self.de_normalize(actions, key="actions") | ||
denorm_actions = denorm_actions[0, 0] | ||
return denorm_actions |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably the core of a pipeline or example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you want a pipeline to be part of this PR? Or something we add in a small follow up PR? |
||
import torch | ||
|
||
import d4rl # noqa | ||
import gym | ||
import tqdm | ||
import train_diffuser | ||
from diffusers import DDPMScheduler, UNet1DModel | ||
|
||
|
||
env_name = "hopper-medium-expert-v2" | ||
env = gym.make(env_name) | ||
data = env.get_dataset() # dataset is only used for normalization in this colab | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if we can but I wanted to try and not load it (because slow), maybe we can put the normalization in the pipeline? Maybe just for other scripts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh actually for value guided I do it in the pipeline. I should convert the pipeline to also work for regular (non value guided diffusion) |
||
|
||
DEVICE = "cpu" | ||
DTYPE = torch.float | ||
|
||
# diffusion model settings | ||
n_samples = 4 # number of trajectories planned via diffusion | ||
horizon = 128 # length of sampled trajectories | ||
state_dim = env.observation_space.shape[0] | ||
action_dim = env.action_space.shape[0] | ||
num_inference_steps = 100 # number of difusion steps | ||
|
||
|
||
# Two generators for different parts of the diffusion loop to work in colab | ||
generator_cpu = torch.Generator(device="cpu") | ||
|
||
scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") | ||
|
||
# 3 different pretrained models are available for this task. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. clean up comments |
||
# The horizion represents the length of trajectories used in training. | ||
network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) | ||
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) | ||
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) | ||
|
||
|
||
# network specific constants for inference | ||
clip_denoised = network.clip_denoised | ||
predict_epsilon = network.predict_epsilon | ||
|
||
# [ observation_dim ] --> [ n_samples x observation_dim ] | ||
obs = env.reset() | ||
total_reward = 0 | ||
done = False | ||
T = 300 | ||
rollout = [obs.copy()] | ||
|
||
try: | ||
for t in tqdm.tqdm(range(T)): | ||
obs_raw = obs | ||
|
||
# normalize observations for forward passes | ||
obs = train_diffuser.normalize(obs, data, "observations") | ||
obs = obs[None].repeat(n_samples, axis=0) | ||
conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} | ||
|
||
# constants for inference | ||
batch_size = len(conditions[0]) | ||
shape = (batch_size, horizon, state_dim + action_dim) | ||
|
||
# sample random initial noise vector | ||
x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) | ||
|
||
# this model is conditioned from an initial state, so you will see this function | ||
# multiple times to change the initial state of generated data to the state | ||
# generated via env.reset() above or env.step() below | ||
x = train_diffuser.reset_x0(x1, conditions, action_dim) | ||
|
||
# convert a np observation to torch for model forward pass | ||
x = train_diffuser.to_torch(x) | ||
|
||
eta = 1.0 # noise factor for sampling reconstructed state | ||
|
||
# run the diffusion process | ||
# for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): | ||
for i in tqdm.tqdm(scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) | ||
|
||
# 1. generate prediction from model | ||
with torch.no_grad(): | ||
residual = network(x, timesteps).sample | ||
|
||
# 2. use the model prediction to reconstruct an observation (de-noise) | ||
obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] | ||
|
||
# 3. [optional] add posterior noise to the sample | ||
if eta > 0: | ||
noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) | ||
posterior_variance = scheduler._get_variance(i) # * noise | ||
# no noise when t == 0 | ||
# NOTE: original implementation missing sqrt on posterior_variance | ||
obs_reconstruct = ( | ||
obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise | ||
) # MJ had as log var, exponentiated | ||
|
||
# 4. apply conditions to the trajectory | ||
obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) | ||
x = train_diffuser.to_torch(obs_reconstruct_postcond) | ||
plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) | ||
# select random plan | ||
idx = np.random.randint(plans.shape[0]) | ||
# select action at correct time | ||
action = plans[idx, 0, :] | ||
actions = train_diffuser.de_normalize(action, data, "actions") | ||
# execute action in environment | ||
next_observation, reward, terminal, _ = env.step(action) | ||
|
||
# update return | ||
total_reward += reward | ||
print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") | ||
|
||
# save observations for rendering | ||
rollout.append(next_observation.copy()) | ||
obs = next_observation | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
print(f"Total reward: {total_reward}") | ||
render = train_diffuser.MuJoCoRenderer(env) | ||
train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bglick13 to get the PR ready, let's remove things like
hub
.We can leave
*.mp4
if we add a pipeline, but add a comment saying its for rl so people know.