Skip to content

Commit 4c63c65

Browse files
committed
Add CUTLASS-based row-wise scaled sparse FP8 kernel
1 parent ceceea5 commit 4c63c65

21 files changed

+1336
-76
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import (
7+
rowwise_scaled_linear_sparse_cutlass_f8f8,
8+
to_sparse_semi_structured_cutlass_sm9x_f8,
9+
)
10+
11+
12+
def benchmark_microseconds(f, *args):
13+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
14+
15+
16+
def get_problem(m: int, n: int, k: int):
17+
dev = torch.device("cuda")
18+
19+
A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2)
20+
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
21+
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn)
22+
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B)
23+
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
24+
C = None
25+
26+
return A, A_scale, B_sp, B_meta, B_scale, C
27+
28+
29+
def benchmark(m: int, k: int, n: int):
30+
dev = torch.device("cuda")
31+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
32+
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
33+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
34+
35+
A, A_scale, B_sp, B_meta, B_scale, C = get_problem(m, n, k)
36+
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
37+
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale, C
38+
)
39+
40+
return {
41+
"m": m,
42+
"k": k,
43+
"n": n,
44+
"fp16_latency (ms)": fp16_time,
45+
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
46+
"f8f8 speedup (d/s)": fp16_time
47+
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
48+
}
49+
50+
51+
if __name__ == "__main__":
52+
k_vals = (8192, 8192, 8192, 28672)
53+
n_vals = (8192, 10240, 57344, 8192)
54+
55+
results = []
56+
for m in tqdm([1 << i for i in range(10)]):
57+
for n, k in zip(n_vals, k_vals):
58+
results.append(benchmark(m, k, n))
59+
60+
df = pd.DataFrame(results)
61+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
62+
print(df.to_markdown(index=False))

docs/source/api_ref_dtypes.rst

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
2828
MarlinQQQLayout
2929
Int4CPULayout
3030
CutlassInt4PackedLayout
31+
CutlassSemiSparseLayout
3132

3233
Quantization techniques
3334
-----------------------

setup.py

+12
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,18 @@ def get_extensions():
265265
"-I" + cutlass_include_dir,
266266
"-I" + cutlass_tools_include_dir,
267267
"-I" + cutlass_extensions_include_dir,
268+
"-DNDEBUG" if not debug_mode else "",
269+
"-DCUTE_USE_PACKED_TUPLE=1"
270+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
271+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
272+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
273+
"--use_fast_math",
274+
"--ftemplate-backtrace-limit=0",
275+
# "--keep",
276+
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
277+
# "--resource-usage",
278+
# "-lineinfo",
279+
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
268280
]
269281
)
270282
else:

test/test_rowwise_scaled_linear_cutlass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)
5757
)
5858
assert torch.all(wq_zeros == 0)
5959
if wq_bits == 4:
60-
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
60+
wq = (wq_s8[..., 1::2] << 4) | (wq_s8[..., 0::2] & 0xF)
6161
else:
6262
wq = wq_s8
6363

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import itertools
2+
3+
import pytest
4+
import torch
5+
from torch.testing._internal.common_cuda import SM90OrLater
6+
7+
from torchao.dtypes import (
8+
Float8Layout,
9+
to_affine_quantized_floatx,
10+
)
11+
from torchao.ops import (
12+
rowwise_scaled_linear_sparse_cutlass_f8f8,
13+
to_sparse_semi_structured_cutlass_sm9x_f8,
14+
)
15+
from torchao.quantization.utils import _get_per_token_block_size
16+
from torchao.sparsity.utils import create_semi_structured_tensor
17+
18+
X_W_DTYPES = [(torch.float16, torch.float16), (torch.bfloat16, torch.bfloat16)]
19+
XQ_WQ_DTYPES = [
20+
(torch.float8_e5m2, torch.float8_e4m3fn),
21+
(torch.float8_e4m3fn, torch.float8_e4m3fn),
22+
]
23+
BATCH_SIZE = [1, 4]
24+
SIZE_MNK = [
25+
(2, 128, 256),
26+
(3, 128, 256),
27+
(13, 128, 256),
28+
(27, 128, 128),
29+
(33, 128, 64),
30+
(65, 128, 32),
31+
]
32+
USE_BIAS = [False, True]
33+
BIAS_DTYPE = [torch.float16]
34+
TEST_PARAMS = list(
35+
itertools.product(
36+
X_W_DTYPES,
37+
XQ_WQ_DTYPES,
38+
BATCH_SIZE,
39+
SIZE_MNK,
40+
USE_BIAS,
41+
BIAS_DTYPE,
42+
)
43+
)
44+
45+
46+
def run_test_for_op(
47+
op,
48+
x_dtype,
49+
w_dtype,
50+
xq_dtype,
51+
wq_dtype,
52+
batch_size,
53+
size_mnk,
54+
use_bias,
55+
bias_dtype,
56+
):
57+
size_m, size_n, size_k = size_mnk
58+
59+
x = torch.randn((batch_size, size_m, size_k), dtype=x_dtype, device="cuda")
60+
w = create_semi_structured_tensor(size_n, size_k, dtype=w_dtype)
61+
bias = torch.rand((size_n,), dtype=bias_dtype, device="cuda") if use_bias else None
62+
63+
x_aqt = to_affine_quantized_floatx(
64+
input_float=x,
65+
target_dtype=xq_dtype,
66+
block_size=_get_per_token_block_size(x),
67+
_layout=Float8Layout(mm_config=None),
68+
)
69+
xq, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
70+
assert zero_points is None
71+
72+
w_aqt = to_affine_quantized_floatx(
73+
input_float=w,
74+
target_dtype=wq_dtype,
75+
block_size=_get_per_token_block_size(w),
76+
_layout=Float8Layout(mm_config=None),
77+
)
78+
wq, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
79+
assert zero_points is None
80+
wq_sp, wq_sp_meta = to_sparse_semi_structured_cutlass_sm9x_f8(wq)
81+
wq_sp_scales = wq_scales
82+
83+
xq_2d = xq.view(-1, xq.shape[-1])
84+
size_m_2d = xq_2d.shape[0]
85+
output_ref = (
86+
(xq_2d.float() @ wq.float().T)
87+
* xq_scales.view(size_m_2d, 1)
88+
* wq_scales.view(1, size_n)
89+
)
90+
if bias is not None:
91+
output_ref += bias
92+
output_ref = output_ref.to(x.dtype).reshape(x.shape[:-1] + (size_n,))
93+
94+
fn_inputs = (xq, xq_scales, wq_sp, wq_sp_meta, wq_sp_scales, bias)
95+
try:
96+
output = op(*fn_inputs)
97+
except NotImplementedError:
98+
pytest.xfail("operator not implemented")
99+
100+
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=5e-3)
101+
102+
103+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
104+
@pytest.mark.skipif(not SM90OrLater, reason="FP8 is only supported on H100+ devices")
105+
@pytest.mark.parametrize(
106+
"x_w_dtypes, xq_wq_dtypes, batch_size, size_mnk, use_bias, bias_dtype",
107+
TEST_PARAMS,
108+
)
109+
def test_rowwise_scaled_liner_sparse_cutlass_f8f8(
110+
x_w_dtypes,
111+
xq_wq_dtypes,
112+
batch_size,
113+
size_mnk,
114+
use_bias,
115+
bias_dtype,
116+
):
117+
run_test_for_op(
118+
rowwise_scaled_linear_sparse_cutlass_f8f8,
119+
*x_w_dtypes,
120+
*xq_wq_dtypes,
121+
batch_size,
122+
size_mnk,
123+
use_bias,
124+
bias_dtype,
125+
)

torchao/_models/llama/generate.py

+42-20
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,13 @@ def ffn_or_attn_only(mod, fqn):
334334

335335
if quantization:
336336
from torchao.quantization import (
337+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
337338
autoquant,
338339
float8_dynamic_activation_float8_weight,
339340
float8_weight_only,
340341
fpx_weight_only,
341342
gemlite_uintx_weight_only,
343+
int4_dynamic_activation_int4_weight,
342344
int4_weight_only,
343345
int8_dynamic_activation_int4_weight,
344346
int8_dynamic_activation_int8_weight,
@@ -434,18 +436,30 @@ def ffn_or_attn_only(mod, fqn):
434436
]
435437
), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
436438
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
437-
elif "int8adq-int4w-symm" in quantization:
439+
elif "int4dq-" in quantization:
438440
from torchao.dtypes import CutlassInt4PackedLayout
439441

440-
quantize_(
441-
model,
442-
int8_dynamic_activation_int4_weight(
443-
group_size=None,
444-
mapping_type=MappingType.SYMMETRIC,
445-
act_mapping_type=MappingType.SYMMETRIC,
446-
layout=CutlassInt4PackedLayout(),
447-
),
448-
)
442+
nbits = int(quantization.removeprefix("int4dq-"))
443+
assert nbits == 4 or nbits == 8
444+
if nbits == 4:
445+
quantize_(
446+
model,
447+
int4_dynamic_activation_int4_weight(
448+
mapping_type=MappingType.SYMMETRIC,
449+
act_mapping_type=MappingType.SYMMETRIC,
450+
layout=CutlassInt4PackedLayout(),
451+
),
452+
)
453+
elif nbits == 8:
454+
quantize_(
455+
model,
456+
int8_dynamic_activation_int4_weight(
457+
group_size=None,
458+
mapping_type=MappingType.SYMMETRIC,
459+
act_mapping_type=MappingType.SYMMETRIC,
460+
layout=CutlassInt4PackedLayout(),
461+
),
462+
)
449463
if "marlin" in quantization:
450464
if "qqq" in quantization:
451465
from torchao.dtypes import MarlinQQQLayout
@@ -564,16 +578,24 @@ def ffn_or_attn_only(mod, fqn):
564578
elif "float8wo" in quantization:
565579
quantize_(model, float8_weight_only())
566580
elif "float8dq" in quantization:
567-
granularity = str(quantization.split("-")[-1])
568-
if granularity == "tensor":
569-
granularity = PerTensor()
570-
elif granularity == "row":
571-
granularity = PerRow()
581+
if sparsity and "semi" in sparsity:
582+
quantize_(
583+
model,
584+
Float8DynamicActivationFloat8SemiSparseWeightConfig(),
585+
filter_fn=ffn_only
586+
)
572587
else:
573-
granularity = PerTensor()
574-
quantize_(
575-
model, float8_dynamic_activation_float8_weight(granularity=granularity)
576-
)
588+
granularity = str(quantization.split("-")[-1])
589+
if granularity == "tensor":
590+
granularity = PerTensor()
591+
elif granularity == "row":
592+
granularity = PerRow()
593+
else:
594+
granularity = PerTensor()
595+
quantize_(
596+
model,
597+
float8_dynamic_activation_float8_weight(granularity=granularity),
598+
)
577599
elif "autoquant_v2" in quantization:
578600
from torchao._models._eval import InputRecorder
579601
from torchao._models.llama.model import prepare_inputs_for_model
@@ -1130,7 +1152,7 @@ def callback(x):
11301152
help=(
11311153
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
11321154
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
1133-
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
1155+
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>"
11341156
),
11351157
)
11361158
parser.add_argument(

torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#endif
1212

1313
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
14+
#include <cuda_runtime.h>
15+
#include <cutlass/cutlass.h>
1416
#include <cutlass/gemm/device/gemm_universal.h>
1517
#include <cutlass/gemm/device/gemm_universal_adapter.h>
1618
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>

torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu

+9-8
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ rowwise_scaled_linear_cutlass_s4s4(
1313
__func__, " : The input datatypes combination ", xq.dtype(),
1414
" for xq and ", wq.dtype(), " for wq is not supported");
1515

16+
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
1617
// Dispatch to appropriate kernel template.
17-
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
18-
// We get ElementA/ElementB types from the header
19-
return rowwise_scaled_linear_cutlass<cutlass::int4b_t, cutlass::int4b_t>(
20-
xq, x_scale, wq, w_scale, bias);
21-
#else
22-
TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s4s4 not available");
23-
return at::Tensor{};
24-
#endif
18+
using ElementA = cutlass::int4b_t;
19+
using ElementB = cutlass::int4b_t;
20+
return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
21+
xq, x_scale, wq, w_scale, bias);
22+
#else
23+
TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME);
24+
return at::Tensor{};
25+
#endif
2526
}
2627

2728
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {

torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/library.h>
2+
23
#include "rowwise_scaled_linear_cutlass.cuh"
34

45
namespace torchao {
@@ -13,13 +14,13 @@ rowwise_scaled_linear_cutlass_s8s4(
1314
" for xq and ", wq.dtype(), " for wq is not supported");
1415

1516
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
16-
// Define ElementA as int8_t since it's a standard type
17+
// Dispatch to appropriate kernel template.
1718
using ElementA = int8_t;
18-
// ElementB comes from cutlass header
19-
return rowwise_scaled_linear_cutlass<ElementA, cutlass::int4b_t>(
20-
xq, x_scale, wq, w_scale, bias);
19+
using ElementB = cutlass::int4b_t;
20+
return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
21+
xq, x_scale, wq, w_scale, bias);
2122
#else
22-
TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s8s4 not available");
23+
TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME);
2324
return at::Tensor{};
2425
#endif
2526
}

0 commit comments

Comments
 (0)