-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy pathutils.py
281 lines (232 loc) · 10.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import torch
from typing import Tuple, Any
from functools import reduce
from importlib.metadata import version
from math import gcd
import torch.nn.utils.parametrize as parametrize
import itertools
import time
import warnings
__all__ = [
"benchmark_model",
"profiler_runner",
"get_compute_capability",
"skip_if_compute_capability_less_than",
"benchmark_torch_function_in_microseconds",
"find_multiple",
"_register_custom_op",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
]
# Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711
def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
def benchmark_model(model, num_runs, input_tensor):
device_type = _assert_and_get_unique_device(model).type
if device_type == "cuda":
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs
elif device_type == "mps":
torch.mps.synchronize()
start_event = torch.mps.event.Event(enable_timing=True)
end_event = torch.mps.event.Event(enable_timing=True)
start_event.record()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)
end_event.record()
torch.mps.synchronize()
return start_event.elapsed_time(end_event) / num_runs
elif device_type == "cpu":
torch.cpu.synchronize()
start_time = time.time()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)
end_time = time.time()
torch.cpu.synchronize()
average_time_per_run = (end_time - start_time) / num_runs
return average_time_per_run
def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result
def get_compute_capability():
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
return float(f"{capability[0]}.{capability[1]}")
return 0.0
def skip_if_compute_capability_less_than(min_capability):
import unittest
def decorator(test_func):
def wrapper(*args, **kwargs):
if get_compute_capability() < min_capability:
raise unittest.SkipTest(f"Compute capability is less than {min_capability}")
return test_func(*args, **kwargs)
return wrapper
return decorator
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded
# Manual warmup
f(*args, **kwargs)
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f}, # noqa: E501
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6
def find_multiple(n: int, *args: Tuple[int]) -> int:
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
if n % k == 0:
return n
return n + k - (n % k)
def _register_custom_op(lib):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
NOTE: This should be applied at the top, after all other decorators have been applied
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
sense for downstream system (like executorch) to accept as well
Example:
lib = torch.library.Library("my_namespace', "FRAGMENT")
register_custom_op = _register_custom_op(lib)
@register_custom_op
def _the_op_that_needs_to_be_preserved(...)
...
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.capture_pre_autograd_graph
"""
from torch._inductor.decomposition import register_decomposition
def decorator(fn):
if TORCH_VERSION_AFTER_2_5:
from torch._library.infer_schema import infer_schema
# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}"
assert not any(c in fn.__name__ for c in ".<>"), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
op_name = fn.__name__[1:]
schema = op_name + infer_schema(fn, mutates_args={})
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")
lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
return op
else:
return fn
return decorator
def get_model_size_in_bytes(model, ignore_embeddings=False):
"""
Returns the model size in bytes. The option to ignore embeddings
is useful for models with disproportionately large embeddings compared
to other model parameters that get quantized/sparsified.
"""
def flat_size(tensor):
if hasattr(tensor, "__tensor_flatten__"):
size = 0
# 0th element is a list of attributes that
# hold tensors
for attr_name in tensor.__tensor_flatten__()[0]:
sub_tensor = getattr(tensor, attr_name)
size += flat_size(sub_tensor)
return size
else:
return tensor.numel() * tensor.element_size()
model_size = 0
for name, child in model.named_children():
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)):
model_size += flat_size(p)
model_size += get_model_size_in_bytes(child, ignore_embeddings)
return model_size
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)
assert len(todo) == 1
return todo[0]
def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)
self.rebuild_stack = rebuild_stack
return plain_tensors
def unwrap_tensor_subclass(model, filter_fn=None):
"""Unwraps (nested) tensor subclass in the model to plain tensors
This is a workaround to make a model with tensor subclass to work with `torch.export.export`
and `torch.aot_compile`, we hope this can be integrated into compile stack soon
tracking issue: https://github.com/pytorch/ao/issues/345
"""
for name, child in model.named_children():
# make sure child.weight is a tensor subclass
if (
isinstance(child, torch.nn.Linear) and
hasattr(child, "weight") and
type(child.weight) is not torch.Tensor and
type(child.weight) is not torch.nn.Parameter and
isinstance(child.weight, torch.Tensor) and
issubclass(type(child.weight), torch.Tensor)
):
parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass())
unwrap_tensor_subclass(child)
return model
def is_fbcode():
return not hasattr(torch.version, "git_version")
def torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version
TORCH_VERSION_AFTER_2_5 = torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = torch_version_at_least("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = torch_version_at_least("2.2.0.dev")
def is_fbcode():
return not hasattr(torch.version, "git_version")