43
43
from vllm .sequence import IntermediateTensors
44
44
45
45
from .interfaces import SupportsPP
46
- from .utils import (is_pp_missing_parameter ,
46
+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
47
47
make_empty_intermediate_tensors_factory , make_layers ,
48
48
maybe_prefix )
49
49
@@ -188,6 +188,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
188
188
quant_config = vllm_config .quant_config
189
189
190
190
self .config = config
191
+ self .quant_config = quant_config
191
192
self .embed_dim = config .n_embd
192
193
self .wte = VocabParallelEmbedding (
193
194
config .vocab_size ,
@@ -228,6 +229,63 @@ def forward(
228
229
hidden_states = self .ln_f (hidden_states )
229
230
return hidden_states
230
231
232
+ def load_weights (self , weights : Iterable [Tuple [str ,
233
+ torch .Tensor ]]) -> Set [str ]:
234
+ stacked_params_mapping = [
235
+ # (param_name, shard_name, shard_id)
236
+ ("qkv_proj" , "q_proj" , "q" ),
237
+ ("qkv_proj" , "k_proj" , "k" ),
238
+ ("qkv_proj" , "v_proj" , "v" ),
239
+ ("gate_up_proj" , "gate_proj" , 0 ),
240
+ ("gate_up_proj" , "up_proj" , 1 ),
241
+ ]
242
+ params_dict = dict (self .named_parameters ())
243
+ loaded_params : Set [str ] = set ()
244
+ for name , loaded_weight in weights :
245
+ if "attn.bias" in name or "attn.masked_bias" in name :
246
+ continue
247
+
248
+ if (self .quant_config is not None and
249
+ (scale_name := self .quant_config .get_cache_scale (name ))):
250
+ # Loading kv cache quantization scales
251
+ param = params_dict [scale_name ]
252
+ weight_loader = getattr (param , "weight_loader" ,
253
+ default_weight_loader )
254
+ loaded_weight = (loaded_weight if loaded_weight .dim () == 0 else
255
+ loaded_weight [0 ])
256
+ weight_loader (param , loaded_weight )
257
+ loaded_params .add (scale_name )
258
+ continue
259
+
260
+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
261
+ if weight_name not in name :
262
+ continue
263
+ name = name .replace (weight_name , param_name )
264
+ # Skip loading extra bias for GPTQ models.
265
+ if name .endswith (".bias" ) and name not in params_dict :
266
+ continue
267
+ if is_pp_missing_parameter (name , self ):
268
+ continue
269
+ param = params_dict [name ]
270
+ weight_loader = param .weight_loader
271
+ weight_loader (param , loaded_weight , shard_id )
272
+ break
273
+ else :
274
+ name = maybe_remap_kv_scale_name (name , params_dict )
275
+ if name is None :
276
+ continue
277
+ # Skip loading extra bias for GPTQ models.
278
+ if name .endswith (".bias" ) and name not in params_dict :
279
+ continue
280
+ if is_pp_missing_parameter (name , self ):
281
+ continue
282
+ param = params_dict [name ]
283
+ weight_loader = getattr (param , "weight_loader" ,
284
+ default_weight_loader )
285
+ weight_loader (param , loaded_weight )
286
+ loaded_params .add (name )
287
+ return loaded_params
288
+
231
289
232
290
class GPTJForCausalLM (nn .Module , SupportsPP ):
233
291
@@ -285,57 +343,5 @@ def sample(
285
343
286
344
def load_weights (self , weights : Iterable [Tuple [str ,
287
345
torch .Tensor ]]) -> Set [str ]:
288
- stacked_params_mapping = [
289
- # (param_name, shard_name, shard_id)
290
- ("qkv_proj" , "q_proj" , "q" ),
291
- ("qkv_proj" , "k_proj" , "k" ),
292
- ("qkv_proj" , "v_proj" , "v" ),
293
- ("gate_up_proj" , "gate_proj" , 0 ),
294
- ("gate_up_proj" , "up_proj" , 1 ),
295
- ]
296
- params_dict = dict (self .named_parameters ())
297
- loaded_params : Set [str ] = set ()
298
- for name , loaded_weight in weights :
299
- if "attn.bias" in name or "attn.masked_bias" in name :
300
- continue
301
-
302
- if (self .quant_config is not None and
303
- (scale_name := self .quant_config .get_cache_scale (name ))):
304
- # Loading kv cache quantization scales
305
- param = params_dict [scale_name ]
306
- weight_loader = getattr (param , "weight_loader" ,
307
- default_weight_loader )
308
- loaded_weight = (loaded_weight if loaded_weight .dim () == 0 else
309
- loaded_weight [0 ])
310
- weight_loader (param , loaded_weight )
311
- loaded_params .add (scale_name )
312
- continue
313
-
314
- for (param_name , weight_name , shard_id ) in stacked_params_mapping :
315
- if weight_name not in name :
316
- continue
317
- name = name .replace (weight_name , param_name )
318
- # Skip loading extra bias for GPTQ models.
319
- if name .endswith (".bias" ) and name not in params_dict :
320
- continue
321
- if is_pp_missing_parameter (name , self ):
322
- continue
323
- param = params_dict [name ]
324
- weight_loader = param .weight_loader
325
- weight_loader (param , loaded_weight , shard_id )
326
- break
327
- else :
328
- name = maybe_remap_kv_scale_name (name , params_dict )
329
- if name is None :
330
- continue
331
- # Skip loading extra bias for GPTQ models.
332
- if name .endswith (".bias" ) and name not in params_dict :
333
- continue
334
- if is_pp_missing_parameter (name , self ):
335
- continue
336
- param = params_dict [name ]
337
- weight_loader = getattr (param , "weight_loader" ,
338
- default_weight_loader )
339
- weight_loader (param , loaded_weight )
340
- loaded_params .add (name )
341
- return loaded_params
346
+ loader = AutoWeightsLoader (self )
347
+ return loader .load_weights (weights )
0 commit comments