Skip to content

Add Value Function and corresponding example script to Diffuser implementation #884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f29ace4
valuefunction code
bglick13 Oct 8, 2022
1684e8b
start example scripts
bglick13 Oct 8, 2022
c757985
missing imports
bglick13 Oct 8, 2022
b315918
bug fixes and placeholder example script
bglick13 Oct 8, 2022
f01c014
add value function scheduler
bglick13 Oct 9, 2022
7b60c93
load value function from hub and get best actions in example
bglick13 Oct 9, 2022
0de435e
very close to working example
bglick13 Oct 10, 2022
a396529
larger batch size for planning
bglick13 Oct 10, 2022
713bd80
more tests
bglick13 Oct 11, 2022
e3fb50f
Merge branch 'main' into rl
bglick13 Oct 11, 2022
686069f
Merge branch 'hf_rl' into rl
bglick13 Oct 11, 2022
d9384ff
merge unet1d changes
bglick13 Oct 11, 2022
52e2668
wandb for debugging, use newer models
bglick13 Oct 11, 2022
75fe8b4
success!
bglick13 Oct 11, 2022
c7fe1dc
turns out we just need more diffusion steps
bglick13 Oct 12, 2022
a6871b1
run on modal
bglick13 Oct 12, 2022
13a443c
Merge branch 'hf_rl' into rl
bglick13 Oct 12, 2022
38616cf
merge and code cleanup
bglick13 Oct 12, 2022
d37b472
use same api for rl model
bglick13 Oct 12, 2022
aa19286
fix variance type
bglick13 Oct 13, 2022
02293e2
wrong normalization function
bglick13 Oct 13, 2022
56818e5
add tests
bglick13 Oct 17, 2022
d085725
style
bglick13 Oct 17, 2022
93fe3ef
style and quality
bglick13 Oct 17, 2022
4e378e9
edits based on comments
bglick13 Oct 18, 2022
e7e6963
style and quality
bglick13 Oct 18, 2022
4f77d89
remove unused var
bglick13 Oct 19, 2022
5de8a6a
Merge branch 'hf_rl' into rl
bglick13 Oct 19, 2022
6bd8397
hack unet1d into a value function
bglick13 Oct 20, 2022
435ad26
add pipeline
bglick13 Oct 20, 2022
5653408
fix arg order
bglick13 Oct 20, 2022
1491932
add pipeline to core library
bglick13 Oct 20, 2022
1a8098e
community pipeline
bglick13 Oct 20, 2022
0e4be75
fix couple shape bugs
bglick13 Oct 21, 2022
5ef88ef
style
bglick13 Oct 21, 2022
c6d94ce
Apply suggestions from code review
Oct 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bglick13 to get the PR ready, let's remove things like hub.
We can leave *.mp4 if we add a pipeline, but add a comment saying its for rl so people know.

*.mp4
hub/*
59 changes: 59 additions & 0 deletions convert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try and consolidate all of these files as much as possible.
Put the convert model into scripts/ with a longer & more descriptive path. (I can do the same for my RL models)

import torch
from diffusers import DDPMScheduler, UNet1DModel, ValueFunction
import os
import json
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)

os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)

def unet(hor):
if hor == 128:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")

elif hor == 32:
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
block_out_channels = (32, 64, 128, 256)
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
state_dict = model.state_dict()
config = dict(down_block_types=down_block_types, block_out_channels=block_out_channels, up_block_types=up_block_types, layers_per_block=1)
hf_value_function = UNet1DModel(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)
hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
json.dump(config, f)

def value_function():
config = dict(in_channels=14, down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), block_out_channels=(32, 64, 128, 256), layers_per_block=1)

model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
state_dict = model
hf_value_function = ValueFunction(**config)
print(f"length of state dict: {len(state_dict.keys())}")
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")

mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
for k, v in mapping.items():
state_dict[v] = state_dict.pop(k)

hf_value_function.load_state_dict(state_dict)

torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
json.dump(config, f)


if __name__ == "__main__":
unet(32)
# unet(128)
value_function()
267 changes: 267 additions & 0 deletions examples/diffuser/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess just put these in the training file to minimize how many files things are across. Also huge points for getting rid of as many functions as you can (I did some of that in colab).

import warnings

import numpy as np
import torch

import gym
import mediapy as media
import mujoco_py as mjc
import tqdm


DTYPE = torch.float


def normalize(x_in, data, key):
means = data[key].mean(axis=0)
stds = data[key].std(axis=0)
return (x_in - means) / stds


def de_normalize(x_in, data, key):
means = data[key].mean(axis=0)
stds = data[key].std(axis=0)
return x_in * stds + means


def to_torch(x_in, dtype=None, device="cuda"):
dtype = dtype or DTYPE
device = device
if type(x_in) is dict:
return {k: to_torch(v, dtype, device) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(device).type(dtype)
return torch.tensor(x_in, dtype=dtype, device=device)


def reset_x0(x_in, cond, act_dim):
for key, val in cond.items():
x_in[:, key, act_dim:] = val.clone()
return x_in


def run_diffusion(x, scheduler, generator, network, unet, conditions, action_dim, config):
y = None
for i in tqdm.tqdm(scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long)
# 3. call the sample function
for _ in range(config["n_guide_steps"]):
with torch.enable_grad():
x.requires_grad_()
y = network(x, timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]
if config["scale_grad_by_std"]:
posterior_variance = scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < config["t_grad_cutoff"]] = 0
x = x.detach()
x = x + config["scale"] * grad
x = reset_x0(x, conditions, action_dim)
# with torch.no_grad():
prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# 3. [optional] add posterior noise to the sample
if config["eta"] > 0:
noise = torch.randn(x.shape).to(x.device)
posterior_variance = scheduler._get_variance(i) # * noise
# no noise when t == 0
# NOTE: original implementation missing sqrt on posterior_variance
x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated

# 4. apply conditions to the trajectory
x = reset_x0(x, conditions, action_dim)
x = to_torch(x, device=config["device"])
# y = network(x, timesteps).sample
return x, y


def to_np(x_in):
if torch.is_tensor(x_in):
x_in = x_in.detach().cpu().numpy()
return x_in


# from MJ's Diffuser code
# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79
def mkdir(savepath):
"""
returns `True` iff `savepath` is created
"""
if not os.path.exists(savepath):
os.makedirs(savepath)
return True
else:
return False


def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"):
"""
observations : [ batch_size x horizon x observation_dim ]
"""

mkdir(savebase)
savepath = os.path.join(savebase, filename)

images = []
for rollout in observations:
# [ horizon x height x width x channels ]
img = renderer._renders(rollout, partial=True)
images.append(img)

# [ horizon x height x (batch_size * width) x channels ]
images = np.concatenate(images, axis=2)
media.write_video(savepath, images, fps=60)
media.show_video(images, codec="h264", fps=60)
return images


# Code adapted from Michael Janner
# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py


def env_map(env_name):
"""
map D4RL dataset names to custom fully-observed
variants for rendering
"""
if "halfcheetah" in env_name:
return "HalfCheetahFullObs-v2"
elif "hopper" in env_name:
return "HopperFullObs-v2"
elif "walker2d" in env_name:
return "Walker2dFullObs-v2"
else:
return env_name


def get_image_mask(img):
background = (img == 255).all(axis=-1, keepdims=True)
mask = ~background.repeat(3, axis=-1)
return mask


def atmost_2d(x):
while x.ndim > 2:
x = x.squeeze(0)
return x


def set_state(env, state):
qpos_dim = env.sim.data.qpos.size
qvel_dim = env.sim.data.qvel.size
if not state.size == qpos_dim + qvel_dim:
warnings.warn(
f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}"
)
state = state[: qpos_dim + qvel_dim]

env.set_state(state[:qpos_dim], state[qpos_dim:])


class MuJoCoRenderer:
"""
default mujoco renderer
"""

def __init__(self, env):
if type(env) is str:
env = env_map(env)
self.env = gym.make(env)
else:
self.env = env
# - 1 because the envs in renderer are fully-observed
# @TODO : clean up
self.observation_dim = np.prod(self.env.observation_space.shape) - 1
self.action_dim = np.prod(self.env.action_space.shape)
try:
self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)
except:
print("[ utils/rendering ] Warning: could not initialize offscreen renderer")
self.viewer = None

def pad_observation(self, observation):
state = np.concatenate(
[
np.zeros(1),
observation,
]
)
return state

def pad_observations(self, observations):
qpos_dim = self.env.sim.data.qpos.size
# xpos is hidden
xvel_dim = qpos_dim - 1
xvel = observations[:, xvel_dim]
xpos = np.cumsum(xvel) * self.env.dt
states = np.concatenate(
[
xpos[:, None],
observations,
],
axis=-1,
)
return states

def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None):
if type(dim) == int:
dim = (dim, dim)

if self.viewer is None:
return np.zeros((*dim, 3), np.uint8)

if render_kwargs is None:
xpos = observation[0] if not partial else 0
render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20}

for key, val in render_kwargs.items():
if key == "lookat":
self.viewer.cam.lookat[:] = val[:]
else:
setattr(self.viewer.cam, key, val)

if partial:
state = self.pad_observation(observation)
else:
state = observation

qpos_dim = self.env.sim.data.qpos.size
if not qvel or state.shape[-1] == qpos_dim:
qvel_dim = self.env.sim.data.qvel.size
state = np.concatenate([state, np.zeros(qvel_dim)])

set_state(self.env, state)

self.viewer.render(*dim)
data = self.viewer.read_pixels(*dim, depth=False)
data = data[::-1, :, :]
return data

def _renders(self, observations, **kwargs):
images = []
for observation in observations:
img = self.render(observation, **kwargs)
images.append(img)
return np.stack(images, axis=0)

def renders(self, samples, partial=False, **kwargs):
if partial:
samples = self.pad_observations(samples)
partial = False

sample_images = self._renders(samples, partial=partial, **kwargs)

composite = np.ones_like(sample_images[0]) * 255

for img in sample_images:
mask = get_image_mask(img)
composite[mask] = img[mask]

return composite

def __call__(self, *args, **kwargs):
return self.renders(*args, **kwargs)
Loading