Skip to content

Commit 26cfc08

Browse files
committed
feat: add checks
1 parent 94d4e15 commit 26cfc08

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

torchao/dtypes/affine_quantized_tensor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1053,11 +1053,14 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
10531053
original_shape = weight_tensor.layout_tensor.original_shape
10541054
num_bits = weight_tensor.layout_tensor.num_bits
10551055

1056+
size_m = input_tensor.shape[0]
1057+
size_n = original_shape[0]
1058+
size_k = input_tensor.shape[1]
10561059
workspace_24 = marlin_24_workspace(original_shape[1])
10571060

10581061
out = torchao.ops.marlin_24_gemm(
1059-
input_tensor, sparse_w_int4, meta, scale, workspace_24,
1060-
num_bits, input_tensor.shape[0], original_shape[1], input_tensor.shape[1],
1062+
input_tensor, sparse_w_int4, meta, scale,
1063+
workspace_24, num_bits, size_m, size_n, size_k
10611064
)
10621065
torch.cuda.synchronize()
10631066

torchao/ops.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,60 @@ def _(
204204
size_n: int,
205205
size_k: int,
206206
) -> Tensor:
207-
# NOTE: Checks in kernel
207+
TILE_SIZE = 16
208+
MIN_THREAD_N = 128
209+
MAX_PARALLELISM = 64
210+
211+
# Verify num_bits
212+
torch._check(bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}")
213+
pack_factor = 32 // bits
214+
215+
# Verify M
216+
torch._check(size_m == x.size(0), lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}")
217+
218+
# Verify K
219+
torch._check(size_k == x.size(1), lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}")
220+
torch._check(size_k % TILE_SIZE == 0, lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}")
221+
torch._check((size_k // TILE_SIZE // 2) == weight_marlin.size(0), lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}")
222+
223+
# Verify N
224+
torch._check(s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}")
225+
torch._check(weight_marlin.size(1) % TILE_SIZE == 0, lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}")
226+
227+
actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * pack_factor
228+
torch._check(size_n == actual_size_n, lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}")
229+
230+
# Verify meta
231+
torch._check(meta.size(0) == size_k // 8 // 2 // 2, lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}")
232+
torch._check(meta.size(1) == size_n * 2, lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}")
233+
234+
# Verify A device and strides
235+
torch._check(x.is_cuda, lambda: "x is not on GPU")
236+
torch._check(x.is_contiguous(), lambda: "x is not contiguous")
237+
238+
# Verify B device and strides
239+
torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU")
240+
torch._check(weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous")
241+
242+
# Verify meta device and strides
243+
torch._check(meta.is_cuda, lambda: "meta is not on GPU")
244+
torch._check(meta.is_contiguous(), lambda: "meta is not contiguous")
245+
246+
# Verify scales device and strides
247+
torch._check(s.is_cuda, lambda: "s is not on GPU")
248+
torch._check(s.is_contiguous(), lambda: "s is not contiguous")
249+
250+
# Verify groupsize
251+
groupsize = -1
252+
if s.size(0) > 1:
253+
torch._check(size_k % s.size(0) == 0, lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}")
254+
groupsize = size_k // s.size(0)
255+
groupsize //= 2 # Because of 24
256+
torch._check(groupsize == -1 or groupsize == 64, lambda: f"Unexpected groupsize = {groupsize}")
257+
258+
# Verify workspace size
259+
torch._check(size_n % MIN_THREAD_N == 0, lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}")
260+
min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM
261+
torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}")
262+
208263
return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device)

0 commit comments

Comments
 (0)