Skip to content

Commit 324404e

Browse files
committed
Q4 cache: Add groupwise Hadamard transform
1 parent 740a19a commit 324404e

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

doc/qcache_eval.md

+19-14
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@ The tl;dr:
1515
Token-level perplexity tests for various full-precision and quantized models using FP16, FP8 and Q4 cache
1616
modes. Dataset is The Pile, 10 rows of 512 tokens per test.
1717

18-
Model | Precision | FP16 cache | FP8 cache | Q4 cache
19-
--------|-----------|---------------|-----------|---------
20-
Mistral 7B Instruct | 3.0 bpw | 13.33 | 13.43 | 13.41
21-
-- | 3.5 bpw | 13.07 | 13.14 | 13.12
22-
-- | 4.0 bpw | 12.90 | 12.90 | 12.90
23-
-- | 5.0 bpw | 12.73 | 12.73 | 12.75
24-
-- | 6.0 bpw | 12.73 | 12.75 | 12.74
25-
-- | FP16 | 12.69 | 12.71 | 12.72
26-
Mixtral 8x7B | 3.5 bpw | 10.27 | 10.41 | 10.39
27-
-- | 4.0 bpw | 10.09 | 10.26 | 10.23
28-
-- | 5.0 bpw | 10.02 | 10.16 | 10.15
29-
Llama2 7B | 4.0 bpw | 11.43 | 11.92 | 11.74
30-
-- | 5.0 bpw | 11.13 | 11.40 | 11.31
31-
-- | FP16 | 10.91 | 11.24 | 11.16
18+
Results are updated for the new method which uses Hadamard rotations on the keys/values. Old results for version
19+
0.0.18 and prior kept for reference.
20+
21+
Model | Precision | FP16 cache | FP8 cache | Q4 cache (old) | Q4 cache
22+
--------|---------|-------------|-----------|-------|----------
23+
Mistral 7B Instruct | 3.0 bpw | **13.33** | 13.43 | 13.41 | **13.37**
24+
-- | 3.5 bpw | **13.07** | 13.14 | 13.12 | **13.09**
25+
-- | 4.0 bpw | **12.90** | 12.90 | 12.90 | **12.90**
26+
-- | 5.0 bpw | **12.73** | 12.73 | 12.75 | **12.75**
27+
-- | 6.0 bpw | **12.73** | 12.75 | 12.74 | **12.74**
28+
-- | FP16 | **12.69** | 12.71 | 12.72 | **12.69**
29+
Mixtral 8x7B | 3.5 bpw | **10.27** | 10.41 | 10.39 | **10.32**
30+
-- | 4.0 bpw | **10.09** | 10.26 | 10.23 | **10.19**
31+
-- | 5.0 bpw | **10.02** | 10.16 | 10.15 | **10.04**
32+
Llama2 7B | 4.0 bpw | **11.43** | 11.92 | 11.74 | **11.60**
33+
-- | 5.0 bpw | **11.13** | 11.40 | 11.31 | **11.19**
34+
-- | FP16 | **10.91** | 11.24 | 11.16 | **10.05**
3235

3336

3437
### HumanEval
@@ -37,6 +40,8 @@ The following are HumanEval tests on various full-precision and quantized models
3740
respectively. Number of samples per task is limited to 10 (still giving 39360 completions in total produced
3841
over about 24 hours.)
3942

43+
The following tests were done prior to the improvements in 0.0.18-dev.
44+
4045
#### pass@1
4146

4247
Model | Precision | FP16 cache | Q4 cache | diff

exllamav2/exllamav2_ext/cuda/cache.cu

+35-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define THREADS 32
88
#define BLOCKSIZE_Q 256
99
#define THREADS_Q (BLOCKSIZE_Q / 2)
10+
#define HADAMARD_Q4
1011

1112
// The upper 8 bits of FP16 are equivalent to FP8 E5M2.
1213
//
@@ -164,6 +165,22 @@ __global__ void fp16_to_q4_kv_kernel
164165
half2 w2 = in2[t];
165166
half2 o = w2;
166167

168+
// Perform hadamard transform on two interleaved 32-element groups. Don't scale output by 1/sqrt(32) here, instead
169+
// scale by 1/32 when dequantizing
170+
171+
#ifdef HADAMARD_Q4
172+
173+
for (int i = 1; i < 32; i <<= 1)
174+
{
175+
half2 pw2 = __shfl_xor_sync(0xffffffff, w2, i, 32);
176+
uint32_t* w2i = reinterpret_cast<uint32_t*>(&w2);
177+
int32_t sfm = -static_cast<int32_t>(t & i) >> 31;
178+
*w2i ^= (sfm & 0x80008000);
179+
w2 = __hadd2(w2, pw2);
180+
}
181+
182+
#endif
183+
167184
// Max abs value for lane_id 0..15, 16..31
168185

169186
half2 absmax2 = __habs2(w2);
@@ -176,7 +193,7 @@ __global__ void fp16_to_q4_kv_kernel
176193

177194
// Normalize
178195

179-
half2 c_8 = __half2half2(__int2half_rn(8));
196+
half2 c_8 = __half2half2(__float2half_rn(8));
180197
half c_i = __float2half_rn(1.0f / 8.0f);
181198

182199
w2 = __h2div(w2, absmax2);
@@ -255,6 +272,23 @@ __global__ void q4_to_fp16_kv_kernel
255272
half2 w2 = __halves2half2(w0, w1);
256273
w2 = __hmul2(w2, scale2);
257274

275+
// Perform hadamard transform on two interleaved 32-element groups. Skipped scaling when quantizing, so result
276+
// is scaled by 1/32 here
277+
278+
#ifdef HADAMARD_Q4
279+
280+
for (int i = 1; i < 32; i <<= 1)
281+
{
282+
half2 pw2 = __shfl_xor_sync(0xffffffff, w2, i, 32);
283+
uint32_t* w2i = reinterpret_cast<uint32_t*>(&w2);
284+
int32_t sfm = -static_cast<int32_t>(t & i) >> 31;
285+
*w2i ^= (sfm & 0x80008000);
286+
w2 = __hadd2(w2, pw2);
287+
}
288+
w2 = __hmul2(w2, __float2half2_rn(1.0f/32.0f));
289+
290+
#endif
291+
258292
// Store
259293

260294
half2* out2 = (half2*) (out + block_offset);

0 commit comments

Comments
 (0)