|
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 |
|
23 |
| -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel |
| 23 | +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction |
24 | 24 | from diffusers.utils import floats_tensor, slow, torch_device
|
25 | 25 |
|
26 | 26 | from .test_modeling_common import ModelTesterMixin
|
@@ -524,3 +524,86 @@ def test_output_pretrained(self):
|
524 | 524 | def test_forward_with_norm_groups(self):
|
525 | 525 | # Not implemented yet for this UNet
|
526 | 526 | pass
|
| 527 | + |
| 528 | + |
| 529 | +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): |
| 530 | + model_class = UNet1DModel |
| 531 | + |
| 532 | + @property |
| 533 | + def dummy_input(self): |
| 534 | + batch_size = 4 |
| 535 | + num_features = 14 |
| 536 | + seq_len = 16 |
| 537 | + |
| 538 | + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) |
| 539 | + time_step = torch.tensor([10] * batch_size).to(torch_device) |
| 540 | + |
| 541 | + return {"sample": noise, "timestep": time_step} |
| 542 | + |
| 543 | + @property |
| 544 | + def input_shape(self): |
| 545 | + return (4, 14, 16) |
| 546 | + |
| 547 | + @property |
| 548 | + def output_shape(self): |
| 549 | + return (4, 14, 1) |
| 550 | + |
| 551 | + def test_ema_training(self): |
| 552 | + pass |
| 553 | + |
| 554 | + def test_training(self): |
| 555 | + pass |
| 556 | + |
| 557 | + def prepare_init_args_and_inputs_for_common(self): |
| 558 | + init_dict = { |
| 559 | + "block_out_channels": (32, 64, 128, 256), |
| 560 | + "in_channels": 14, |
| 561 | + "out_channels": 14, |
| 562 | + } |
| 563 | + inputs_dict = self.dummy_input |
| 564 | + return init_dict, inputs_dict |
| 565 | + |
| 566 | + def test_from_pretrained_hub(self): |
| 567 | + unet, loading_info = UNet1DModel.from_pretrained( |
| 568 | + "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True |
| 569 | + ) |
| 570 | + value_function, vf_loading_info = UNet1DModel.from_pretrained( |
| 571 | + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
| 572 | + ) |
| 573 | + self.assertIsNotNone(unet) |
| 574 | + self.assertEqual(len(loading_info["missing_keys"]), 0) |
| 575 | + self.assertIsNotNone(value_function) |
| 576 | + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) |
| 577 | + |
| 578 | + unet.to(torch_device) |
| 579 | + value_function.to(torch_device) |
| 580 | + image = value_function(**self.dummy_input) |
| 581 | + |
| 582 | + assert image is not None, "Make sure output is not None" |
| 583 | + |
| 584 | + def test_output_pretrained(self): |
| 585 | + value_function, vf_loading_info = UNet1DModel.from_pretrained( |
| 586 | + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
| 587 | + ) |
| 588 | + torch.manual_seed(0) |
| 589 | + if torch.cuda.is_available(): |
| 590 | + torch.cuda.manual_seed_all(0) |
| 591 | + |
| 592 | + num_features = value_function.in_channels |
| 593 | + seq_len = 14 |
| 594 | + noise = torch.randn((1, seq_len, num_features)).permute( |
| 595 | + 0, 2, 1 |
| 596 | + ) # match original, we can update values and remove |
| 597 | + time_step = torch.full((num_features,), 0) |
| 598 | + |
| 599 | + with torch.no_grad(): |
| 600 | + output = value_function(noise, time_step).sample |
| 601 | + |
| 602 | + # fmt: off |
| 603 | + expected_output_slice = torch.tensor([207.0272] * seq_len) |
| 604 | + # fmt: on |
| 605 | + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) |
| 606 | + |
| 607 | + def test_forward_with_norm_groups(self): |
| 608 | + # Not implemented yet for this UNet |
| 609 | + pass |
0 commit comments