@@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
107
107
import triton .language as tl
108
108
109
109
@triton .jit
110
- def _fp4_packed_to_bf16 (x_packed ):
110
+ def _fp4_packed_to_bf16 (
111
+ x_packed ,
112
+ sign_mask_f4 ,
113
+ mantissa_mask_f4 ,
114
+ mbits_f4_e2m1 ,
115
+ ebits_f4_e2m1 ,
116
+ f4_e2m1_exp_bias ,
117
+ mbits_f32 ,
118
+ ebits_f32 ,
119
+ f32_exp_bias ,
120
+ zero_bits_f32 ,
121
+ zero_point_five_bits_f32 ,
122
+ ):
111
123
"""
112
124
Input: a tensor of packed fp4 values
113
125
Output: a tensor of bfloat16 values
@@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed):
123
135
# output = x_unpacked.to(tl.float32)
124
136
125
137
# save the sign
126
- sign_f4 = x & SIGN_MASK_F4
138
+ sign_f4 = x & sign_mask_f4
127
139
128
140
# set everything to positive, will add sign back at the end
129
141
x_pos = x ^ sign_f4
@@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed):
138
150
denormal_mask = x_pos == 1
139
151
140
152
# calculate the new exponent and shift it to bits 2:9 of the result
141
- exp_biased_f4 = x_pos >> MBITS_F4_E2M1
142
- exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
143
- exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << MBITS_F32
153
+ exp_biased_f4 = x_pos >> mbits_f4_e2m1
154
+ exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
155
+ exp_biased_f32 = exp_biased_f32 .to (tl .int32 ) << mbits_f32
144
156
145
157
# shift the mantissa to bits 10:32 of the result
146
- mantissa_f4 = x_pos & MANTISSA_MASK_F4
147
- mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (MBITS_F32 - MBITS_F4_E2M1 )
158
+ mantissa_f4 = x_pos & mantissa_mask_f4
159
+ mantissa_f32 = mantissa_f4 .to (tl .int32 ) << (mbits_f32 - mbits_f4_e2m1 )
148
160
output = mantissa_f32
149
161
150
162
# combine the pieces
151
163
result = exp_biased_f32 | mantissa_f32
152
164
# result[zero_mask] = ZERO_BITS_F32
153
- result = tl .where (zero_mask , ZERO_BITS_F32 , result )
165
+ result = tl .where (zero_mask , zero_bits_f32 , result )
154
166
# result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
155
- result = tl .where (denormal_mask , ZERO_POINT_FIVE_BITS_F32 , result )
167
+ result = tl .where (denormal_mask , zero_point_five_bits_f32 , result )
156
168
157
169
# add sign back
158
170
sign_f32 = sign_f4 .to (tl .int32 ) << (
159
- MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
171
+ mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
160
172
)
161
173
result = result | sign_f32
162
174
@@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel(
174
186
x_ptr ,
175
187
output_ptr ,
176
188
n_elements_in ,
189
+ sign_mask_f4 : tl .constexpr ,
190
+ mantissa_mask_f4 : tl .constexpr ,
191
+ mbits_f4_e2m1 : tl .constexpr ,
192
+ ebits_f4_e2m1 : tl .constexpr ,
193
+ f4_e2m1_exp_bias : tl .constexpr ,
194
+ mbits_f32 : tl .constexpr ,
195
+ ebits_f32 : tl .constexpr ,
196
+ f32_exp_bias : tl .constexpr ,
197
+ zero_bits_f32 : tl .constexpr ,
198
+ zero_point_five_bits_f32 : tl .constexpr ,
177
199
BLOCK_SIZE_IN : tl .constexpr ,
178
200
):
179
201
pid = tl .program_id (axis = 0 )
@@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel(
187
209
188
210
# packed uint8
189
211
x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
190
- output = _fp4_packed_to_bf16 (x_packed )
212
+ output = _fp4_packed_to_bf16 (
213
+ x_packed ,
214
+ sign_mask_f4 ,
215
+ mantissa_mask_f4 ,
216
+ mbits_f4_e2m1 ,
217
+ ebits_f4_e2m1 ,
218
+ f4_e2m1_exp_bias ,
219
+ mbits_f32 ,
220
+ ebits_f32 ,
221
+ f32_exp_bias ,
222
+ zero_bits_f32 ,
223
+ zero_point_five_bits_f32 ,
224
+ )
191
225
192
226
# set up output offsets
193
227
block_start_out = pid * BLOCK_SIZE_OUT
@@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel(
213
247
output_ptr ,
214
248
n_elements_in ,
215
249
mx_block_size : tl .constexpr ,
250
+ sign_mask_f4 : tl .constexpr ,
251
+ mantissa_mask_f4 : tl .constexpr ,
252
+ mbits_f4_e2m1 : tl .constexpr ,
253
+ ebits_f4_e2m1 : tl .constexpr ,
254
+ f4_e2m1_exp_bias : tl .constexpr ,
255
+ mbits_f32 : tl .constexpr ,
256
+ ebits_f32 : tl .constexpr ,
257
+ f32_exp_bias : tl .constexpr ,
258
+ zero_bits_f32 : tl .constexpr ,
259
+ zero_point_five_bits_f32 : tl .constexpr ,
260
+ e8m0_exponent_bias : tl .constexpr ,
261
+ e8m0_exponent_nan_val : tl .constexpr ,
216
262
BLOCK_SIZE_IN : tl .constexpr ,
217
263
):
218
264
pid = tl .program_id (axis = 0 )
@@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel(
227
273
mask_in = offsets_in < n_elements_in
228
274
# packed uint8
229
275
x_packed = tl .load (x_ptr + offsets_in , mask = mask_in )
230
- output = _fp4_packed_to_bf16 (x_packed )
276
+ output = _fp4_packed_to_bf16 (
277
+ x_packed ,
278
+ sign_mask_f4 ,
279
+ mantissa_mask_f4 ,
280
+ mbits_f4_e2m1 ,
281
+ ebits_f4_e2m1 ,
282
+ f4_e2m1_exp_bias ,
283
+ mbits_f32 ,
284
+ ebits_f32 ,
285
+ f32_exp_bias ,
286
+ zero_bits_f32 ,
287
+ zero_point_five_bits_f32 ,
288
+ )
231
289
232
290
# load scale
233
291
block_start_s = pid * BLOCK_SIZE_S
@@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel(
236
294
s = tl .load (s_ptr + offsets_s , mask = mask_s )
237
295
238
296
# create the scale in bf16
239
- s_offset = s .to (tl .int16 ) - E8M0_EXPONENT_BIAS
297
+ s_offset = s .to (tl .int16 ) - e8m0_exponent_bias
240
298
s_fp = libdevice .pow (2.0 , s_offset ).to (tl .bfloat16 )
241
- s_fp = tl .where (s != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
299
+ s_fp = tl .where (s != e8m0_exponent_nan_val , s_fp , float ("nan" ))
242
300
243
301
# multiply output by scale
244
302
# TODO(later): see if manipulating the exponent instead of fp
@@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel(
263
321
x_ptr ,
264
322
output_ptr ,
265
323
n_elements_in ,
324
+ sign_mask_f4 ,
325
+ mantissa_mask_f4 ,
326
+ mbits_f4_e2m1 ,
327
+ ebits_f4_e2m1 ,
328
+ f4_e2m1_exp_bias ,
329
+ mbits_f32 ,
330
+ ebits_f32 ,
331
+ f32_exp_bias ,
332
+ zero_bits_f32 ,
333
+ zero_point_five_bits_f32 ,
266
334
BLOCK_SIZE_IN ,
267
335
):
268
336
raise AssertionError ("unsupported without triton" )
@@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel(
273
341
output_ptr ,
274
342
n_elements_in ,
275
343
mx_block_size ,
344
+ sign_mask_f4 ,
345
+ mantissa_mask_f4 ,
346
+ mbits_f4_e2m1 ,
347
+ ebits_f4_e2m1 ,
348
+ f4_e2m1_exp_bias ,
349
+ mbits_f32 ,
350
+ ebits_f32 ,
351
+ f32_exp_bias ,
352
+ zero_bits_f32 ,
353
+ zero_point_five_bits_f32 ,
354
+ e8m0_exponent_bias ,
355
+ e8m0_exponent_nan_val ,
276
356
BLOCK_SIZE_IN ,
277
357
):
278
358
raise AssertionError ("unsupported without triton" )
@@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
294
374
grid = lambda meta : ( # noqa: E731
295
375
triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
296
376
) # noqa: E731,E501
297
- triton_f4_to_bf16_kernel [grid ](x , output , n_elements_in , BLOCK_SIZE_IN = 512 )
377
+ triton_f4_to_bf16_kernel [grid ](
378
+ x ,
379
+ output ,
380
+ n_elements_in ,
381
+ sign_mask_f4 = SIGN_MASK_F4 ,
382
+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
383
+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
384
+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
385
+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
386
+ mbits_f32 = MBITS_F32 ,
387
+ ebits_f32 = EBITS_F32 ,
388
+ f32_exp_bias = F32_EXP_BIAS ,
389
+ zero_bits_f32 = ZERO_BITS_F32 ,
390
+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
391
+ BLOCK_SIZE_IN = 512 ,
392
+ )
298
393
return output
299
394
300
395
@@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16(
318
413
triton .cdiv (n_elements_in , meta ["BLOCK_SIZE_IN" ]),
319
414
)
320
415
triton_f4_to_scaled_bf16_kernel [grid ](
321
- x , s_e8m0 , output , n_elements_in , mx_block_size
416
+ x ,
417
+ s_e8m0 ,
418
+ output ,
419
+ n_elements_in ,
420
+ mx_block_size ,
421
+ sign_mask_f4 = SIGN_MASK_F4 ,
422
+ mantissa_mask_f4 = MANTISSA_MASK_F4 ,
423
+ mbits_f4_e2m1 = MBITS_F4_E2M1 ,
424
+ ebits_f4_e2m1 = EBITS_F4_E2M1 ,
425
+ f4_e2m1_exp_bias = F4_E2M1_EXP_BIAS ,
426
+ mbits_f32 = MBITS_F32 ,
427
+ ebits_f32 = EBITS_F32 ,
428
+ f32_exp_bias = F32_EXP_BIAS ,
429
+ zero_bits_f32 = ZERO_BITS_F32 ,
430
+ zero_point_five_bits_f32 = ZERO_POINT_FIVE_BITS_F32 ,
431
+ e8m0_exponent_bias = E8M0_EXPONENT_BIAS ,
432
+ e8m0_exponent_nan_val = E8M0_EXPONENT_NAN_VAL ,
322
433
)
323
434
return output
324
435
0 commit comments