Skip to content

Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

alexsamardzic
Copy link
Collaborator

No description provided.

Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1671

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 75a6195 with merge base 711fa08 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Feb 5, 2025

The kernel is ready and passes smoke test.

Remaining tasks:

  • Write a converter to SM90 sparse semi-structured format
  • Validate the kernel on proper test inputs
  • Write the benchmark
  • Write Python-side code: sparsify/quantize method, Llama generator extension, etc.
  • Provide that kernel is built with SM90a flags when torchao detects H100 card as SM90
  • Further unify CUDA code with rowwise_scaled_linear_cutlass code
  • Implement a meaningful config selection.

@cpuhrsch @drisspg

@cpuhrsch cpuhrsch requested a review from jcaip February 5, 2025 21:00
@alexsamardzic alexsamardzic added float8 sparsity topic: new feature Use this tag if this PR adds a new feature labels Feb 6, 2025
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 2 times, most recently from 5bbcc49 to 6d34b7e Compare February 6, 2025 23:41
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 8 times, most recently from bd7288a to f11fae4 Compare February 13, 2025 22:38
@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 10 times, most recently from bf65c83 to c0368e3 Compare February 19, 2025 23:33
@alexsamardzic
Copy link
Collaborator Author

Testing this PR revealed that the sparse compressor in CUTLASS is not treating -0.0 values as zeros. The upstream fix is proposed here.

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch from c0368e3 to 4c63c65 Compare February 20, 2025 19:10
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Feb 24, 2025

This PR is ready for review. It contains:

  1. An implementation of two new CUTLASS-based operators:
    • Converter to sparse format for FP8 data and SM9x arch, in torchao/csrc/cuda/to_sparse_semi_structured_cutlass_sm9x.
    • Row-wise scaled linear operator implementation for sparse FP8 weight and FP8 activation in torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass. For parallel compilation, each operator template instantiation is in a separate .cu file.
  2. The test for later operator in test/test_rowwise_scaled_linear_sparse_cutlass.py (not all tests will pass at the moment because of [QST] About NaNs generated during FP16->FP8 quantization #1766), and the micro-benchmark in benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py.
  3. The corresponding layout and TensorImpl class implementations in torchao/dtypes/floatx/cutlass_semi_sparse_layout.py. Because of a CUTLASS issue with handling minus zero values when compressing dense to sparse tensor, from_plain() method here contains a temporary workaround (the fix for this CUTLASS issue is in the works: Treat negative zero as equivalent to positive zero in sm90_sparse_gemm_compressor.hpp NVIDIA/cutlass#2110).
  4. The remaining glue code on the Python side in torchao/ops.py, torchao/dtypes/affine_quantized_tensor.py and torchao/quantization/quant_api.py, including definition of new config Float8DynamicActivationFloat8SemiSparseWeightConfig for the quantize_() method.
  5. An update to torchao/_models/llama/generate.py script, to make it possible to test the new quantization and linear operator within the context of Llama - run with python generate.py --compile --sparsity semi -q float8dq.
  6. Some minor updates for CUTLASS-based integer W4A4/W4A8 stuff.

I'll address the performance tuning (through CUTLASS run-time config selection), that is mentioned as a remaining task above, in a separate PR.

@drisspg The setup.py changes are about activating gencode flags for SM90a when the build is for SM90 - it's clumsy, but it works, so hopefully we could use this approach until eventually switching to CMake builds for the extensions. I'm adding you as a reviewer because of this; also, please add reviewer(s), whoever may be the most appropriate, for the Python side of the code.

@jcaip If you think there is a need, we may discuss eventually exposing mentioned new operators through SparseSemiStructuredTensor.

@gau-nernst With this PR, it's possible to try CUTLASS-based W4A4 operator from the Llama generator - run with python generate.py --compile --sparsity semi -q int4dq-4 (be sure to fetch the model beforehand - instructions are here). The output is not meaningful, maybe it's because the quantization is too tight, but we may want to investigate it further.

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 5 times, most recently from d1d96f7 to 4eaece8 Compare February 27, 2025 16:02
Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @alexsamardzic Took a first pass, but looks good so far :)

Did you have any numbers for generate.py? I can try grabbing some if not.

"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"--use_fast_math",
"--ftemplate-backtrace-limit=0",
# "--keep",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do these lines do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NDEBUG should be always there for a non-debug build, I'm moving it up, among other general nvcc flags. The rest of -D defines are what CUTLASS itself uses for compilation, these won't affect non-CUTLASS .cufiles. The --ftemplate-backtrace-limit=0 will print full list of template instantiations in case of a compile error (the default is to print the first 5 and the last 5); as CUTLASS is heavily templated library, this is really needed to understand and fix these errors. The --fast-math is not really needed and should not be there, I'm removing it now as it affects non-CUTLASS .cu files too. The options under comments are useful for developers to activate sometimes.

(I hope #1659 eventually gets meerged, as all of these flags would easier to handle with CMake. With CUDAExtension, it seems there should be a new extension whenever build flags need to differ, and for this reason I have a separate extension for CUTLASS-based SM90+ kernels. It would be good to have another one for other CUTLASS-based kernels, to apply the flags discussed to these only; but having more extensions slows down the build.)

// performance is really bad; on the other side, using
// KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to
// affect the precision much - thus, sticking with it.
using KernelSchedule =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of a performance degredation is slow accumulation? Is it about the same 2x we see for the non-sparse version?

We have some recipes where this makes a difference for the final accuracy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Llama generate.py, it's indeed close to 2x; for the benchmark (benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py) of the operator only, it's even worse - up to 3x for some shapes.

I was thinking about making use_fast_accum an optional argument for the operator. However, that would mean eventually having this as a template argument, with two options, that would double number of templates instantiated during the compilation, that would slow down the compilation twice. So I think it's better to keep it as is for now, it could be added if a need arise. (For the same reason, the number of options for scales/bias/output data types in .cu file is kept to minimum - this is stated explicitly in a comment in the .cuh file.).

@@ -110,8 +108,7 @@ def from_plain(
_layout: Layout,
):
assert zero_point is None or torch.all(zero_point == 0)

int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF)
int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious here, what's the difference between ... and :?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to look into 3D and higher-dimensional tensors to spot the difference: for example, for 3D tensor x, it holds that x[..., -1] is x[:, :, -1], while x[:, -1] is x[:, -1, :].

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch from 4eaece8 to e55367e Compare February 28, 2025 11:48
@alexsamardzic
Copy link
Collaborator Author

Did you have any numbers for generate.py? I can try grabbing some if not.

Numbers are nothing to be advertised at the moment: baseline generate.py --compile gives around 180 tokens/sec on H100 that I used for testing, while this kernel generate.py --compile --sparsity semi -q float8dq produces just around 120 tokens/sec. But the run-time configs are not tuned at all for small batch sizes - that's what I mentioned in other comments that I intend to take on next (note also that CUTLASS-based kernels are not good fit for dynamic quantization at the moment as they cannot be fused). On the other hand, the benchmark script shows the speedup vs BF16/BF16 MM for up to 2x for larger shapes.

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch 3 times, most recently from e276d04 to 284fc37 Compare March 3, 2025 17:56
Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, feel free to merge

@alexsamardzic alexsamardzic force-pushed the rowwise-scaled-sparse-fp8-cutlass branch from 284fc37 to 75a6195 Compare March 11, 2025 21:22
@alexsamardzic
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

vkuzo added a commit that referenced this pull request Mar 17, 2025
Summary:

This test was added by #1671.

This test doesn't pass on ROCm, skip it to unbreak CI and we can fix it
later

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c102e19
ghstack-comment-id: 2729651455
Pull Request resolved: #1906
@alexsamardzic alexsamardzic deleted the rowwise-scaled-sparse-fp8-cutlass branch March 18, 2025 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 Merged sparsity topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants