Skip to content

Commit d91456a

Browse files
ardforkSlyEcho
authored andcommitted
fix half2 decomposition
1 parent c1cb70d commit d91456a

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

ggml-cuda.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
472472
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
473473
const block_q4_1 * x = (const block_q4_1 *) vx;
474474

475-
const dfloat d = x[ib].dm.x;
476-
const dfloat m = x[ib].dm.y;
475+
const dfloat d = __low2half(x[ib].dm);
476+
const dfloat m = __high2half(x[ib].dm);
477477

478478
const int vui = x[ib].qs[iqs];
479479

@@ -515,8 +515,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
515515
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
516516
const block_q5_1 * x = (const block_q5_1 *) vx;
517517

518-
const dfloat d = x[ib].dm.x;
519-
const dfloat m = x[ib].dm.y;
518+
const dfloat d = __low2half(x[ib].dm);
519+
const dfloat m = __high2half(x[ib].dm);
520520

521521
uint32_t qh;
522522
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -568,8 +568,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
568568
const uint8_t q = x[i].qs[32*n + l];
569569
float * y = yy + i*QK_K + 128*n;
570570

571-
float dall = x[i].dm.x;
572-
float dmin = x[i].dm.y;
571+
float dall = __low2half(x[i].dm);
572+
float dmin = __high2half(x[i].dm);
573573
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
574574
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
575575
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@@ -579,8 +579,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
579579
const int il = tid%16; // 0...15
580580
const uint8_t q = x[i].qs[il] >> (2*is);
581581
float * y = yy + i*QK_K + 16*is + il;
582-
float dall = x[i].dm.x;
583-
float dmin = x[i].dm.y;
582+
float dall = __low2half(x[i].dm);
583+
float dmin = __high2half(x[i].dm);
584584
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
585585
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
586586
#endif
@@ -666,8 +666,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
666666

667667
float * y = yy + i*QK_K + 64*il + n*ir;
668668

669-
const float dall = x[i].dm.x;
670-
const float dmin = x[i].dm.y;
669+
const float dall = __low2half(x[i].dm);
670+
const float dmin = __high2half(x[i].dm);
671671

672672
const uint8_t * q = x[i].qs + 32*il + n*ir;
673673

@@ -705,8 +705,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
705705

706706
float * y = yy + i*QK_K + 64*il + 2*ir;
707707

708-
const float dall = x[i].dm.x;
709-
const float dmin = x[i].dm.y;
708+
const float dall = __low2half(x[i].dm);
709+
const float dmin = __high2half(x[i].dm);
710710

711711
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
712712
const uint8_t * qh = x[i].qh + 2*ir;
@@ -818,8 +818,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
818818
const float * y = yy + i * QK_K + y_offset;
819819
const uint8_t * q = x[i].qs + q_offset;
820820

821-
const float dall = x[i].dm.x;
822-
const float dmin = x[i].dm.y;
821+
const float dall = __low2half(x[i].dm);
822+
const float dmin = __high2half(x[i].dm);
823823

824824
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
825825
aux[0] = a[0] & 0x0f0f0f0f;
@@ -1039,8 +1039,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
10391039
const float * y1 = yy + i*QK_K + y_offset;
10401040
const float * y2 = y1 + 128;
10411041

1042-
const float dall = x[i].dm.x;
1043-
const float dmin = x[i].dm.y;
1042+
const float dall = __low2half(x[i].dm);
1043+
const float dmin = __high2half(x[i].dm);
10441044

10451045
const uint16_t * a = (const uint16_t *)x[i].scales;
10461046
aux[0] = a[im+0] & kmask1;
@@ -1172,8 +1172,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
11721172
const float * y1 = yy + i*QK_K + y_offset;
11731173
const float * y2 = y1 + 128;
11741174

1175-
const float dall = x[i].dm.x;
1176-
const float dmin = x[i].dm.y;
1175+
const float dall = __low2half(x[i].dm);
1176+
const float dmin = __high2half(x[i].dm);
11771177

11781178
const uint16_t * a = (const uint16_t *)x[i].scales;
11791179
aux[0] = a[im+0] & kmask1;

0 commit comments

Comments
 (0)