Skip to content

Add Diffusion Policy for Reinforcement Learning #9824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
654f6b4
enable cpu ability
DorsaRoh Oct 31, 2024
96c62d0
model creation + comprehensive testing
DorsaRoh Oct 31, 2024
8759c12
training + tests
DorsaRoh Oct 31, 2024
2cbc3bb
all tests working
DorsaRoh Oct 31, 2024
5feb7af
remove unneeded files + clarify docs
DorsaRoh Oct 31, 2024
80fcab8
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh Oct 31, 2024
a4c7340
update train tests
DorsaRoh Oct 31, 2024
6457aec
update readme.md
DorsaRoh Oct 31, 2024
ddb7718
remove data from gitignore
DorsaRoh Oct 31, 2024
089c40a
undo cpu enabled option
DorsaRoh Oct 31, 2024
7d96254
Update README.md
DorsaRoh Oct 31, 2024
25e4638
update readme
DorsaRoh Oct 31, 2024
566f112
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh Oct 31, 2024
fbb442c
Merge branch 'main' into diffusion-policy
DorsaRoh Oct 31, 2024
4c07cea
code quality fixes
DorsaRoh Nov 1, 2024
c0e1af3
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh Nov 1, 2024
957de92
diffusion policy example
DorsaRoh Nov 1, 2024
09922df
update readme
DorsaRoh Nov 1, 2024
186b8f0
add pretrained model weights + doc
DorsaRoh Nov 1, 2024
d7ced53
add comment
DorsaRoh Nov 1, 2024
333b5cb
Merge branch 'main' into diffusion-policy
DorsaRoh Nov 1, 2024
b37b0e2
add documentation
DorsaRoh Nov 1, 2024
65dd0b7
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh Nov 1, 2024
06401cb
add docstrings
DorsaRoh Nov 1, 2024
2bf4fa0
update comments
DorsaRoh Nov 1, 2024
47f2ba5
update readme
DorsaRoh Nov 1, 2024
be98e76
fix code quality
DorsaRoh Nov 1, 2024
5bff039
Update examples/reinforcement_learning/README.md
DorsaRoh Nov 1, 2024
2bad2fc
Update examples/reinforcement_learning/diffusion_policy.py
DorsaRoh Nov 1, 2024
a7b8ef2
suggestions + safe globals for weights_only=True
DorsaRoh Nov 1, 2024
9a5cfe7
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh Nov 1, 2024
19db70e
suggestions + safe weights loading
DorsaRoh Nov 1, 2024
c8fa61a
fix code quality
DorsaRoh Nov 1, 2024
b8fe110
reformat file
DorsaRoh Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/reinforcement_learning/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Overview

## Diffusion-based Policy Learning for RL

`diffusion_policy` implements <a href="https://diffusion-policy.cs.columbia.edu/">Diffusion Policy</a>, a diffusion model to predict action sequences in reinforcement learning tasks.


## Diffuser Locomotion

These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
There are two ways to use the script, `run_diffuser_locomotion.py`.
Expand Down
69 changes: 69 additions & 0 deletions examples/reinforcement_learning/diffusion_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn
from diffusers import UNet1DModel, DDPMScheduler

class DiffusionPolicy:
def __init__(self, state_dim=5, action_dim=2, sequence_length=16, condition_dim=2, device= "cuda" if torch.cuda.is_available() else "cpu"):
self.device = device
self.sequence_length = sequence_length
self.action_dim = action_dim
self.condition_dim = condition_dim

# observation encoder - output dim matches condition_dim
self.obs_encoder = nn.Sequential(
nn.Linear(state_dim, 32),
nn.ReLU(),
nn.Linear(32, condition_dim),
).to(device)

self.model = UNet1DModel(
sample_size=sequence_length,
in_channels=action_dim + condition_dim,
out_channels=action_dim,
layers_per_block=2,
block_out_channels=(128,),
down_block_types=("DownBlock1D",),
up_block_types=("UpBlock1D",),
).to(device)

self.noise_scheduler = DDPMScheduler(
num_train_timesteps=100,
beta_schedule="squaredcos_cap_v2"
)

@torch.no_grad()
def predict(self, observation, num_inference_steps=10):
"""Generate action sequence from an observation."""
batch_size = observation.shape[0]

# encode observation
observation = observation.to(self.device)
cond = self.obs_encoder(observation) # [B, condition_dim]

# expand condition to sequence length
cond = cond.view(batch_size, self.condition_dim, 1)
cond = cond.expand(batch_size, self.condition_dim, self.sequence_length)

action = torch.randn(
(batch_size, self.action_dim, self.sequence_length),
device=self.device
)

# denoise
self.noise_scheduler.set_timesteps(num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_input = torch.cat([action, cond], dim=1)

# predict and remove noise
noise_pred = self.model(model_input, t).sample
action = self.noise_scheduler.step(noise_pred, t, action).prev_sample

return action.transpose(1, 2) # [batch_size, sequence_length, action_dim]

if __name__ == "__main__":
policy = DiffusionPolicy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we load any pre-trained model here?

Copy link
Contributor Author

@DorsaRoh DorsaRoh Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the valuable thought!
Diffusion policies are frequently tailored to specific use cases, and incorporating pretrained weights into the inference example could highly limit its general applicability and confuse users working on different tasks. Although I have pretrained weights available for a specific task that I can add here, to maintain the example’s universality, I recommend initializing the model without loading them. This will allow users to train their own models or integrate relevant pretrained weights based on their own applications!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I beg to differ. I think if we can document it sufficiently it would make more sense to showcase this with a pre-trained model.

Copy link
Contributor Author

@DorsaRoh DorsaRoh Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I have made the changes. Now, the example loads from a pretrained model and contains comprehensive documentation


# a sample single observation
observation = torch.randn(1, 5) # [batch_size, state_dim]
actions = policy.predict(observation)
print("Generated action sequence shape:", actions.shape) # Should be [1, 16, 2]