-
Notifications
You must be signed in to change notification settings - Fork 253
Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CUTLASS-based row-wise scaled sparse FP8 kernel #1671
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1671
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 75a6195 with merge base 711fa08 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The kernel is ready and passes smoke test. Remaining tasks:
|
5bbcc49
to
6d34b7e
Compare
bd7288a
to
f11fae4
Compare
bf65c83
to
c0368e3
Compare
Testing this PR revealed that the sparse compressor in CUTLASS is not treating -0.0 values as zeros. The upstream fix is proposed here. |
c0368e3
to
4c63c65
Compare
This PR is ready for review. It contains:
I'll address the performance tuning (through CUTLASS run-time config selection), that is mentioned as a remaining task above, in a separate PR. @drisspg The @jcaip If you think there is a need, we may discuss eventually exposing mentioned new operators through @gau-nernst With this PR, it's possible to try CUTLASS-based W4A4 operator from the Llama generator - run with |
d1d96f7
to
4eaece8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @alexsamardzic Took a first pass, but looks good so far :)
Did you have any numbers for generate.py? I can try grabbing some if not.
"-DCUTLASS_DEBUG_TRACE_LEVEL=0", | ||
"--use_fast_math", | ||
"--ftemplate-backtrace-limit=0", | ||
# "--keep", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do these lines do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NDEBUG
should be always there for a non-debug build, I'm moving it up, among other general nvcc
flags. The rest of -D
defines are what CUTLASS itself uses for compilation, these won't affect non-CUTLASS .cu
files. The --ftemplate-backtrace-limit=0
will print full list of template instantiations in case of a compile error (the default is to print the first 5 and the last 5); as CUTLASS is heavily templated library, this is really needed to understand and fix these errors. The --fast-math
is not really needed and should not be there, I'm removing it now as it affects non-CUTLASS .cu
files too. The options under comments are useful for developers to activate sometimes.
(I hope #1659 eventually gets meerged, as all of these flags would easier to handle with CMake. With CUDAExtension
, it seems there should be a new extension whenever build flags need to differ, and for this reason I have a separate extension for CUTLASS-based SM90+ kernels. It would be good to have another one for other CUTLASS-based kernels, to apply the flags discussed to these only; but having more extensions slows down the build.)
// performance is really bad; on the other side, using | ||
// KernelTmaWarpSpecializedPingpongFP8FastAccum doesn't seem to | ||
// affect the precision much - thus, sticking with it. | ||
using KernelSchedule = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of a performance degredation is slow accumulation? Is it about the same 2x we see for the non-sparse version?
We have some recipes where this makes a difference for the final accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Llama generate.py, it's indeed close to 2x; for the benchmark (benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
) of the operator only, it's even worse - up to 3x for some shapes.
I was thinking about making use_fast_accum
an optional argument for the operator. However, that would mean eventually having this as a template argument, with two options, that would double number of templates instantiated during the compilation, that would slow down the compilation twice. So I think it's better to keep it as is for now, it could be added if a need arise. (For the same reason, the number of options for scales/bias/output data types in .cu
file is kept to minimum - this is stated explicitly in a comment in the .cuh
file.).
@@ -110,8 +108,7 @@ def from_plain( | |||
_layout: Layout, | |||
): | |||
assert zero_point is None or torch.all(zero_point == 0) | |||
|
|||
int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF) | |||
int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious here, what's the difference between ...
and :
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to look into 3D and higher-dimensional tensors to spot the difference: for example, for 3D tensor x
, it holds that x[..., -1]
is x[:, :, -1]
, while x[:, -1]
is x[:, -1, :]
.
4eaece8
to
e55367e
Compare
Numbers are nothing to be advertised at the moment: baseline |
e276d04
to
284fc37
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, feel free to merge
284fc37
to
75a6195
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
No description provided.