Skip to content

Commit 7c5fef8

Browse files
Nathan Lambertbglick13
Nathan Lambert
andauthored
Add UNet 1d for RL model for planning + colab (#105)
* re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <nathan@huggingface.co> * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert <nathan@huggingface.co> * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert <nathan@huggingface.co> * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus <benglickenhaus@gmail.com>
1 parent a8d0977 commit 7c5fef8

18 files changed

+1176
-65
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,6 @@ tags
163163
*.lock
164164

165165
# DS_Store (MacOS)
166-
.DS_Store
166+
.DS_Store
167+
# RL pipelines may produce mp4 outputs
168+
*.mp4

docs/source/api/models.mdx

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
2222
## UNet2DOutput
2323
[[autodoc]] models.unet_2d.UNet2DOutput
2424

25-
## UNet1DModel
26-
[[autodoc]] UNet1DModel
27-
2825
## UNet2DModel
2926
[[autodoc]] UNet2DModel
3027

28+
## UNet1DOutput
29+
[[autodoc]] models.unet_1d.UNet1DOutput
30+
31+
## UNet1DModel
32+
[[autodoc]] UNet1DModel
33+
3134
## UNet2DConditionOutput
3235
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
3336

examples/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
4242
| [**Text-to-Image fine-tuning**](./text_to_image) |||
4343
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
4444
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
45-
45+
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
4646

4747
## Community
4848

examples/rl/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Overview
2+
3+
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
4+
There are four scripts,
5+
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
6+
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
7+
8+
You will need some RL specific requirements to run the examples:
9+
10+
```
11+
pip install -f https://download.pytorch.org/whl/torch_stable.html \
12+
free-mujoco-py \
13+
einops \
14+
gym==0.24.1 \
15+
protobuf==3.20.1 \
16+
git+https://github.com/rail-berkeley/d4rl.git \
17+
mediapy \
18+
Pillow==9.0.0
19+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import d4rl # noqa
2+
import gym
3+
import tqdm
4+
from diffusers.experimental import ValueGuidedRLPipeline
5+
6+
7+
config = dict(
8+
n_samples=64,
9+
horizon=32,
10+
num_inference_steps=20,
11+
n_guide_steps=0,
12+
scale_grad_by_std=True,
13+
scale=0.1,
14+
eta=0.0,
15+
t_grad_cutoff=2,
16+
device="cpu",
17+
)
18+
19+
20+
if __name__ == "__main__":
21+
env_name = "hopper-medium-v2"
22+
env = gym.make(env_name)
23+
24+
pipeline = ValueGuidedRLPipeline.from_pretrained(
25+
"bglick13/hopper-medium-v2-value-function-hor32",
26+
env=env,
27+
)
28+
29+
env.seed(0)
30+
obs = env.reset()
31+
total_reward = 0
32+
total_score = 0
33+
T = 1000
34+
rollout = [obs.copy()]
35+
try:
36+
for t in tqdm.tqdm(range(T)):
37+
# Call the policy
38+
denorm_actions = pipeline(obs, planning_horizon=32)
39+
40+
# execute action in environment
41+
next_observation, reward, terminal, _ = env.step(denorm_actions)
42+
score = env.get_normalized_score(total_reward)
43+
# update return
44+
total_reward += reward
45+
total_score += score
46+
print(
47+
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
48+
f" {total_score}"
49+
)
50+
# save observations for rendering
51+
rollout.append(next_observation.copy())
52+
53+
obs = next_observation
54+
except KeyboardInterrupt:
55+
pass
56+
57+
print(f"Total reward: {total_reward}")
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import d4rl # noqa
2+
import gym
3+
import tqdm
4+
from diffusers.experimental import ValueGuidedRLPipeline
5+
6+
7+
config = dict(
8+
n_samples=64,
9+
horizon=32,
10+
num_inference_steps=20,
11+
n_guide_steps=2,
12+
scale_grad_by_std=True,
13+
scale=0.1,
14+
eta=0.0,
15+
t_grad_cutoff=2,
16+
device="cpu",
17+
)
18+
19+
20+
if __name__ == "__main__":
21+
env_name = "hopper-medium-v2"
22+
env = gym.make(env_name)
23+
24+
pipeline = ValueGuidedRLPipeline.from_pretrained(
25+
"bglick13/hopper-medium-v2-value-function-hor32",
26+
env=env,
27+
)
28+
29+
env.seed(0)
30+
obs = env.reset()
31+
total_reward = 0
32+
total_score = 0
33+
T = 1000
34+
rollout = [obs.copy()]
35+
try:
36+
for t in tqdm.tqdm(range(T)):
37+
# call the policy
38+
denorm_actions = pipeline(obs, planning_horizon=32)
39+
40+
# execute action in environment
41+
next_observation, reward, terminal, _ = env.step(denorm_actions)
42+
score = env.get_normalized_score(total_reward)
43+
# update return
44+
total_reward += reward
45+
total_score += score
46+
print(
47+
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
48+
f" {total_score}"
49+
)
50+
# save observations for rendering
51+
rollout.append(next_observation.copy())
52+
53+
obs = next_observation
54+
except KeyboardInterrupt:
55+
pass
56+
57+
print(f"Total reward: {total_reward}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import json
2+
import os
3+
4+
import torch
5+
6+
from diffusers import UNet1DModel
7+
8+
9+
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
10+
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
11+
12+
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
13+
14+
15+
def unet(hor):
16+
if hor == 128:
17+
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
18+
block_out_channels = (32, 128, 256)
19+
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
20+
21+
elif hor == 32:
22+
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
23+
block_out_channels = (32, 64, 128, 256)
24+
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
25+
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
26+
state_dict = model.state_dict()
27+
config = dict(
28+
down_block_types=down_block_types,
29+
block_out_channels=block_out_channels,
30+
up_block_types=up_block_types,
31+
layers_per_block=1,
32+
use_timestep_embedding=True,
33+
out_block_type="OutConv1DBlock",
34+
norm_num_groups=8,
35+
downsample_each_block=False,
36+
in_channels=14,
37+
out_channels=14,
38+
extra_in_channels=0,
39+
time_embedding_type="positional",
40+
flip_sin_to_cos=False,
41+
freq_shift=1,
42+
sample_size=65536,
43+
mid_block_type="MidResTemporalBlock1D",
44+
act_fn="mish",
45+
)
46+
hf_value_function = UNet1DModel(**config)
47+
print(f"length of state dict: {len(state_dict.keys())}")
48+
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
49+
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
50+
for k, v in mapping.items():
51+
state_dict[v] = state_dict.pop(k)
52+
hf_value_function.load_state_dict(state_dict)
53+
54+
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
55+
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
56+
json.dump(config, f)
57+
58+
59+
def value_function():
60+
config = dict(
61+
in_channels=14,
62+
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63+
up_block_types=(),
64+
out_block_type="ValueFunction",
65+
mid_block_type="ValueFunctionMidBlock1D",
66+
block_out_channels=(32, 64, 128, 256),
67+
layers_per_block=1,
68+
downsample_each_block=True,
69+
sample_size=65536,
70+
out_channels=14,
71+
extra_in_channels=0,
72+
time_embedding_type="positional",
73+
use_timestep_embedding=True,
74+
flip_sin_to_cos=False,
75+
freq_shift=1,
76+
norm_num_groups=8,
77+
act_fn="mish",
78+
)
79+
80+
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
81+
state_dict = model
82+
hf_value_function = UNet1DModel(**config)
83+
print(f"length of state dict: {len(state_dict.keys())}")
84+
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
85+
86+
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
87+
for k, v in mapping.items():
88+
state_dict[v] = state_dict.pop(k)
89+
90+
hf_value_function.load_state_dict(state_dict)
91+
92+
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
93+
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
94+
json.dump(config, f)
95+
96+
97+
if __name__ == "__main__":
98+
unet(32)
99+
# unet(128)
100+
value_function()

src/diffusers/experimental/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# 🧨 Diffusers Experimental
2+
3+
We are adding experimental code to support novel applications and usages of the Diffusers library.
4+
Currently, the following experiments are supported:
5+
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .rl import ValueGuidedRLPipeline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .value_guided_sampling import ValueGuidedRLPipeline

0 commit comments

Comments
 (0)