-
Notifications
You must be signed in to change notification settings - Fork 6k
Add Value Function and corresponding example script to Diffuser implementation #884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
f29ace4
1684e8b
c757985
b315918
f01c014
7b60c93
0de435e
a396529
713bd80
e3fb50f
686069f
d9384ff
52e2668
75fe8b4
c7fe1dc
a6871b1
13a443c
38616cf
d37b472
aa19286
02293e2
56818e5
d085725
93fe3ef
4e378e9
e7e6963
4f77d89
5de8a6a
6bd8397
435ad26
5653408
1491932
1a8098e
0e4be75
5ef88ef
c6d94ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,4 +163,6 @@ tags | |
*.lock | ||
|
||
# DS_Store (MacOS) | ||
.DS_Store | ||
.DS_Store | ||
*.mp4 | ||
hub/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's try and consolidate all of these files as much as possible. |
||
import torch | ||
from diffusers import DDPMScheduler, UNet1DModel, ValueFunction | ||
import os | ||
import json | ||
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) | ||
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) | ||
|
||
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) | ||
|
||
def unet(hor): | ||
if hor == 128: | ||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | ||
block_out_channels = (32, 128, 256) | ||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") | ||
|
||
elif hor == 32: | ||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | ||
block_out_channels = (32, 64, 128, 256) | ||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") | ||
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") | ||
state_dict = model.state_dict() | ||
config = dict(down_block_types=down_block_types, block_out_channels=block_out_channels, up_block_types=up_block_types, layers_per_block=1) | ||
hf_value_function = UNet1DModel(**config) | ||
print(f"length of state dict: {len(state_dict.keys())}") | ||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | ||
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) | ||
for k, v in mapping.items(): | ||
state_dict[v] = state_dict.pop(k) | ||
hf_value_function.load_state_dict(state_dict) | ||
|
||
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") | ||
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: | ||
json.dump(config, f) | ||
|
||
def value_function(): | ||
config = dict(in_channels=14, down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), block_out_channels=(32, 64, 128, 256), layers_per_block=1) | ||
|
||
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") | ||
state_dict = model | ||
hf_value_function = ValueFunction(**config) | ||
print(f"length of state dict: {len(state_dict.keys())}") | ||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | ||
|
||
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) | ||
for k, v in mapping.items(): | ||
state_dict[v] = state_dict.pop(k) | ||
|
||
hf_value_function.load_state_dict(state_dict) | ||
|
||
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") | ||
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: | ||
json.dump(config, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
unet(32) | ||
# unet(128) | ||
value_function() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would guess just put these in the training file to minimize how many files things are across. Also huge points for getting rid of as many functions as you can (I did some of that in colab). |
||
import warnings | ||
|
||
import numpy as np | ||
import torch | ||
|
||
import gym | ||
import mediapy as media | ||
import mujoco_py as mjc | ||
import tqdm | ||
|
||
|
||
DTYPE = torch.float | ||
|
||
|
||
def normalize(x_in, data, key): | ||
means = data[key].mean(axis=0) | ||
stds = data[key].std(axis=0) | ||
return (x_in - means) / stds | ||
|
||
|
||
def de_normalize(x_in, data, key): | ||
means = data[key].mean(axis=0) | ||
stds = data[key].std(axis=0) | ||
return x_in * stds + means | ||
|
||
|
||
def to_torch(x_in, dtype=None, device="cuda"): | ||
dtype = dtype or DTYPE | ||
device = device | ||
if type(x_in) is dict: | ||
return {k: to_torch(v, dtype, device) for k, v in x_in.items()} | ||
elif torch.is_tensor(x_in): | ||
return x_in.to(device).type(dtype) | ||
return torch.tensor(x_in, dtype=dtype, device=device) | ||
|
||
|
||
def reset_x0(x_in, cond, act_dim): | ||
for key, val in cond.items(): | ||
x_in[:, key, act_dim:] = val.clone() | ||
return x_in | ||
|
||
|
||
def run_diffusion(x, scheduler, generator, network, unet, conditions, action_dim, config): | ||
y = None | ||
for i in tqdm.tqdm(scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((config["n_samples"],), i, device=config["device"], dtype=torch.long) | ||
# 3. call the sample function | ||
for _ in range(config["n_guide_steps"]): | ||
with torch.enable_grad(): | ||
x.requires_grad_() | ||
y = network(x, timesteps).sample | ||
grad = torch.autograd.grad([y.sum()], [x])[0] | ||
if config["scale_grad_by_std"]: | ||
posterior_variance = scheduler._get_variance(i) | ||
model_std = torch.exp(0.5 * posterior_variance) | ||
grad = model_std * grad | ||
grad[timesteps < config["t_grad_cutoff"]] = 0 | ||
x = x.detach() | ||
x = x + config["scale"] * grad | ||
x = reset_x0(x, conditions, action_dim) | ||
# with torch.no_grad(): | ||
prev_x = unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) | ||
x = scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] | ||
|
||
# 3. [optional] add posterior noise to the sample | ||
if config["eta"] > 0: | ||
noise = torch.randn(x.shape).to(x.device) | ||
posterior_variance = scheduler._get_variance(i) # * noise | ||
# no noise when t == 0 | ||
# NOTE: original implementation missing sqrt on posterior_variance | ||
x = x + int(i > 0) * (0.5 * posterior_variance) * config["eta"] * noise # MJ had as log var, exponentiated | ||
|
||
# 4. apply conditions to the trajectory | ||
x = reset_x0(x, conditions, action_dim) | ||
x = to_torch(x, device=config["device"]) | ||
# y = network(x, timesteps).sample | ||
return x, y | ||
|
||
|
||
def to_np(x_in): | ||
if torch.is_tensor(x_in): | ||
x_in = x_in.detach().cpu().numpy() | ||
return x_in | ||
|
||
|
||
# from MJ's Diffuser code | ||
# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79 | ||
def mkdir(savepath): | ||
""" | ||
returns `True` iff `savepath` is created | ||
""" | ||
if not os.path.exists(savepath): | ||
os.makedirs(savepath) | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def show_sample(renderer, observations, filename="sample.mp4", savebase="videos"): | ||
""" | ||
observations : [ batch_size x horizon x observation_dim ] | ||
""" | ||
|
||
mkdir(savebase) | ||
savepath = os.path.join(savebase, filename) | ||
|
||
images = [] | ||
for rollout in observations: | ||
# [ horizon x height x width x channels ] | ||
img = renderer._renders(rollout, partial=True) | ||
images.append(img) | ||
|
||
# [ horizon x height x (batch_size * width) x channels ] | ||
images = np.concatenate(images, axis=2) | ||
media.write_video(savepath, images, fps=60) | ||
media.show_video(images, codec="h264", fps=60) | ||
return images | ||
|
||
|
||
# Code adapted from Michael Janner | ||
# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py | ||
|
||
|
||
def env_map(env_name): | ||
""" | ||
map D4RL dataset names to custom fully-observed | ||
variants for rendering | ||
""" | ||
if "halfcheetah" in env_name: | ||
return "HalfCheetahFullObs-v2" | ||
elif "hopper" in env_name: | ||
return "HopperFullObs-v2" | ||
elif "walker2d" in env_name: | ||
return "Walker2dFullObs-v2" | ||
else: | ||
return env_name | ||
|
||
|
||
def get_image_mask(img): | ||
background = (img == 255).all(axis=-1, keepdims=True) | ||
mask = ~background.repeat(3, axis=-1) | ||
return mask | ||
|
||
|
||
def atmost_2d(x): | ||
while x.ndim > 2: | ||
x = x.squeeze(0) | ||
return x | ||
|
||
|
||
def set_state(env, state): | ||
qpos_dim = env.sim.data.qpos.size | ||
qvel_dim = env.sim.data.qvel.size | ||
if not state.size == qpos_dim + qvel_dim: | ||
warnings.warn( | ||
f"[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, but got state of size {state.size}" | ||
) | ||
state = state[: qpos_dim + qvel_dim] | ||
|
||
env.set_state(state[:qpos_dim], state[qpos_dim:]) | ||
|
||
|
||
class MuJoCoRenderer: | ||
""" | ||
default mujoco renderer | ||
""" | ||
|
||
def __init__(self, env): | ||
if type(env) is str: | ||
env = env_map(env) | ||
self.env = gym.make(env) | ||
else: | ||
self.env = env | ||
# - 1 because the envs in renderer are fully-observed | ||
# @TODO : clean up | ||
self.observation_dim = np.prod(self.env.observation_space.shape) - 1 | ||
self.action_dim = np.prod(self.env.action_space.shape) | ||
try: | ||
self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) | ||
except: | ||
print("[ utils/rendering ] Warning: could not initialize offscreen renderer") | ||
self.viewer = None | ||
|
||
def pad_observation(self, observation): | ||
state = np.concatenate( | ||
[ | ||
np.zeros(1), | ||
observation, | ||
] | ||
) | ||
return state | ||
|
||
def pad_observations(self, observations): | ||
qpos_dim = self.env.sim.data.qpos.size | ||
# xpos is hidden | ||
xvel_dim = qpos_dim - 1 | ||
xvel = observations[:, xvel_dim] | ||
xpos = np.cumsum(xvel) * self.env.dt | ||
states = np.concatenate( | ||
[ | ||
xpos[:, None], | ||
observations, | ||
], | ||
axis=-1, | ||
) | ||
return states | ||
|
||
def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): | ||
if type(dim) == int: | ||
dim = (dim, dim) | ||
|
||
if self.viewer is None: | ||
return np.zeros((*dim, 3), np.uint8) | ||
|
||
if render_kwargs is None: | ||
xpos = observation[0] if not partial else 0 | ||
render_kwargs = {"trackbodyid": 2, "distance": 3, "lookat": [xpos, -0.5, 1], "elevation": -20} | ||
|
||
for key, val in render_kwargs.items(): | ||
if key == "lookat": | ||
self.viewer.cam.lookat[:] = val[:] | ||
else: | ||
setattr(self.viewer.cam, key, val) | ||
|
||
if partial: | ||
state = self.pad_observation(observation) | ||
else: | ||
state = observation | ||
|
||
qpos_dim = self.env.sim.data.qpos.size | ||
if not qvel or state.shape[-1] == qpos_dim: | ||
qvel_dim = self.env.sim.data.qvel.size | ||
state = np.concatenate([state, np.zeros(qvel_dim)]) | ||
|
||
set_state(self.env, state) | ||
|
||
self.viewer.render(*dim) | ||
data = self.viewer.read_pixels(*dim, depth=False) | ||
data = data[::-1, :, :] | ||
return data | ||
|
||
def _renders(self, observations, **kwargs): | ||
images = [] | ||
for observation in observations: | ||
img = self.render(observation, **kwargs) | ||
images.append(img) | ||
return np.stack(images, axis=0) | ||
|
||
def renders(self, samples, partial=False, **kwargs): | ||
if partial: | ||
samples = self.pad_observations(samples) | ||
partial = False | ||
|
||
sample_images = self._renders(samples, partial=partial, **kwargs) | ||
|
||
composite = np.ones_like(sample_images[0]) * 255 | ||
|
||
for img in sample_images: | ||
mask = get_image_mask(img) | ||
composite[mask] = img[mask] | ||
|
||
return composite | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.renders(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bglick13 to get the PR ready, let's remove things like
hub
.We can leave
*.mp4
if we add a pipeline, but add a comment saying its for rl so people know.