-
Notifications
You must be signed in to change notification settings - Fork 257
[float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes #1325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1325
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6f4615b with merge base 1a0dbf1 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
and b_scale.shape == (1, b_data.shape[1]) | ||
and not use_fast_accum | ||
): | ||
# The rowwise CUTLASS-based kernel is so slow without fast-accum that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, do we have any OSS shareable evidence (perf/accuracy) on doing this versus rowwise with fast-accum off that we can add here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.

I ran a quick benchmark on my H100 with a recent-ish version of PyTorch (nightly from Nov 12). I samples all MxNxK matmul shapes where each of M, N and K is a power of two between 512 and 16384. Here I'm plotting the slowdowns observed when activating slow-accum for the rowwise (CUTLASS-based) and tensorwise (cuBLAS-based) modes
In summary: in tensorwise we get a max slowdown of 50% (usually much less), with rowwise we typically are 2x as slow, with peaks of 4.5x as slow as fast-accum.
(I suspect that for very small shapes the benchmark was CPU-bound hence slow-accum looks as fast as fast-accum, but that's probably misleading)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw in cuda 12.6.2+ perf of row-wise slow accum kernels is significantly better (slowdown is 50% or so, instead of 2-3x) but separate scaling might still come out ahead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel Do you know where those improvements come from? Is it just NVCC becoming better/smarter? Those rowwise kernels are built as part of PyTorch using CUTLASS so I expected CUTLASS upgrades would be more likely to improve perf...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now rerunning benchmarks I see bad perf even on the new version, must have messed something up the last time :-(.
Landing since Ruff is already broken on main |
Superseded by #1377 |
This improves perf for large matrices by more than 2x, more detailed benchmark coming. On master  On this branch <img width="601" alt="image" src="https://github.com/user-attachments/assets/7f55152b-1110-45e4-b2ea-6f274d543869" /> A plot similar to pytorch/ao#1325 (comment) <details> <summary>Benchmarking code:</summary> ```python import torch from triton.testing import do_bench import itertools def fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False): return torch._scaled_mm(a, b.t(), scale_a.view(-1, 1), scale_b.view(1, -1), use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16) def fn_aten(a, b, scale, use_fast_accum=False): return torch._scaled_mm(a, b.t(), scale, scale, use_fast_accum=use_fast_accum, out_dtype=torch.bfloat16) for i,j,k in itertools.product(range(9, 15), range(9, 15), range(9, 15)): m = 2**i n = 2**j k = 2**k a=torch.randn(m, k, device="cuda").to(dtype=torch.float8_e4m3fn) b=torch.randn(n, k, device="cuda").to(dtype=torch.float8_e4m3fn) scale_a = torch.randint(1, 11, (a.shape[0],), device="cuda", dtype=torch.float32) scale_b = torch.randint(1, 11, (b.shape[0],), device="cuda", dtype=torch.float32) scale_0 = torch.randn((), device="cuda", dtype=torch.float32) ms_rowwise_fast = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=True), warmup=25, rep=50) ms_rowwise_slow = do_bench(lambda: fn_aten_scales(a, b, scale_a, scale_b, use_fast_accum=False), warmup=25, rep=50) ms_tensor_fast = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=True), warmup=25, rep=50) ms_tensor_slow = do_bench(lambda: fn_aten(a, b, scale_0, use_fast_accum=False), warmup=25, rep=50) print(f"m={m}, n={n}, k={k}, fast={ms_rowwise_fast}, slow={ms_rowwise_slow}, ratio_tw={ms_tensor_slow /ms_tensor_fast}, ratio_rw={ms_rowwise_slow / ms_rowwise_fast}") ``` </details> Higher N/K values still have about 40% penalty, perhaps some additional heuristics tweaks would be useful. Pull Request resolved: #144809 Approved by: https://github.com/drisspg
Stack from ghstack (oldest at bottom):
And circumvent the issue with the slow CUTLASS kernel by using the cuBLAS kernel + manual scaling.