Skip to content

Commit 3acddb5

Browse files
committed
update post merge of scripts
1 parent 48a7414 commit 3acddb5

File tree

4 files changed

+2
-221
lines changed

4 files changed

+2
-221
lines changed

src/diffusers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
if is_torch_available():
2020
from .modeling_utils import ModelMixin
21-
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction, VQModel
21+
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
2222
from .optimization import (
2323
get_constant_schedule,
2424
get_constant_schedule_with_warmup,

src/diffusers/models/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from .unet_1d import UNet1DModel
2020
from .unet_2d import UNet2DModel
2121
from .unet_2d_condition import UNet2DConditionModel
22-
from .unet_rl import ValueFunction
2322
from .vae import AutoencoderKL, VQModel
2423

2524
if is_flax_available():

src/diffusers/models/unet_rl.py

-135
This file was deleted.

tests/test_models_unet.py

+1-84
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import torch
2222

23-
from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel, ValueFunction
23+
from diffusers import UNet1DModel, UNet2DConditionModel, UNet2DModel
2424
from diffusers.utils import floats_tensor, slow, torch_device
2525

2626
from .test_modeling_common import ModelTesterMixin
@@ -524,86 +524,3 @@ def test_output_pretrained(self):
524524
def test_forward_with_norm_groups(self):
525525
# Not implemented yet for this UNet
526526
pass
527-
528-
529-
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
530-
model_class = ValueFunction
531-
532-
@property
533-
def dummy_input(self):
534-
batch_size = 4
535-
num_features = 14
536-
seq_len = 16
537-
538-
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
539-
time_step = torch.tensor([10] * batch_size).to(torch_device)
540-
541-
return {"sample": noise, "timestep": time_step}
542-
543-
@property
544-
def input_shape(self):
545-
return (4, 14, 16)
546-
547-
@property
548-
def output_shape(self):
549-
return (4, 14, 1)
550-
551-
def test_ema_training(self):
552-
pass
553-
554-
def test_training(self):
555-
pass
556-
557-
def prepare_init_args_and_inputs_for_common(self):
558-
init_dict = {
559-
"block_out_channels": (32, 64, 128, 256),
560-
"in_channels": 14,
561-
"out_channels": 14,
562-
}
563-
inputs_dict = self.dummy_input
564-
return init_dict, inputs_dict
565-
566-
def test_from_pretrained_hub(self):
567-
unet, loading_info = UNet1DModel.from_pretrained(
568-
"bglick13/hopper-medium-v2-unet-hor32", output_loading_info=True
569-
)
570-
value_function, vf_loading_info = ValueFunction.from_pretrained(
571-
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
572-
)
573-
self.assertIsNotNone(unet)
574-
self.assertEqual(len(loading_info["missing_keys"]), 0)
575-
self.assertIsNotNone(value_function)
576-
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
577-
578-
unet.to(torch_device)
579-
value_function.to(torch_device)
580-
image = value_function(**self.dummy_input)
581-
582-
assert image is not None, "Make sure output is not None"
583-
584-
def test_output_pretrained(self):
585-
value_function, vf_loading_info = ValueFunction.from_pretrained(
586-
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True
587-
)
588-
torch.manual_seed(0)
589-
if torch.cuda.is_available():
590-
torch.cuda.manual_seed_all(0)
591-
592-
num_features = value_function.in_channels
593-
seq_len = 14
594-
noise = torch.randn((1, seq_len, num_features)).permute(
595-
0, 2, 1
596-
) # match original, we can update values and remove
597-
time_step = torch.full((num_features,), 0)
598-
599-
with torch.no_grad():
600-
output = value_function(noise, time_step).sample
601-
602-
# fmt: off
603-
expected_output_slice = torch.tensor([207.0272] * seq_len)
604-
# fmt: on
605-
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
606-
607-
def test_forward_with_norm_groups(self):
608-
# Not implemented yet for this UNet
609-
pass

0 commit comments

Comments
 (0)