Skip to content

Commit 87e067d

Browse files
authored
[Model] use AutoWeightsLoader for BigCode, GPT-J (#16823)
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
1 parent 26507f8 commit 87e067d

File tree

2 files changed

+91
-79
lines changed

2 files changed

+91
-79
lines changed

vllm/model_executor/models/gpt_bigcode.py

+30-24
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsLoRA, SupportsPP
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers)
4848

4949

@@ -244,6 +244,30 @@ def forward(
244244
hidden_states = self.ln_f(hidden_states)
245245
return hidden_states
246246

247+
def load_weights(self, weights: Iterable[Tuple[str,
248+
torch.Tensor]]) -> Set[str]:
249+
params_dict = dict(self.named_parameters(remove_duplicate=False))
250+
loaded_params: Set[str] = set()
251+
for name, loaded_weight in weights:
252+
if ".attn.bias" in name:
253+
# Skip attention mask.
254+
# NOTE: "c_attn.bias" should not be skipped.
255+
continue
256+
if is_pp_missing_parameter(name, self):
257+
continue
258+
param = params_dict[name]
259+
weight_loader = getattr(param, "weight_loader",
260+
default_weight_loader)
261+
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
262+
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
263+
weight_loader(param, loaded_weight, 'q')
264+
weight_loader(param, loaded_weight, 'k')
265+
weight_loader(param, loaded_weight, 'v')
266+
else:
267+
weight_loader(param, loaded_weight)
268+
loaded_params.add(name)
269+
return loaded_params
270+
247271

248272
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
249273
packed_modules_mapping = {"c_attn": ["c_attn"]}
@@ -315,26 +339,8 @@ def sample(
315339

316340
def load_weights(self, weights: Iterable[Tuple[str,
317341
torch.Tensor]]) -> Set[str]:
318-
params_dict = dict(self.named_parameters(remove_duplicate=False))
319-
loaded_params: Set[str] = set()
320-
for name, loaded_weight in weights:
321-
if "lm_head.weight" in name:
322-
continue
323-
if ".attn.bias" in name:
324-
# Skip attention mask.
325-
# NOTE: "c_attn.bias" should not be skipped.
326-
continue
327-
if is_pp_missing_parameter(name, self):
328-
continue
329-
param = params_dict[name]
330-
weight_loader = getattr(param, "weight_loader",
331-
default_weight_loader)
332-
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
333-
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
334-
weight_loader(param, loaded_weight, 'q')
335-
weight_loader(param, loaded_weight, 'k')
336-
weight_loader(param, loaded_weight, 'v')
337-
else:
338-
weight_loader(param, loaded_weight)
339-
loaded_params.add(name)
340-
return loaded_params
342+
loader = AutoWeightsLoader(
343+
self,
344+
skip_prefixes=(["lm_head."]),
345+
)
346+
return loader.load_weights(weights)

vllm/model_executor/models/gpt_j.py

+61-55
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsPP
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -188,6 +188,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
188188
quant_config = vllm_config.quant_config
189189

190190
self.config = config
191+
self.quant_config = quant_config
191192
self.embed_dim = config.n_embd
192193
self.wte = VocabParallelEmbedding(
193194
config.vocab_size,
@@ -228,6 +229,63 @@ def forward(
228229
hidden_states = self.ln_f(hidden_states)
229230
return hidden_states
230231

232+
def load_weights(self, weights: Iterable[Tuple[str,
233+
torch.Tensor]]) -> Set[str]:
234+
stacked_params_mapping = [
235+
# (param_name, shard_name, shard_id)
236+
("qkv_proj", "q_proj", "q"),
237+
("qkv_proj", "k_proj", "k"),
238+
("qkv_proj", "v_proj", "v"),
239+
("gate_up_proj", "gate_proj", 0),
240+
("gate_up_proj", "up_proj", 1),
241+
]
242+
params_dict = dict(self.named_parameters())
243+
loaded_params: Set[str] = set()
244+
for name, loaded_weight in weights:
245+
if "attn.bias" in name or "attn.masked_bias" in name:
246+
continue
247+
248+
if (self.quant_config is not None and
249+
(scale_name := self.quant_config.get_cache_scale(name))):
250+
# Loading kv cache quantization scales
251+
param = params_dict[scale_name]
252+
weight_loader = getattr(param, "weight_loader",
253+
default_weight_loader)
254+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
255+
loaded_weight[0])
256+
weight_loader(param, loaded_weight)
257+
loaded_params.add(scale_name)
258+
continue
259+
260+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
261+
if weight_name not in name:
262+
continue
263+
name = name.replace(weight_name, param_name)
264+
# Skip loading extra bias for GPTQ models.
265+
if name.endswith(".bias") and name not in params_dict:
266+
continue
267+
if is_pp_missing_parameter(name, self):
268+
continue
269+
param = params_dict[name]
270+
weight_loader = param.weight_loader
271+
weight_loader(param, loaded_weight, shard_id)
272+
break
273+
else:
274+
name = maybe_remap_kv_scale_name(name, params_dict)
275+
if name is None:
276+
continue
277+
# Skip loading extra bias for GPTQ models.
278+
if name.endswith(".bias") and name not in params_dict:
279+
continue
280+
if is_pp_missing_parameter(name, self):
281+
continue
282+
param = params_dict[name]
283+
weight_loader = getattr(param, "weight_loader",
284+
default_weight_loader)
285+
weight_loader(param, loaded_weight)
286+
loaded_params.add(name)
287+
return loaded_params
288+
231289

232290
class GPTJForCausalLM(nn.Module, SupportsPP):
233291

@@ -285,57 +343,5 @@ def sample(
285343

286344
def load_weights(self, weights: Iterable[Tuple[str,
287345
torch.Tensor]]) -> Set[str]:
288-
stacked_params_mapping = [
289-
# (param_name, shard_name, shard_id)
290-
("qkv_proj", "q_proj", "q"),
291-
("qkv_proj", "k_proj", "k"),
292-
("qkv_proj", "v_proj", "v"),
293-
("gate_up_proj", "gate_proj", 0),
294-
("gate_up_proj", "up_proj", 1),
295-
]
296-
params_dict = dict(self.named_parameters())
297-
loaded_params: Set[str] = set()
298-
for name, loaded_weight in weights:
299-
if "attn.bias" in name or "attn.masked_bias" in name:
300-
continue
301-
302-
if (self.quant_config is not None and
303-
(scale_name := self.quant_config.get_cache_scale(name))):
304-
# Loading kv cache quantization scales
305-
param = params_dict[scale_name]
306-
weight_loader = getattr(param, "weight_loader",
307-
default_weight_loader)
308-
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
309-
loaded_weight[0])
310-
weight_loader(param, loaded_weight)
311-
loaded_params.add(scale_name)
312-
continue
313-
314-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
315-
if weight_name not in name:
316-
continue
317-
name = name.replace(weight_name, param_name)
318-
# Skip loading extra bias for GPTQ models.
319-
if name.endswith(".bias") and name not in params_dict:
320-
continue
321-
if is_pp_missing_parameter(name, self):
322-
continue
323-
param = params_dict[name]
324-
weight_loader = param.weight_loader
325-
weight_loader(param, loaded_weight, shard_id)
326-
break
327-
else:
328-
name = maybe_remap_kv_scale_name(name, params_dict)
329-
if name is None:
330-
continue
331-
# Skip loading extra bias for GPTQ models.
332-
if name.endswith(".bias") and name not in params_dict:
333-
continue
334-
if is_pp_missing_parameter(name, self):
335-
continue
336-
param = params_dict[name]
337-
weight_loader = getattr(param, "weight_loader",
338-
default_weight_loader)
339-
weight_loader(param, loaded_weight)
340-
loaded_params.add(name)
341-
return loaded_params
346+
loader = AutoWeightsLoader(self)
347+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)