Skip to content

Commit 713e8f2

Browse files
committed
add mdiblock / outblock architecture
1 parent 3acddb5 commit 713e8f2

File tree

2 files changed

+85
-48
lines changed

2 files changed

+85
-48
lines changed

src/diffusers/models/unet_1d.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
out_channels: int = 14,
6161
down_block_types: Tuple[str] = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
6262
up_block_types: Tuple[str] = ("UpResnetBlock1D", "UpResnetBlock1D"),
63-
mid_block_types: Tuple[str] = ("MidResTemporalBlock1D", "MidResTemporalBlock1D"),
63+
mid_block_type: Tuple[str] = "MidResTemporalBlock1D",
6464
out_block_type: str = "OutConv1DBlock",
6565
block_out_channels: Tuple[int] = (32, 128, 256),
6666
act_fn: str = "mish",
@@ -79,7 +79,9 @@ def __init__(
7979
)
8080

8181
self.down_blocks = nn.ModuleList([])
82+
self.mid_block = None
8283
self.up_blocks = nn.ModuleList([])
84+
self.out_block = None
8385
mid_dim = block_out_channels[-1]
8486

8587
# down
@@ -101,25 +103,15 @@ def __init__(
101103
self.down_blocks.append(down_block)
102104

103105
# mid
104-
self.mid_blocks = nn.ModuleList([])
105-
for i, mid_block_type in enumerate(mid_block_types):
106-
if always_downsample:
107-
mid_block = get_mid_block(
108-
mid_block_type,
109-
in_channels=mid_dim // (i + 1),
110-
out_channels=mid_dim // ((i + 1) * 2),
111-
embed_dim=block_out_channels[0],
112-
add_downsample=True,
113-
)
114-
else:
115-
mid_block = get_mid_block(
116-
mid_block_type,
117-
in_channels=mid_dim,
118-
out_channels=mid_dim,
119-
embed_dim=block_out_channels[0],
120-
add_downsample=False,
121-
)
122-
self.mid_blocks.append(mid_block)
106+
self.mid_block = get_mid_block(
107+
mid_block_type,
108+
in_channels=mid_dim,
109+
out_channels=mid_dim,
110+
embed_dim=block_out_channels[0],
111+
num_layers=layers_per_block,
112+
add_downsample=always_downsample,
113+
)
114+
123115
# up
124116
reversed_block_out_channels = list(reversed(block_out_channels))
125117
for i, up_block_type in enumerate(up_block_types):
@@ -184,15 +176,16 @@ def forward(
184176
down_block_res_samples.append(res_samples[0])
185177

186178
# 3. mid
187-
for mid_block in self.mid_blocks:
188-
sample = mid_block(sample, temb)
179+
if self.mid_block:
180+
sample = self.mid_block(sample, temb)
189181

190182
# 4. up
191183
for up_block in self.up_blocks:
192184
sample = up_block(hidden_states=sample, res_hidden_states=down_block_res_samples.pop(), temb=temb)
193185

194186
# 5. post-process
195-
sample = self.out_block(sample, temb)
187+
if self.out_block:
188+
sample = self.out_block(sample, temb)
196189

197190
if not return_dict:
198191
return (sample,)

src/diffusers/models/unet_1d_blocks.py

+69-25
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
if groups_out is None:
4848
groups_out = groups
4949

50-
# there will always be at least one resenet
50+
# there will always be at least one resnet
5151
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
5252

5353
for _ in range(num_layers):
@@ -111,7 +111,7 @@ def __init__(
111111
if groups_out is None:
112112
groups_out = groups
113113

114-
# there will always be at least one resenet
114+
# there will always be at least one resnet
115115
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
116116

117117
for _ in range(num_layers):
@@ -174,22 +174,60 @@ class UpBlock1DNoSkip(nn.Module):
174174

175175

176176
class MidResTemporalBlock1D(nn.Module):
177-
def __init__(self, in_channels, out_channels, embed_dim, add_downsample):
177+
def __init__(
178+
self,
179+
in_channels,
180+
out_channels,
181+
embed_dim,
182+
num_layers: int = 1,
183+
add_downsample: bool = False,
184+
add_upsample: bool = False,
185+
non_linearity=None,
186+
):
178187
super().__init__()
179188
self.in_channels = in_channels
180189
self.out_channels = out_channels
181190
self.add_downsample = add_downsample
182-
self.resnet = ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)
183191

192+
# there will always be at least one resnet
193+
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
194+
195+
for _ in range(num_layers):
196+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
197+
198+
self.resnets = nn.ModuleList(resnets)
199+
200+
if non_linearity == "swish":
201+
self.nonlinearity = lambda x: F.silu(x)
202+
elif non_linearity == "mish":
203+
self.nonlinearity = nn.Mish()
204+
elif non_linearity == "silu":
205+
self.nonlinearity = nn.SiLU()
206+
else:
207+
self.nonlinearity = None
208+
209+
self.upsample = None
210+
if add_downsample:
211+
self.upsample = Downsample1D(out_channels, use_conv=True)
212+
213+
self.downsample = None
184214
if add_downsample:
185215
self.downsample = Downsample1D(out_channels, use_conv=True)
186-
else:
187-
self.downsample = nn.Identity()
188216

189-
def forward(self, sample, temb):
190-
sample = self.resnet(sample, temb)
191-
sample = self.downsample(sample)
192-
return sample
217+
if self.upsample and self.downsample:
218+
raise ValueError("Block cannot downsample and upsample")
219+
220+
def forward(self, hidden_states, temb):
221+
hidden_states = self.resnets[0](hidden_states, temb)
222+
for resnet in self.resnets[1:]:
223+
hidden_states = resnet(hidden_states, temb)
224+
225+
if self.upsample:
226+
hidden_states = self.upsample(hidden_states)
227+
if self.downsample:
228+
self.downsample = self.downsample(hidden_states)
229+
230+
return hidden_states
193231

194232

195233
class OutConv1DBlock(nn.Module):
@@ -203,14 +241,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
203241
self.final_conv1d_act = nn.Mish()
204242
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
205243

206-
def forward(self, sample, t):
207-
sample = self.final_conv1d_1(sample)
208-
sample = rearrange_dims(sample)
209-
sample = self.final_conv1d_gn(sample)
210-
sample = rearrange_dims(sample)
211-
sample = self.final_conv1d_act(sample)
212-
sample = self.final_conv1d_2(sample)
213-
return sample
244+
def forward(self, hidden_states, temb=None):
245+
hidden_states = self.final_conv1d_1(hidden_states)
246+
hidden_states = rearrange_dims(hidden_states)
247+
hidden_states = self.final_conv1d_gn(hidden_states)
248+
hidden_states = rearrange_dims(hidden_states)
249+
hidden_states = self.final_conv1d_act(hidden_states)
250+
hidden_states = self.final_conv1d_2(hidden_states)
251+
return hidden_states
214252

215253

216254
class OutValueFunctionBlock(nn.Module):
@@ -224,13 +262,13 @@ def __init__(self, fc_dim, embed_dim):
224262
]
225263
)
226264

227-
def forward(self, sample, t):
228-
sample = sample.view(sample.shape[0], -1)
229-
sample = torch.cat((sample, t), dim=-1)
265+
def forward(self, hidden_states, temb):
266+
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
267+
hidden_states = torch.cat((hidden_states, temb), dim=-1)
230268
for layer in self.final_block:
231-
sample = layer(sample)
269+
hidden_states = layer(hidden_states)
232270

233-
return sample
271+
return hidden_states
234272

235273

236274
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
@@ -260,9 +298,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
260298
raise ValueError(f"{up_block_type} does not exist.")
261299

262300

263-
def get_mid_block(mid_block_type, in_channels, out_channels, embed_dim, add_downsample):
301+
def get_mid_block(mid_block_type, num_layers, in_channels, out_channels, embed_dim, add_downsample):
264302
if mid_block_type == "MidResTemporalBlock1D":
265-
return MidResTemporalBlock1D(in_channels, out_channels, embed_dim, add_downsample)
303+
return MidResTemporalBlock1D(
304+
num_layers=num_layers,
305+
in_channels=in_channels,
306+
out_channels=out_channels,
307+
embed_dim=embed_dim,
308+
add_downsample=add_downsample,
309+
)
266310
raise ValueError(f"{mid_block_type} does not exist.")
267311

268312

0 commit comments

Comments
 (0)