Skip to content

Commit c2cf973

Browse files
authored
fix mx triton kernel after PyTorch triton pin change (#431)
Summary: Triton pin updated recently: pytorch/pytorch#126098 In the new triton version, functions can only access global variables of type `tl.constexpr`. Due to the current structure of the code and the fact that these constants are also used by non-triton programs, I think the best thing to do is to just stop using globals in the MX triton kernel. The PR lifts all of these constants to kernel function arguments. Test Plan: ``` pytest test/prototype/mx_formats/test_custom_cast.py ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 37c348e commit c2cf973

File tree

2 files changed

+128
-17
lines changed

2 files changed

+128
-17
lines changed

.github/workflows/regression_test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
gpu-arch-version: "12.1"
3434
- name: CUDA Nightly
3535
runs-on: linux.g5.12xlarge.nvidia.gpu
36-
torch-spec: '--pre torch==2.5.0.dev20240620+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
36+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
3737
gpu-arch-type: "cuda"
3838
gpu-arch-version: "12.1"
3939
- name: CPU 2.2.2

torchao/prototype/mx_formats/custom_cast.py

+127-16
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
107107
import triton.language as tl
108108

109109
@triton.jit
110-
def _fp4_packed_to_bf16(x_packed):
110+
def _fp4_packed_to_bf16(
111+
x_packed,
112+
sign_mask_f4,
113+
mantissa_mask_f4,
114+
mbits_f4_e2m1,
115+
ebits_f4_e2m1,
116+
f4_e2m1_exp_bias,
117+
mbits_f32,
118+
ebits_f32,
119+
f32_exp_bias,
120+
zero_bits_f32,
121+
zero_point_five_bits_f32,
122+
):
111123
"""
112124
Input: a tensor of packed fp4 values
113125
Output: a tensor of bfloat16 values
@@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed):
123135
# output = x_unpacked.to(tl.float32)
124136

125137
# save the sign
126-
sign_f4 = x & SIGN_MASK_F4
138+
sign_f4 = x & sign_mask_f4
127139

128140
# set everything to positive, will add sign back at the end
129141
x_pos = x ^ sign_f4
@@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed):
138150
denormal_mask = x_pos == 1
139151

140152
# calculate the new exponent and shift it to bits 2:9 of the result
141-
exp_biased_f4 = x_pos >> MBITS_F4_E2M1
142-
exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
143-
exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32
153+
exp_biased_f4 = x_pos >> mbits_f4_e2m1
154+
exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
155+
exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32
144156

145157
# shift the mantissa to bits 10:32 of the result
146-
mantissa_f4 = x_pos & MANTISSA_MASK_F4
147-
mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1)
158+
mantissa_f4 = x_pos & mantissa_mask_f4
159+
mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)
148160
output = mantissa_f32
149161

150162
# combine the pieces
151163
result = exp_biased_f32 | mantissa_f32
152164
# result[zero_mask] = ZERO_BITS_F32
153-
result = tl.where(zero_mask, ZERO_BITS_F32, result)
165+
result = tl.where(zero_mask, zero_bits_f32, result)
154166
# result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
155-
result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result)
167+
result = tl.where(denormal_mask, zero_point_five_bits_f32, result)
156168

157169
# add sign back
158170
sign_f32 = sign_f4.to(tl.int32) << (
159-
MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
171+
mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
160172
)
161173
result = result | sign_f32
162174

@@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel(
174186
x_ptr,
175187
output_ptr,
176188
n_elements_in,
189+
sign_mask_f4: tl.constexpr,
190+
mantissa_mask_f4: tl.constexpr,
191+
mbits_f4_e2m1: tl.constexpr,
192+
ebits_f4_e2m1: tl.constexpr,
193+
f4_e2m1_exp_bias: tl.constexpr,
194+
mbits_f32: tl.constexpr,
195+
ebits_f32: tl.constexpr,
196+
f32_exp_bias: tl.constexpr,
197+
zero_bits_f32: tl.constexpr,
198+
zero_point_five_bits_f32: tl.constexpr,
177199
BLOCK_SIZE_IN: tl.constexpr,
178200
):
179201
pid = tl.program_id(axis=0)
@@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel(
187209

188210
# packed uint8
189211
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
190-
output = _fp4_packed_to_bf16(x_packed)
212+
output = _fp4_packed_to_bf16(
213+
x_packed,
214+
sign_mask_f4,
215+
mantissa_mask_f4,
216+
mbits_f4_e2m1,
217+
ebits_f4_e2m1,
218+
f4_e2m1_exp_bias,
219+
mbits_f32,
220+
ebits_f32,
221+
f32_exp_bias,
222+
zero_bits_f32,
223+
zero_point_five_bits_f32,
224+
)
191225

192226
# set up output offsets
193227
block_start_out = pid * BLOCK_SIZE_OUT
@@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel(
213247
output_ptr,
214248
n_elements_in,
215249
mx_block_size: tl.constexpr,
250+
sign_mask_f4: tl.constexpr,
251+
mantissa_mask_f4: tl.constexpr,
252+
mbits_f4_e2m1: tl.constexpr,
253+
ebits_f4_e2m1: tl.constexpr,
254+
f4_e2m1_exp_bias: tl.constexpr,
255+
mbits_f32: tl.constexpr,
256+
ebits_f32: tl.constexpr,
257+
f32_exp_bias: tl.constexpr,
258+
zero_bits_f32: tl.constexpr,
259+
zero_point_five_bits_f32: tl.constexpr,
260+
e8m0_exponent_bias: tl.constexpr,
261+
e8m0_exponent_nan_val: tl.constexpr,
216262
BLOCK_SIZE_IN: tl.constexpr,
217263
):
218264
pid = tl.program_id(axis=0)
@@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel(
227273
mask_in = offsets_in < n_elements_in
228274
# packed uint8
229275
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
230-
output = _fp4_packed_to_bf16(x_packed)
276+
output = _fp4_packed_to_bf16(
277+
x_packed,
278+
sign_mask_f4,
279+
mantissa_mask_f4,
280+
mbits_f4_e2m1,
281+
ebits_f4_e2m1,
282+
f4_e2m1_exp_bias,
283+
mbits_f32,
284+
ebits_f32,
285+
f32_exp_bias,
286+
zero_bits_f32,
287+
zero_point_five_bits_f32,
288+
)
231289

232290
# load scale
233291
block_start_s = pid * BLOCK_SIZE_S
@@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel(
236294
s = tl.load(s_ptr + offsets_s, mask=mask_s)
237295

238296
# create the scale in bf16
239-
s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS
297+
s_offset = s.to(tl.int16) - e8m0_exponent_bias
240298
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
241-
s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
299+
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))
242300

243301
# multiply output by scale
244302
# TODO(later): see if manipulating the exponent instead of fp
@@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel(
263321
x_ptr,
264322
output_ptr,
265323
n_elements_in,
324+
sign_mask_f4,
325+
mantissa_mask_f4,
326+
mbits_f4_e2m1,
327+
ebits_f4_e2m1,
328+
f4_e2m1_exp_bias,
329+
mbits_f32,
330+
ebits_f32,
331+
f32_exp_bias,
332+
zero_bits_f32,
333+
zero_point_five_bits_f32,
266334
BLOCK_SIZE_IN,
267335
):
268336
raise AssertionError("unsupported without triton")
@@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel(
273341
output_ptr,
274342
n_elements_in,
275343
mx_block_size,
344+
sign_mask_f4,
345+
mantissa_mask_f4,
346+
mbits_f4_e2m1,
347+
ebits_f4_e2m1,
348+
f4_e2m1_exp_bias,
349+
mbits_f32,
350+
ebits_f32,
351+
f32_exp_bias,
352+
zero_bits_f32,
353+
zero_point_five_bits_f32,
354+
e8m0_exponent_bias,
355+
e8m0_exponent_nan_val,
276356
BLOCK_SIZE_IN,
277357
):
278358
raise AssertionError("unsupported without triton")
@@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
294374
grid = lambda meta: ( # noqa: E731
295375
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
296376
) # noqa: E731,E501
297-
triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512)
377+
triton_f4_to_bf16_kernel[grid](
378+
x,
379+
output,
380+
n_elements_in,
381+
sign_mask_f4=SIGN_MASK_F4,
382+
mantissa_mask_f4=MANTISSA_MASK_F4,
383+
mbits_f4_e2m1=MBITS_F4_E2M1,
384+
ebits_f4_e2m1=EBITS_F4_E2M1,
385+
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
386+
mbits_f32=MBITS_F32,
387+
ebits_f32=EBITS_F32,
388+
f32_exp_bias=F32_EXP_BIAS,
389+
zero_bits_f32=ZERO_BITS_F32,
390+
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
391+
BLOCK_SIZE_IN=512,
392+
)
298393
return output
299394

300395

@@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16(
318413
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
319414
)
320415
triton_f4_to_scaled_bf16_kernel[grid](
321-
x, s_e8m0, output, n_elements_in, mx_block_size
416+
x,
417+
s_e8m0,
418+
output,
419+
n_elements_in,
420+
mx_block_size,
421+
sign_mask_f4=SIGN_MASK_F4,
422+
mantissa_mask_f4=MANTISSA_MASK_F4,
423+
mbits_f4_e2m1=MBITS_F4_E2M1,
424+
ebits_f4_e2m1=EBITS_F4_E2M1,
425+
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
426+
mbits_f32=MBITS_F32,
427+
ebits_f32=EBITS_F32,
428+
f32_exp_bias=F32_EXP_BIAS,
429+
zero_bits_f32=ZERO_BITS_F32,
430+
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
431+
e8m0_exponent_bias=E8M0_EXPONENT_BIAS,
432+
e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL,
322433
)
323434
return output
324435

0 commit comments

Comments
 (0)