Skip to content

Commit 95a8555

Browse files
mikekgfbmalfet
authored andcommitted
Code formatting (pytorch#457)
* scrub & reformat code * use full paths * set tiktoken init to False, not None to align with new tokenizer chatting logic
1 parent 4f53f5e commit 95a8555

File tree

9 files changed

+35
-30
lines changed

9 files changed

+35
-30
lines changed

GPTQ.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,17 @@ def cuda(self):
132132

133133
class GenericGPTQRunner(fx.Interpreter):
134134
"""
135-
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
136-
It uses torch._dynamo.export to obtain a graph of the model and then hooks
137-
into function calls and when it detects a linear, it applies GPTQ to the weight
138-
given the calibration of inputs passed in at initialization. It puts the results
139-
into the state_dict so that the quantized model weights/qparams can be loaded
140-
directly into the model.
135+
This is a generic GPTQ runner that takes an existing model and
136+
applies GPTQ. It uses torch._dynamo.export to obtain a graph of
137+
the model and then hooks into function calls and when it detects a
138+
linear, it applies GPTQ to the weight given the calibration of
139+
inputs passed in at initialization. It puts the results into the
140+
state_dict so that the quantized model weights/qparams can be
141+
loaded directly into the model.
141142
142143
This class is expected to work in concert with a GPTQSimpleQuantizer
143144
class to define the specific type of quantization being done.
145+
144146
"""
145147

146148
def __init__(
@@ -206,7 +208,7 @@ def get_quantized_state_dict(self):
206208
self.gptq_done
207209
), "need to run GPTQRunner before you can get_quantized_state_dict"
208210
quantized_state_dict = self.new_state_dict
209-
# Don't want to store/load the kv_cache so remove it from the state_dict
211+
210212
del_list = []
211213
for param_fqn in quantized_state_dict:
212214
if "kv_cache" in param_fqn:
@@ -224,7 +226,8 @@ def tensors_to_cuda(args):
224226

225227
# flatten args and kwargs together
226228
flat_args, spec = tree_flatten((args, kwargs))
227-
# move all single tensors to cuda, will move MultiInputs to cuda one at a time
229+
# move all single tensors to cuda, will move MultiInputs
230+
# to cuda one at a time
228231
flat_args = tensors_to_cuda(flat_args)
229232

230233
has_multi_input = MultiInput in [type(x) for x in flat_args]
@@ -421,8 +424,9 @@ def faster_quant(self, H, W):
421424
if all_qparams == []:
422425
all_qparams.append(cur_qparams)
423426

424-
# convert a list of qparams objects into a single one. enerally by
425-
# concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor
427+
# convert a list of qparams objects into a single
428+
# one. generally by concatenating a bunch of n,1 scale/zeros
429+
# tensors into a n,num_groups tensor
426430
all_qparams = self.combine_qparams_list_func(all_qparams)
427431
Q = self.quantize_func(DQ, all_qparams)
428432
return Q, DQ.to(orig_dtype), all_qparams

build/builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ def validate_model(
180180
if model is None:
181181
return
182182

183-
condition = False # not (self.is_tiktoken == model.config.use_tiktoken) or not (self.is_sentencepiece == not model.config.use_tiktoken)
183+
is_tiktoken = self.is_tiktoken
184+
is_sentencepiece = self.is_sentencepiece
185+
use_tiktoken = model.config.use_tiktoken
184186

185-
if condition:
187+
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
186188
raise RuntimeError(
187-
f"model-specified tokenizer ({tokenizer_setting_to_name(model.config.use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(self.is_tiktoken)} for {model_description}"
189+
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
188190
)
189191

190192
return

build/gguf_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear
1818

1919
from build.gguf_util import Q4_0, to_float
20-
from .model import ModelArgs, Transformer
20+
from build.model import ModelArgs, Transformer
2121

2222
logger: logging.Logger = logging.getLogger(__name__)
2323

build/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ModelArgs:
3636
norm_eps: float = 1e-5
3737
multiple_of: int = 256
3838
ffn_dim_multiplier: Optional[int] = None
39-
use_tiktoken: Optional[bool] = None
39+
use_tiktoken: bool = False
4040

4141
def __post_init__(self):
4242
if self.n_local_heads == -1:

chat_in_browser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def main():
3838

3939
@app.route("/chat", methods=["GET", "POST"])
4040
def chat():
41-
# Retrieve the HTTP POST request parameter value from 'request.form' dictionary
41+
# Retrieve the HTTP POST request parameter value from
42+
# 'request.form' dictionary
4243
_prompt = request.form.get("prompt", "")
4344
proc.stdin.write((_prompt + "\n").encode("utf-8"))
4445
proc.stdin.flush()

cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from build.utils import allowable_dtype_names, allowable_params_table
1313
from download import download_and_convert, is_model_downloaded
1414

15-
# CPU is always available and also exportable to ExecuTorch
16-
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
15+
default_device = "cpu"
1716

1817

1918
# Handle CLI arguments that are common to a majority of subcommands.

download.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def download_and_convert(
6666
model_config = resolve_model_config(model)
6767
model_dir = models_dir / model_config.name
6868

69-
# Download into a temporary directory. We'll move to the final location once
70-
# the download and conversion is complete. This allows recovery in the event
71-
# that the download or conversion fails unexpectedly.
69+
# Download into a temporary directory. We'll move to the final
70+
# location once the download and conversion is complete. This
71+
# allows recovery in the event that the download or conversion
72+
# fails unexpectedly.
7273
temp_dir = models_dir / "downloads" / model_config.name
7374
if os.path.isdir(temp_dir):
7475
shutil.rmtree(temp_dir)

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
7878
max_seq_length = min(T_new, model.config.block_size)
7979

8080
device, dtype = prompt.device, prompt.dtype
81-
# create an empty tensor of the expected final shape and fill in the current tokens
81+
# create an empty tensor of the expected final shape and
82+
# fill in the current tokens
8283
empty = torch.empty(T_new, dtype=dtype, device=device)
8384
empty[:T] = prompt
8485
seq = empty

generate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def validate_build(
103103
)
104104

105105
@classmethod
106-
def from_args(cls, args): # -> GeneratorArgs:
106+
def from_args(cls, args):
107107
return cls(
108108
prompt=args.prompt,
109109
encoded_prompt=None,
@@ -326,7 +326,8 @@ def generate(
326326
is_speculative = draft_model is not None
327327
device, dtype = prompt.device, prompt.dtype
328328

329-
# create an empty tensor of the expected final shape and fill in the current tokens
329+
# create an empty tensor of the expected final shape and
330+
# fill in the current tokens
330331
T = prompt.size(0)
331332
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
332333
T_new = T + max_new_tokens
@@ -338,7 +339,8 @@ def generate(
338339
if is_speculative and draft_model is not model:
339340
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
340341

341-
# create an empty tensor of the expected final shape and fill in the current tokens
342+
# create an empty tensor of the expected final shape and
343+
# fill in the current tokens
342344
empty = torch.empty(T_new, dtype=dtype, device=device)
343345
empty[:T] = prompt
344346
seq = empty
@@ -461,8 +463,6 @@ def _main(
461463
is_speculative = speculative_builder_args.checkpoint_path is not None
462464

463465
if generator_args.chat_mode and not builder_args.is_chat_model:
464-
# This is not a log message, it's a dangerous condition message
465-
# that we must ensure is displayed
466466
print(
467467
"""
468468
*******************************************************
@@ -486,8 +486,6 @@ def _main(
486486
builder_args.setup_caches = False
487487
model = _initialize_model(builder_args, quantize, tokenizer)
488488

489-
# will add a version of _initialize_model in future
490-
# (need additional args)
491489
if is_speculative:
492490
draft_model = _initialize_model(
493491
speculative_builder_args,
@@ -533,7 +531,6 @@ def _main(
533531
decode_one_token, mode="reduce-overhead", fullgraph=True
534532
)
535533

536-
# Uncomment to squeeze more perf out of prefill
537534
if generator_args.compile_prefill:
538535
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
539536

0 commit comments

Comments
 (0)