-
Notifications
You must be signed in to change notification settings - Fork 6k
Add UNet 1d for RL model for planning + colab #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 68 commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
8d1a17c
re-add RL model code
natolambert 84e94d7
match model forward api
natolambert f67b036
add register_to_config, pass training tests
natolambert e42d1c0
fix tests, update forward outputs
natolambert 2dd514e
remove unused code, some comments
natolambert b4c6188
add to docs
natolambert c53bba9
remove extra embedding code
natolambert effcbdb
unify time embedding
natolambert 7865231
remove conv1d output sequential
natolambert 35b0a43
remove sequential from conv1dblock
natolambert 9b1379d
style and deleting duplicated code
natolambert e97a610
clean files
natolambert 8642560
remove unused variables
natolambert f58c915
clean variables
natolambert ad8376d
Merge branch 'main' into rl
natolambert 3b08bea
add 1d resnet block structure for downsample
natolambert aae2a9a
rename as unet1d
natolambert dd872af
fix renaming
natolambert 9b67bb7
rename files
natolambert db012eb
add get_block(...) api
natolambert 4db6e0b
unify args for model1d like model2d
natolambert 634a526
minor cleaning
natolambert aebf547
fix docs
natolambert 305ecd8
improve 1d resnet blocks
natolambert 42855b9
Merge branch 'main' into rl
natolambert 95d3a1c
fix tests, remove permuts
natolambert 6cbb73b
fix style
natolambert ffb7355
add output activation
natolambert a6314f6
rename flax blocks file
natolambert 48a7414
Add Value Function and corresponding example script to Diffuser imple…
bglick13 3acddb5
update post merge of scripts
natolambert 713e8f2
add mdiblock / outblock architecture
natolambert 268ebdf
Pipeline cleanup (#947)
bglick13 daa05fb
Update src/diffusers/models/unet_1d_blocks.py
ea5f231
Update tests/test_models_unet.py
4f7a3a4
RL Cleanup v2 (#965)
bglick13 d90b8b1
fix quality in tests
natolambert ad8b6cf
fix quality style, split test file
natolambert e06a4a4
Merge branch 'main' into rl
natolambert 99b2c81
fix checks / tests
natolambert de4b6e4
make timesteps closer to main
natolambert ef6ca1f
unify block API
natolambert 6e3485c
Merge branch 'main' into rl
natolambert e6f1a83
unify forward api
natolambert c35a925
delete lines in examples
natolambert 949b93a
style
natolambert 2f6462b
examples style
natolambert a2dd559
all tests pass
natolambert 39dff73
make style
natolambert d5eedff
make dance_diff test pass
natolambert faeacd5
Refactoring RL PR (#1200)
be25030
Merge branch 'main' into rl
natolambert 72b7ee8
hotfix for tests
natolambert cf76a2d
quality
natolambert 2290356
fix some tests
natolambert a061f7e
change defaults
natolambert 0c58758
more mps test fixes
natolambert 691ddee
unet1d defaults
natolambert 4948ca7
do not default import experimental
natolambert ac88677
defaults for tests
natolambert ba204db
fix tests
natolambert 915c41e
fix-copies
natolambert c901889
Merge branch 'main' into rl
natolambert becc803
fix
natolambert 9b8e5ee
changes per Patrik's comments (#1285)
bglick13 3684a8c
fix renaming
natolambert ebdef16
skip more mps tests
natolambert a259aae
last test fix
natolambert 1f7702c
Update examples/rl/README.md
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,4 +163,6 @@ tags | |
*.lock | ||
|
||
# DS_Store (MacOS) | ||
.DS_Store | ||
.DS_Store | ||
# RL pipelines may produce mp4 outputs | ||
*.mp4 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Overview | ||
|
||
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. | ||
There are four scripts, | ||
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, | ||
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. | ||
|
||
You will need some RL specific requirements to run the examples: | ||
|
||
``` | ||
pip install -f https://download.pytorch.org/whl/torch_stable.html \ | ||
free-mujoco-py \ | ||
einops \ | ||
gym \ | ||
natolambert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
protobuf==3.20.1 \ | ||
git+https://github.com/rail-berkeley/d4rl.git \ | ||
mediapy \ | ||
Pillow==9.0.0 | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import d4rl # noqa | ||
import gym | ||
import tqdm | ||
from diffusers.experimental import ValueGuidedRLPipeline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perfect :-) |
||
|
||
|
||
config = dict( | ||
n_samples=64, | ||
horizon=32, | ||
num_inference_steps=20, | ||
n_guide_steps=0, | ||
scale_grad_by_std=True, | ||
scale=0.1, | ||
eta=0.0, | ||
t_grad_cutoff=2, | ||
device="cpu", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
env_name = "hopper-medium-v2" | ||
env = gym.make(env_name) | ||
|
||
pipeline = ValueGuidedRLPipeline.from_pretrained( | ||
"bglick13/hopper-medium-v2-value-function-hor32", | ||
env=env, | ||
) | ||
|
||
env.seed(0) | ||
obs = env.reset() | ||
total_reward = 0 | ||
total_score = 0 | ||
T = 1000 | ||
rollout = [obs.copy()] | ||
try: | ||
for t in tqdm.tqdm(range(T)): | ||
# Call the policy | ||
denorm_actions = pipeline(obs, planning_horizon=32) | ||
|
||
# execute action in environment | ||
next_observation, reward, terminal, _ = env.step(denorm_actions) | ||
score = env.get_normalized_score(total_reward) | ||
# update return | ||
total_reward += reward | ||
total_score += score | ||
print( | ||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" | ||
f" {total_score}" | ||
) | ||
# save observations for rendering | ||
rollout.append(next_observation.copy()) | ||
|
||
obs = next_observation | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
print(f"Total reward: {total_reward}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import d4rl # noqa | ||
import gym | ||
import tqdm | ||
from diffusers.experimental import ValueGuidedRLPipeline | ||
|
||
|
||
config = dict( | ||
n_samples=64, | ||
horizon=32, | ||
num_inference_steps=20, | ||
n_guide_steps=2, | ||
scale_grad_by_std=True, | ||
scale=0.1, | ||
eta=0.0, | ||
t_grad_cutoff=2, | ||
device="cpu", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
env_name = "hopper-medium-v2" | ||
env = gym.make(env_name) | ||
|
||
pipeline = ValueGuidedRLPipeline.from_pretrained( | ||
"bglick13/hopper-medium-v2-value-function-hor32", | ||
env=env, | ||
) | ||
|
||
env.seed(0) | ||
obs = env.reset() | ||
total_reward = 0 | ||
total_score = 0 | ||
T = 1000 | ||
rollout = [obs.copy()] | ||
try: | ||
for t in tqdm.tqdm(range(T)): | ||
# call the policy | ||
denorm_actions = pipeline(obs, planning_horizon=32) | ||
|
||
# execute action in environment | ||
next_observation, reward, terminal, _ = env.step(denorm_actions) | ||
score = env.get_normalized_score(total_reward) | ||
# update return | ||
total_reward += reward | ||
total_score += score | ||
print( | ||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" | ||
f" {total_score}" | ||
) | ||
# save observations for rendering | ||
rollout.append(next_observation.copy()) | ||
|
||
obs = next_observation | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
print(f"Total reward: {total_reward}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import json | ||
import os | ||
|
||
import torch | ||
|
||
from diffusers import UNet1DModel | ||
|
||
|
||
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) | ||
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) | ||
|
||
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) | ||
|
||
|
||
def unet(hor): | ||
if hor == 128: | ||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | ||
block_out_channels = (32, 128, 256) | ||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") | ||
|
||
elif hor == 32: | ||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | ||
block_out_channels = (32, 64, 128, 256) | ||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") | ||
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") | ||
state_dict = model.state_dict() | ||
config = dict( | ||
down_block_types=down_block_types, | ||
block_out_channels=block_out_channels, | ||
up_block_types=up_block_types, | ||
layers_per_block=1, | ||
use_timestep_embedding=True, | ||
out_block_type="OutConv1DBlock", | ||
norm_num_groups=8, | ||
downsample_each_block=False, | ||
in_channels=14, | ||
out_channels=14, | ||
extra_in_channels=0, | ||
time_embedding_type="positional", | ||
flip_sin_to_cos=False, | ||
freq_shift=1, | ||
sample_size=65536, | ||
mid_block_type="MidResTemporalBlock1D", | ||
act_fn="mish", | ||
) | ||
hf_value_function = UNet1DModel(**config) | ||
print(f"length of state dict: {len(state_dict.keys())}") | ||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | ||
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) | ||
for k, v in mapping.items(): | ||
state_dict[v] = state_dict.pop(k) | ||
hf_value_function.load_state_dict(state_dict) | ||
|
||
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") | ||
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: | ||
json.dump(config, f) | ||
|
||
|
||
def value_function(): | ||
config = dict( | ||
in_channels=14, | ||
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), | ||
up_block_types=(), | ||
out_block_type="ValueFunction", | ||
mid_block_type="ValueFunctionMidBlock1D", | ||
block_out_channels=(32, 64, 128, 256), | ||
layers_per_block=1, | ||
downsample_each_block=True, | ||
sample_size=65536, | ||
out_channels=14, | ||
extra_in_channels=0, | ||
time_embedding_type="positional", | ||
use_timestep_embedding=True, | ||
flip_sin_to_cos=False, | ||
freq_shift=1, | ||
norm_num_groups=8, | ||
act_fn="mish", | ||
) | ||
|
||
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") | ||
state_dict = model | ||
hf_value_function = UNet1DModel(**config) | ||
print(f"length of state dict: {len(state_dict.keys())}") | ||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | ||
|
||
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) | ||
for k, v in mapping.items(): | ||
state_dict[v] = state_dict.pop(k) | ||
|
||
hf_value_function.load_state_dict(state_dict) | ||
|
||
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") | ||
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: | ||
json.dump(config, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
unet(32) | ||
# unet(128) | ||
value_function() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# 🧨 Diffusers Experimental | ||
|
||
We are adding experimental code to support novel applications and usages of the Diffusers library. | ||
Currently, the following experiments are supported: | ||
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .rl import ValueGuidedRLPipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .value_guided_sampling import ValueGuidedRLPipeline |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!