-
Notifications
You must be signed in to change notification settings - Fork 6k
Add Diffusion Policy for Reinforcement Learning #9824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
654f6b4
enable cpu ability
DorsaRoh 96c62d0
model creation + comprehensive testing
DorsaRoh 8759c12
training + tests
DorsaRoh 2cbc3bb
all tests working
DorsaRoh 5feb7af
remove unneeded files + clarify docs
DorsaRoh 80fcab8
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh a4c7340
update train tests
DorsaRoh 6457aec
update readme.md
DorsaRoh ddb7718
remove data from gitignore
DorsaRoh 089c40a
undo cpu enabled option
DorsaRoh 7d96254
Update README.md
DorsaRoh 25e4638
update readme
DorsaRoh 566f112
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh fbb442c
Merge branch 'main' into diffusion-policy
DorsaRoh 4c07cea
code quality fixes
DorsaRoh c0e1af3
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh 957de92
diffusion policy example
DorsaRoh 09922df
update readme
DorsaRoh 186b8f0
add pretrained model weights + doc
DorsaRoh d7ced53
add comment
DorsaRoh 333b5cb
Merge branch 'main' into diffusion-policy
DorsaRoh b37b0e2
add documentation
DorsaRoh 65dd0b7
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh 06401cb
add docstrings
DorsaRoh 2bf4fa0
update comments
DorsaRoh 47f2ba5
update readme
DorsaRoh be98e76
fix code quality
DorsaRoh 5bff039
Update examples/reinforcement_learning/README.md
DorsaRoh 2bad2fc
Update examples/reinforcement_learning/diffusion_policy.py
DorsaRoh a7b8ef2
suggestions + safe globals for weights_only=True
DorsaRoh 9a5cfe7
Merge branch 'diffusion-policy' of https://github.com/DorsaRoh/diffus…
DorsaRoh 19db70e
suggestions + safe weights loading
DorsaRoh c8fa61a
fix code quality
DorsaRoh b8fe110
reformat file
DorsaRoh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
import torch.nn as nn | ||
from diffusers import UNet1DModel, DDPMScheduler | ||
|
||
class DiffusionPolicy: | ||
def __init__(self, state_dim=5, action_dim=2, sequence_length=16, condition_dim=2, device= "cuda" if torch.cuda.is_available() else "cpu"): | ||
self.device = device | ||
self.sequence_length = sequence_length | ||
self.action_dim = action_dim | ||
self.condition_dim = condition_dim | ||
|
||
# observation encoder - output dim matches condition_dim | ||
self.obs_encoder = nn.Sequential( | ||
nn.Linear(state_dim, 32), | ||
nn.ReLU(), | ||
nn.Linear(32, condition_dim), | ||
).to(device) | ||
|
||
self.model = UNet1DModel( | ||
sample_size=sequence_length, | ||
in_channels=action_dim + condition_dim, | ||
out_channels=action_dim, | ||
layers_per_block=2, | ||
block_out_channels=(128,), | ||
down_block_types=("DownBlock1D",), | ||
up_block_types=("UpBlock1D",), | ||
).to(device) | ||
|
||
self.noise_scheduler = DDPMScheduler( | ||
num_train_timesteps=100, | ||
beta_schedule="squaredcos_cap_v2" | ||
) | ||
|
||
@torch.no_grad() | ||
def predict(self, observation, num_inference_steps=10): | ||
"""Generate action sequence from an observation.""" | ||
batch_size = observation.shape[0] | ||
|
||
# encode observation | ||
observation = observation.to(self.device) | ||
cond = self.obs_encoder(observation) # [B, condition_dim] | ||
|
||
# expand condition to sequence length | ||
cond = cond.view(batch_size, self.condition_dim, 1) | ||
cond = cond.expand(batch_size, self.condition_dim, self.sequence_length) | ||
|
||
action = torch.randn( | ||
(batch_size, self.action_dim, self.sequence_length), | ||
device=self.device | ||
) | ||
|
||
# denoise | ||
self.noise_scheduler.set_timesteps(num_inference_steps) | ||
for t in self.noise_scheduler.timesteps: | ||
model_input = torch.cat([action, cond], dim=1) | ||
|
||
# predict and remove noise | ||
noise_pred = self.model(model_input, t).sample | ||
action = self.noise_scheduler.step(noise_pred, t, action).prev_sample | ||
|
||
return action.transpose(1, 2) # [batch_size, sequence_length, action_dim] | ||
|
||
if __name__ == "__main__": | ||
policy = DiffusionPolicy() | ||
|
||
# a sample single observation | ||
observation = torch.randn(1, 5) # [batch_size, state_dim] | ||
actions = policy.predict(observation) | ||
print("Generated action sequence shape:", actions.shape) # Should be [1, 16, 2] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we load any pre-trained model here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the valuable thought!
Diffusion policies are frequently tailored to specific use cases, and incorporating pretrained weights into the inference example could highly limit its general applicability and confuse users working on different tasks. Although I have pretrained weights available for a specific task that I can add here, to maintain the example’s universality, I recommend initializing the model without loading them. This will allow users to train their own models or integrate relevant pretrained weights based on their own applications!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I beg to differ. I think if we can document it sufficiently it would make more sense to showcase this with a pre-trained model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I have made the changes. Now, the example loads from a pretrained model and contains comprehensive documentation