@@ -47,7 +47,7 @@ def __init__(
47
47
if groups_out is None :
48
48
groups_out = groups
49
49
50
- # there will always be at least one resenet
50
+ # there will always be at least one resnet
51
51
resnets = [ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = temb_channels )]
52
52
53
53
for _ in range (num_layers ):
@@ -111,7 +111,7 @@ def __init__(
111
111
if groups_out is None :
112
112
groups_out = groups
113
113
114
- # there will always be at least one resenet
114
+ # there will always be at least one resnet
115
115
resnets = [ResidualTemporalBlock1D (2 * in_channels , out_channels , embed_dim = temb_channels )]
116
116
117
117
for _ in range (num_layers ):
@@ -174,22 +174,60 @@ class UpBlock1DNoSkip(nn.Module):
174
174
175
175
176
176
class MidResTemporalBlock1D (nn .Module ):
177
- def __init__ (self , in_channels , out_channels , embed_dim , add_downsample ):
177
+ def __init__ (
178
+ self ,
179
+ in_channels ,
180
+ out_channels ,
181
+ embed_dim ,
182
+ num_layers : int = 1 ,
183
+ add_downsample : bool = False ,
184
+ add_upsample : bool = False ,
185
+ non_linearity = None ,
186
+ ):
178
187
super ().__init__ ()
179
188
self .in_channels = in_channels
180
189
self .out_channels = out_channels
181
190
self .add_downsample = add_downsample
182
- self .resnet = ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = embed_dim )
183
191
192
+ # there will always be at least one resnet
193
+ resnets = [ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = embed_dim )]
194
+
195
+ for _ in range (num_layers ):
196
+ resnets .append (ResidualTemporalBlock1D (out_channels , out_channels , embed_dim = embed_dim ))
197
+
198
+ self .resnets = nn .ModuleList (resnets )
199
+
200
+ if non_linearity == "swish" :
201
+ self .nonlinearity = lambda x : F .silu (x )
202
+ elif non_linearity == "mish" :
203
+ self .nonlinearity = nn .Mish ()
204
+ elif non_linearity == "silu" :
205
+ self .nonlinearity = nn .SiLU ()
206
+ else :
207
+ self .nonlinearity = None
208
+
209
+ self .upsample = None
210
+ if add_downsample :
211
+ self .upsample = Downsample1D (out_channels , use_conv = True )
212
+
213
+ self .downsample = None
184
214
if add_downsample :
185
215
self .downsample = Downsample1D (out_channels , use_conv = True )
186
- else :
187
- self .downsample = nn .Identity ()
188
216
189
- def forward (self , sample , temb ):
190
- sample = self .resnet (sample , temb )
191
- sample = self .downsample (sample )
192
- return sample
217
+ if self .upsample and self .downsample :
218
+ raise ValueError ("Block cannot downsample and upsample" )
219
+
220
+ def forward (self , hidden_states , temb ):
221
+ hidden_states = self .resnets [0 ](hidden_states , temb )
222
+ for resnet in self .resnets [1 :]:
223
+ hidden_states = resnet (hidden_states , temb )
224
+
225
+ if self .upsample :
226
+ hidden_states = self .upsample (hidden_states )
227
+ if self .downsample :
228
+ self .downsample = self .downsample (hidden_states )
229
+
230
+ return hidden_states
193
231
194
232
195
233
class OutConv1DBlock (nn .Module ):
@@ -203,14 +241,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
203
241
self .final_conv1d_act = nn .Mish ()
204
242
self .final_conv1d_2 = nn .Conv1d (embed_dim , out_channels , 1 )
205
243
206
- def forward (self , sample , t ):
207
- sample = self .final_conv1d_1 (sample )
208
- sample = rearrange_dims (sample )
209
- sample = self .final_conv1d_gn (sample )
210
- sample = rearrange_dims (sample )
211
- sample = self .final_conv1d_act (sample )
212
- sample = self .final_conv1d_2 (sample )
213
- return sample
244
+ def forward (self , hidden_states , temb = None ):
245
+ hidden_states = self .final_conv1d_1 (hidden_states )
246
+ hidden_states = rearrange_dims (hidden_states )
247
+ hidden_states = self .final_conv1d_gn (hidden_states )
248
+ hidden_states = rearrange_dims (hidden_states )
249
+ hidden_states = self .final_conv1d_act (hidden_states )
250
+ hidden_states = self .final_conv1d_2 (hidden_states )
251
+ return hidden_states
214
252
215
253
216
254
class OutValueFunctionBlock (nn .Module ):
@@ -224,13 +262,13 @@ def __init__(self, fc_dim, embed_dim):
224
262
]
225
263
)
226
264
227
- def forward (self , sample , t ):
228
- sample = sample .view (sample .shape [0 ], - 1 )
229
- sample = torch .cat ((sample , t ), dim = - 1 )
265
+ def forward (self , hidden_states , temb ):
266
+ hidden_states = hidden_states .view (hidden_states .shape [0 ], - 1 )
267
+ hidden_states = torch .cat ((hidden_states , temb ), dim = - 1 )
230
268
for layer in self .final_block :
231
- sample = layer (sample )
269
+ hidden_states = layer (hidden_states )
232
270
233
- return sample
271
+ return hidden_states
234
272
235
273
236
274
def get_down_block (down_block_type , num_layers , in_channels , out_channels , temb_channels , add_downsample ):
@@ -260,9 +298,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
260
298
raise ValueError (f"{ up_block_type } does not exist." )
261
299
262
300
263
- def get_mid_block (mid_block_type , in_channels , out_channels , embed_dim , add_downsample ):
301
+ def get_mid_block (mid_block_type , num_layers , in_channels , out_channels , embed_dim , add_downsample ):
264
302
if mid_block_type == "MidResTemporalBlock1D" :
265
- return MidResTemporalBlock1D (in_channels , out_channels , embed_dim , add_downsample )
303
+ return MidResTemporalBlock1D (
304
+ num_layers = num_layers ,
305
+ in_channels = in_channels ,
306
+ out_channels = out_channels ,
307
+ embed_dim = embed_dim ,
308
+ add_downsample = add_downsample ,
309
+ )
266
310
raise ValueError (f"{ mid_block_type } does not exist." )
267
311
268
312
0 commit comments