Skip to content

Commit eeb5d68

Browse files
CUDA: faster softmax via shared memory + fp16 math
1 parent 540938f commit eeb5d68

File tree

1 file changed

+279
-23
lines changed

1 file changed

+279
-23
lines changed

ggml-cuda.cu

+279-23
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
#include "ggml.h"
117117
#include "ggml-backend-impl.h"
118118

119+
#define CC_PASCAL 600
119120
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120121
#define CC_VOLTA 700
121122
#define CC_OFFSET_AMD 1000000
@@ -585,6 +586,14 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
585586
return a;
586587
}
587588

589+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
590+
#pragma unroll
591+
for (int mask = 16; mask > 0; mask >>= 1) {
592+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
593+
}
594+
return a;
595+
}
596+
588597
static __device__ __forceinline__ float warp_reduce_max(float x) {
589598
#pragma unroll
590599
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -593,6 +602,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
593602
return x;
594603
}
595604

605+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
606+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
607+
(void) x;
608+
bad_arch();
609+
#else
610+
#pragma unroll
611+
for (int mask = 16; mask > 0; mask >>= 1) {
612+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
613+
}
614+
return x;
615+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
616+
}
617+
596618
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
597619
return b;
598620
GGML_UNUSED(a);
@@ -5201,75 +5223,227 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
52015223
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
52025224
}
52035225

5204-
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
5226+
template <int ncols_template, int block_size_template, bool need_check>
5227+
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5228+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5229+
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5230+
const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE);
5231+
5232+
const int tid = threadIdx.x;
5233+
const int rowx = blockIdx.x;
5234+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5235+
5236+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5237+
5238+
const int warp_id = threadIdx.x / WARP_SIZE;
5239+
const int lane_id = threadIdx.x % WARP_SIZE;
5240+
5241+
extern __shared__ half2 data_soft_max_f16[];
5242+
half2 * vals = data_soft_max_f16 + 0; // shared memory buffer to cache values between iterations
5243+
half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
5244+
5245+
half2 max_val = make_half2(-INFINITY, -INFINITY);
5246+
5247+
#pragma unroll
5248+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5249+
const int col_smem = col0 + tid;
5250+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5251+
5252+
if (ncols_template == 0 && col_smem >= ncols_smem) {
5253+
break;
5254+
}
5255+
5256+
const int ix = rowx*ncols_data + col_data;
5257+
const int iy = rowy*ncols_data + col_data;
5258+
5259+
half2 val;
5260+
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5261+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
5262+
val.y = -INFINITY;
5263+
} else {
5264+
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
5265+
}
5266+
vals[col_smem] = val;
5267+
max_val = __hmax2(max_val, val);
5268+
}
5269+
5270+
// find the max value in the block
5271+
max_val = warp_reduce_max(max_val);
5272+
if (block_size > WARP_SIZE) {
5273+
if (warp_id == 0) {
5274+
buf_iw[lane_id] = -INFINITY;
5275+
}
5276+
__syncthreads();
5277+
5278+
if (lane_id == 0) {
5279+
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
5280+
}
5281+
__syncthreads();
5282+
5283+
max_val = __half2half2(buf_iw[lane_id]);
5284+
max_val = warp_reduce_max(max_val);
5285+
} else {
5286+
max_val = __half2half2(__hmax(max_val.x, max_val.y));
5287+
}
5288+
5289+
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
5290+
5291+
#pragma unroll
5292+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5293+
const int col_smem = col0 + tid;
5294+
5295+
if (ncols_template == 0 && col_smem >= ncols_smem) {
5296+
break;
5297+
}
5298+
5299+
const half2 val = h2exp(vals[col_smem] - max_val);
5300+
5301+
tmp += val;
5302+
vals[col_smem] = val;
5303+
}
5304+
5305+
// find the sum of exps in the block
5306+
tmp = warp_reduce_sum(tmp);
5307+
if (block_size > WARP_SIZE) {
5308+
if (warp_id == 0) {
5309+
buf_iw[lane_id] = 0.0f;
5310+
}
5311+
__syncthreads();
5312+
5313+
if (lane_id == 0) {
5314+
buf_iw[warp_id] = tmp.x + tmp.y;
5315+
}
5316+
__syncthreads();
5317+
5318+
tmp = __half2half2(buf_iw[lane_id]);
5319+
tmp = warp_reduce_sum(tmp);
5320+
} else {
5321+
tmp = __half2half2(tmp.x + tmp.y);
5322+
}
5323+
5324+
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
5325+
5326+
#pragma unroll
5327+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5328+
const int col_smem = col0 + tid;
5329+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5330+
5331+
if (ncols_template == 0 && col_data >= ncols_data) {
5332+
return;
5333+
}
5334+
5335+
const int idst = rowx*ncols_data + col_data;
5336+
const half2 result = vals[col_smem] * inv_sum;
5337+
dst[idst] = result.x;
5338+
5339+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
5340+
return;
5341+
}
5342+
5343+
dst[idst + WARP_SIZE] = result.y;
5344+
}
5345+
#else
5346+
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
5347+
bad_arch();
5348+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5349+
}
5350+
5351+
template <int ncols_template, int block_size_template>
5352+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5353+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5354+
52055355
const int tid = threadIdx.x;
52065356
const int rowx = blockIdx.x;
52075357
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
52085358

5209-
const int block_size = blockDim.x;
5359+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
52105360

52115361
const int warp_id = threadIdx.x / WARP_SIZE;
52125362
const int lane_id = threadIdx.x % WARP_SIZE;
52135363

5214-
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
5364+
extern __shared__ float data_soft_max_f32[];
5365+
float * vals = data_soft_max_f32 + 0; // shared memory buffer to cache values between iterations
5366+
float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
52155367

52165368
float max_val = -INFINITY;
52175369

5218-
for (int col = tid; col < ncols; col += block_size) {
5370+
#pragma unroll
5371+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
5372+
const int col = col0 + tid;
5373+
5374+
if (ncols_template == 0 && col >= ncols) {
5375+
break;
5376+
}
5377+
52195378
const int ix = rowx*ncols + col;
52205379
const int iy = rowy*ncols + col;
5221-
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
5380+
5381+
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
5382+
vals[col] = val;
5383+
max_val = max(max_val, val);
52225384
}
52235385

52245386
// find the max value in the block
52255387
max_val = warp_reduce_max(max_val);
52265388
if (block_size > WARP_SIZE) {
52275389
if (warp_id == 0) {
5228-
buf[lane_id] = -INFINITY;
5390+
buf_iw[lane_id] = -INFINITY;
52295391
}
52305392
__syncthreads();
52315393

52325394
if (lane_id == 0) {
5233-
buf[warp_id] = max_val;
5395+
buf_iw[warp_id] = max_val;
52345396
}
52355397
__syncthreads();
52365398

5237-
max_val = buf[lane_id];
5399+
max_val = buf_iw[lane_id];
52385400
max_val = warp_reduce_max(max_val);
52395401
}
52405402

5241-
float tmp = 0.f;
5403+
float tmp = 0.0f; // partial sum
52425404

5243-
for (int col = tid; col < ncols; col += block_size) {
5244-
const int ix = rowx*ncols + col;
5245-
const int iy = rowy*ncols + col;
5246-
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
5405+
#pragma unroll
5406+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
5407+
const int col = col0 + tid;
5408+
5409+
if (ncols_template == 0 && col >= ncols) {
5410+
break;
5411+
}
5412+
5413+
const float val = expf(vals[col] - max_val);
52475414
tmp += val;
5248-
dst[ix] = val;
5415+
vals[col] = val;
52495416
}
52505417

52515418
// find the sum of exps in the block
52525419
tmp = warp_reduce_sum(tmp);
52535420
if (block_size > WARP_SIZE) {
52545421
if (warp_id == 0) {
5255-
buf[lane_id] = 0.f;
5422+
buf_iw[lane_id] = 0.0f;
52565423
}
52575424
__syncthreads();
52585425

52595426
if (lane_id == 0) {
5260-
buf[warp_id] = tmp;
5427+
buf_iw[warp_id] = tmp;
52615428
}
52625429
__syncthreads();
52635430

5264-
tmp = buf[lane_id];
5431+
tmp = buf_iw[lane_id];
52655432
tmp = warp_reduce_sum(tmp);
52665433
}
52675434

5268-
const float inv_tmp = 1.f / tmp;
5435+
const float inv_sum = 1.0f / tmp;
52695436

5270-
for (int col = tid; col < ncols; col += block_size) {
5271-
const int i = rowx*ncols + col;
5272-
dst[i] *= inv_tmp;
5437+
#pragma unroll
5438+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
5439+
const int col = col0 + tid;
5440+
5441+
if (ncols_template == 0 && col >= ncols) {
5442+
return;
5443+
}
5444+
5445+
const int idst = rowx*ncols + col;
5446+
dst[idst] = vals[col] * inv_sum;
52735447
}
52745448
}
52755449

@@ -6543,12 +6717,80 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
65436717
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
65446718
}
65456719

6720+
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6721+
int nth = WARP_SIZE;
6722+
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6723+
const dim3 block_dims(nth, 1, 1);
6724+
const dim3 block_nums(nrows_x, 1, 1);
6725+
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
6726+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
6727+
switch (ncols_x) {
6728+
case 32:
6729+
soft_max_f16<32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6730+
break;
6731+
case 64:
6732+
soft_max_f16<64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6733+
break;
6734+
case 128:
6735+
soft_max_f16<128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6736+
break;
6737+
case 256:
6738+
soft_max_f16<256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6739+
break;
6740+
case 512:
6741+
soft_max_f16<512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6742+
break;
6743+
case 1024:
6744+
soft_max_f16<1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6745+
break;
6746+
case 2048:
6747+
soft_max_f16<2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6748+
break;
6749+
case 4096:
6750+
soft_max_f16<4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6751+
break;
6752+
default:
6753+
soft_max_f16<0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6754+
break;
6755+
}
6756+
}
6757+
65466758
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
65476759
int nth = WARP_SIZE;
65486760
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
65496761
const dim3 block_dims(nth, 1, 1);
65506762
const dim3 block_nums(nrows_x, 1, 1);
6551-
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6763+
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
6764+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
6765+
switch (ncols_x) {
6766+
case 32:
6767+
soft_max_f32<32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6768+
break;
6769+
case 64:
6770+
soft_max_f32<64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6771+
break;
6772+
case 128:
6773+
soft_max_f32<128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6774+
break;
6775+
case 256:
6776+
soft_max_f32<256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6777+
break;
6778+
case 512:
6779+
soft_max_f32<512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6780+
break;
6781+
case 1024:
6782+
soft_max_f32<1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6783+
break;
6784+
case 2048:
6785+
soft_max_f32<2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6786+
break;
6787+
case 4096:
6788+
soft_max_f32<4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6789+
break;
6790+
default:
6791+
soft_max_f32<0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6792+
break;
6793+
}
65526794
}
65536795

65546796
static void im2col_f32_f16_cuda(const float* x, half* dst,
@@ -7873,7 +8115,21 @@ static void ggml_cuda_op_soft_max(
78738115
float scale = 1.0f;
78748116
memcpy(&scale, dst->op_params, sizeof(float));
78758117

7876-
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8118+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8119+
const bool use_f16_soft_max = false;
8120+
#else
8121+
#ifdef GGML_CUDA_F16
8122+
const bool use_f16_soft_max = true;
8123+
#else
8124+
const bool use_f16_soft_max = false;
8125+
#endif // GGML_CUDA_F16
8126+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8127+
8128+
if (use_f16_soft_max) {
8129+
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8130+
} else {
8131+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8132+
}
78778133

78788134
(void) dst;
78798135
}

0 commit comments

Comments
 (0)