|
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 |
|
23 |
| -from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction |
| 23 | +from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel |
24 | 24 | from diffusers.utils import floats_tensor, slow, torch_device
|
25 | 25 |
|
26 | 26 | from .test_modeling_common import ModelTesterMixin
|
@@ -524,86 +524,3 @@ def test_output_pretrained(self):
|
524 | 524 | def test_forward_with_norm_groups(self):
|
525 | 525 | # Not implemented yet for this UNet
|
526 | 526 | pass
|
527 |
| - |
528 |
| - |
529 |
| -class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): |
530 |
| - model_class = ValueFunction |
531 |
| - |
532 |
| - @property |
533 |
| - def dummy_input(self): |
534 |
| - batch_size = 4 |
535 |
| - num_features = 14 |
536 |
| - seq_len = 16 |
537 |
| - |
538 |
| - noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) |
539 |
| - time_step = torch.tensor([10] * batch_size).to(torch_device) |
540 |
| - |
541 |
| - return {"sample": noise, "timestep": time_step} |
542 |
| - |
543 |
| - @property |
544 |
| - def input_shape(self): |
545 |
| - return (4, 14, 16) |
546 |
| - |
547 |
| - @property |
548 |
| - def output_shape(self): |
549 |
| - return (4, 14, 1) |
550 |
| - |
551 |
| - def test_ema_training(self): |
552 |
| - pass |
553 |
| - |
554 |
| - def test_training(self): |
555 |
| - pass |
556 |
| - |
557 |
| - def prepare_init_args_and_inputs_for_common(self): |
558 |
| - init_dict = { |
559 |
| - "block_out_channels": (32, 64, 128, 256), |
560 |
| - "in_channels": 14, |
561 |
| - "out_channels": 14, |
562 |
| - } |
563 |
| - inputs_dict = self.dummy_input |
564 |
| - return init_dict, inputs_dict |
565 |
| - |
566 |
| - def test_from_pretrained_hub(self): |
567 |
| - unet, loading_info = UNet1DModel.from_pretrained( |
568 |
| - "bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True |
569 |
| - ) |
570 |
| - value_function, vf_loading_info = ValueFunction.from_pretrained( |
571 |
| - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
572 |
| - ) |
573 |
| - self.assertIsNotNone(unet) |
574 |
| - self.assertEqual(len(loading_info["missing_keys"]), 0) |
575 |
| - self.assertIsNotNone(value_function) |
576 |
| - self.assertEqual(len(vf_loading_info["missing_keys"]), 0) |
577 |
| - |
578 |
| - unet.to(torch_device) |
579 |
| - value_function.to(torch_device) |
580 |
| - image = value_function(**self.dummy_input) |
581 |
| - |
582 |
| - assert image is not None, "Make sure output is not None" |
583 |
| - |
584 |
| - def test_output_pretrained(self): |
585 |
| - value_function, vf_loading_info = ValueFunction.from_pretrained( |
586 |
| - "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True |
587 |
| - ) |
588 |
| - torch.manual_seed(0) |
589 |
| - if torch.cuda.is_available(): |
590 |
| - torch.cuda.manual_seed_all(0) |
591 |
| - |
592 |
| - num_features = value_function.in_channels |
593 |
| - seq_len = 14 |
594 |
| - noise = torch.randn((1, seq_len, num_features)).permute( |
595 |
| - 0, 2, 1 |
596 |
| - ) # match original, we can update values and remove |
597 |
| - time_step = torch.full((num_features,), 0) |
598 |
| - |
599 |
| - with torch.no_grad(): |
600 |
| - output = value_function(noise, time_step).sample |
601 |
| - |
602 |
| - # fmt: off |
603 |
| - expected_output_slice = torch.tensor([207.0272] * seq_len) |
604 |
| - # fmt: on |
605 |
| - self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) |
606 |
| - |
607 |
| - def test_forward_with_norm_groups(self): |
608 |
| - # Not implemented yet for this UNet |
609 |
| - pass |
0 commit comments