116
116
#include " ggml.h"
117
117
#include " ggml-backend-impl.h"
118
118
119
+ #define CC_PASCAL 600
119
120
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
121
#define CC_VOLTA 700
121
122
#define CC_OFFSET_AMD 1000000
@@ -585,6 +586,14 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
585
586
return a;
586
587
}
587
588
589
+ static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
590
+ #pragma unroll
591
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
592
+ a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
593
+ }
594
+ return a;
595
+ }
596
+
588
597
static __device__ __forceinline__ float warp_reduce_max (float x) {
589
598
#pragma unroll
590
599
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
@@ -593,6 +602,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
593
602
return x;
594
603
}
595
604
605
+ static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
606
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
607
+ (void ) x;
608
+ bad_arch ();
609
+ #else
610
+ #pragma unroll
611
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
612
+ x = __hmax2 (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
613
+ }
614
+ return x;
615
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
616
+ }
617
+
596
618
static __device__ __forceinline__ float op_repeat (const float a, const float b) {
597
619
return b;
598
620
GGML_UNUSED (a);
@@ -5201,75 +5223,227 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5201
5223
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5202
5224
}
5203
5225
5204
- static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
5226
+ template <int ncols_template, int block_size_template, bool need_check>
5227
+ static __global__ void soft_max_f16 (const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5228
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5229
+ const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5230
+ const int ncols_smem = GGML_PAD (ncols_data/2 , WARP_SIZE);
5231
+
5232
+ const int tid = threadIdx .x ;
5233
+ const int rowx = blockIdx .x ;
5234
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5235
+
5236
+ const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
5237
+
5238
+ const int warp_id = threadIdx .x / WARP_SIZE;
5239
+ const int lane_id = threadIdx .x % WARP_SIZE;
5240
+
5241
+ extern __shared__ half2 data_soft_max_f16[];
5242
+ half2 * vals = data_soft_max_f16 + 0 ; // shared memory buffer to cache values between iterations
5243
+ half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
5244
+
5245
+ half2 max_val = make_half2 (-INFINITY, -INFINITY);
5246
+
5247
+ #pragma unroll
5248
+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5249
+ const int col_smem = col0 + tid;
5250
+ const int col_data = 2 *col0 + 2 *WARP_SIZE*warp_id + lane_id;
5251
+
5252
+ if (ncols_template == 0 && col_smem >= ncols_smem) {
5253
+ break ;
5254
+ }
5255
+
5256
+ const int ix = rowx*ncols_data + col_data;
5257
+ const int iy = rowy*ncols_data + col_data;
5258
+
5259
+ half2 val;
5260
+ val.x = x[ix + 0 ]*scale + (y ? y[iy + 0 ] : 0 .0f );
5261
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5262
+ val.y = -INFINITY;
5263
+ } else {
5264
+ val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0 .0f );
5265
+ }
5266
+ vals[col_smem] = val;
5267
+ max_val = __hmax2 (max_val, val);
5268
+ }
5269
+
5270
+ // find the max value in the block
5271
+ max_val = warp_reduce_max (max_val);
5272
+ if (block_size > WARP_SIZE) {
5273
+ if (warp_id == 0 ) {
5274
+ buf_iw[lane_id] = -INFINITY;
5275
+ }
5276
+ __syncthreads ();
5277
+
5278
+ if (lane_id == 0 ) {
5279
+ buf_iw[warp_id] = __hmax (max_val.x , max_val.y );
5280
+ }
5281
+ __syncthreads ();
5282
+
5283
+ max_val = __half2half2 (buf_iw[lane_id]);
5284
+ max_val = warp_reduce_max (max_val);
5285
+ } else {
5286
+ max_val = __half2half2 (__hmax (max_val.x , max_val.y ));
5287
+ }
5288
+
5289
+ half2 tmp = make_half2 (0 .0f , 0 .0f ); // partial sums
5290
+
5291
+ #pragma unroll
5292
+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5293
+ const int col_smem = col0 + tid;
5294
+
5295
+ if (ncols_template == 0 && col_smem >= ncols_smem) {
5296
+ break ;
5297
+ }
5298
+
5299
+ const half2 val = h2exp (vals[col_smem] - max_val);
5300
+
5301
+ tmp += val;
5302
+ vals[col_smem] = val;
5303
+ }
5304
+
5305
+ // find the sum of exps in the block
5306
+ tmp = warp_reduce_sum (tmp);
5307
+ if (block_size > WARP_SIZE) {
5308
+ if (warp_id == 0 ) {
5309
+ buf_iw[lane_id] = 0 .0f ;
5310
+ }
5311
+ __syncthreads ();
5312
+
5313
+ if (lane_id == 0 ) {
5314
+ buf_iw[warp_id] = tmp.x + tmp.y ;
5315
+ }
5316
+ __syncthreads ();
5317
+
5318
+ tmp = __half2half2 (buf_iw[lane_id]);
5319
+ tmp = warp_reduce_sum (tmp);
5320
+ } else {
5321
+ tmp = __half2half2 (tmp.x + tmp.y );
5322
+ }
5323
+
5324
+ const half2 inv_sum = make_half2 (1 .0f , 1 .0f ) / tmp;
5325
+
5326
+ #pragma unroll
5327
+ for (int col0 = 0 ; col0 < ncols_smem; col0 += block_size) {
5328
+ const int col_smem = col0 + tid;
5329
+ const int col_data = 2 *col0 + 2 *WARP_SIZE*warp_id + lane_id;
5330
+
5331
+ if (ncols_template == 0 && col_data >= ncols_data) {
5332
+ return ;
5333
+ }
5334
+
5335
+ const int idst = rowx*ncols_data + col_data;
5336
+ const half2 result = vals[col_smem] * inv_sum;
5337
+ dst[idst] = result.x ;
5338
+
5339
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5340
+ return ;
5341
+ }
5342
+
5343
+ dst[idst + WARP_SIZE] = result.y ;
5344
+ }
5345
+ #else
5346
+ (void ) x; (void ) y; (void ) dst; (void ) ncols_par; (void ) nrows_y; (void ) scale;
5347
+ bad_arch ();
5348
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5349
+ }
5350
+
5351
+ template <int ncols_template, int block_size_template>
5352
+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5353
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5354
+
5205
5355
const int tid = threadIdx .x ;
5206
5356
const int rowx = blockIdx .x ;
5207
5357
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5208
5358
5209
- const int block_size = blockDim .x ;
5359
+ const int block_size = block_size_template == 0 ? blockDim .x : block_size_template ;
5210
5360
5211
5361
const int warp_id = threadIdx .x / WARP_SIZE;
5212
5362
const int lane_id = threadIdx .x % WARP_SIZE;
5213
5363
5214
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
5364
+ extern __shared__ float data_soft_max_f32[];
5365
+ float * vals = data_soft_max_f32 + 0 ; // shared memory buffer to cache values between iterations
5366
+ float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
5215
5367
5216
5368
float max_val = -INFINITY;
5217
5369
5218
- for (int col = tid; col < ncols; col += block_size) {
5370
+ #pragma unroll
5371
+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5372
+ const int col = col0 + tid;
5373
+
5374
+ if (ncols_template == 0 && col >= ncols) {
5375
+ break ;
5376
+ }
5377
+
5219
5378
const int ix = rowx*ncols + col;
5220
5379
const int iy = rowy*ncols + col;
5221
- max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
5380
+
5381
+ const float val = x[ix]*scale + (y ? y[iy] : 0 .0f );
5382
+ vals[col] = val;
5383
+ max_val = max (max_val, val);
5222
5384
}
5223
5385
5224
5386
// find the max value in the block
5225
5387
max_val = warp_reduce_max (max_val);
5226
5388
if (block_size > WARP_SIZE) {
5227
5389
if (warp_id == 0 ) {
5228
- buf [lane_id] = -INFINITY;
5390
+ buf_iw [lane_id] = -INFINITY;
5229
5391
}
5230
5392
__syncthreads ();
5231
5393
5232
5394
if (lane_id == 0 ) {
5233
- buf [warp_id] = max_val;
5395
+ buf_iw [warp_id] = max_val;
5234
5396
}
5235
5397
__syncthreads ();
5236
5398
5237
- max_val = buf [lane_id];
5399
+ max_val = buf_iw [lane_id];
5238
5400
max_val = warp_reduce_max (max_val);
5239
5401
}
5240
5402
5241
- float tmp = 0 .f ;
5403
+ float tmp = 0 .0f ; // partial sum
5242
5404
5243
- for (int col = tid; col < ncols; col += block_size) {
5244
- const int ix = rowx*ncols + col;
5245
- const int iy = rowy*ncols + col;
5246
- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
5405
+ #pragma unroll
5406
+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5407
+ const int col = col0 + tid;
5408
+
5409
+ if (ncols_template == 0 && col >= ncols) {
5410
+ break ;
5411
+ }
5412
+
5413
+ const float val = expf (vals[col] - max_val);
5247
5414
tmp += val;
5248
- dst[ix ] = val;
5415
+ vals[col ] = val;
5249
5416
}
5250
5417
5251
5418
// find the sum of exps in the block
5252
5419
tmp = warp_reduce_sum (tmp);
5253
5420
if (block_size > WARP_SIZE) {
5254
5421
if (warp_id == 0 ) {
5255
- buf [lane_id] = 0 .f ;
5422
+ buf_iw [lane_id] = 0 .0f ;
5256
5423
}
5257
5424
__syncthreads ();
5258
5425
5259
5426
if (lane_id == 0 ) {
5260
- buf [warp_id] = tmp;
5427
+ buf_iw [warp_id] = tmp;
5261
5428
}
5262
5429
__syncthreads ();
5263
5430
5264
- tmp = buf [lane_id];
5431
+ tmp = buf_iw [lane_id];
5265
5432
tmp = warp_reduce_sum (tmp);
5266
5433
}
5267
5434
5268
- const float inv_tmp = 1 .f / tmp;
5435
+ const float inv_sum = 1 .0f / tmp;
5269
5436
5270
- for (int col = tid; col < ncols; col += block_size) {
5271
- const int i = rowx*ncols + col;
5272
- dst[i] *= inv_tmp;
5437
+ #pragma unroll
5438
+ for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
5439
+ const int col = col0 + tid;
5440
+
5441
+ if (ncols_template == 0 && col >= ncols) {
5442
+ return ;
5443
+ }
5444
+
5445
+ const int idst = rowx*ncols + col;
5446
+ dst[idst] = vals[col] * inv_sum;
5273
5447
}
5274
5448
}
5275
5449
@@ -6543,12 +6717,80 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
6543
6717
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
6544
6718
}
6545
6719
6720
+ static void soft_max_f16_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6721
+ int nth = WARP_SIZE;
6722
+ while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
6723
+ const dim3 block_dims (nth, 1 , 1 );
6724
+ const dim3 block_nums (nrows_x, 1 , 1 );
6725
+ const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof (half);
6726
+ static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
6727
+ switch (ncols_x) {
6728
+ case 32 :
6729
+ soft_max_f16<32 , 32 , true ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6730
+ break ;
6731
+ case 64 :
6732
+ soft_max_f16<64 , 32 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6733
+ break ;
6734
+ case 128 :
6735
+ soft_max_f16<128 , 64 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6736
+ break ;
6737
+ case 256 :
6738
+ soft_max_f16<256 , 128 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6739
+ break ;
6740
+ case 512 :
6741
+ soft_max_f16<512 , 256 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6742
+ break ;
6743
+ case 1024 :
6744
+ soft_max_f16<1024 , 512 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6745
+ break ;
6746
+ case 2048 :
6747
+ soft_max_f16<2048 , 1024 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6748
+ break ;
6749
+ case 4096 :
6750
+ soft_max_f16<4096 , 1024 , false ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6751
+ break ;
6752
+ default :
6753
+ soft_max_f16<0 , 0 , true ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6754
+ break ;
6755
+ }
6756
+ }
6757
+
6546
6758
static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6547
6759
int nth = WARP_SIZE;
6548
6760
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
6549
6761
const dim3 block_dims (nth, 1 , 1 );
6550
6762
const dim3 block_nums (nrows_x, 1 , 1 );
6551
- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6763
+ const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof (float );
6764
+ static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
6765
+ switch (ncols_x) {
6766
+ case 32 :
6767
+ soft_max_f32<32 , 32 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6768
+ break ;
6769
+ case 64 :
6770
+ soft_max_f32<64 , 64 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6771
+ break ;
6772
+ case 128 :
6773
+ soft_max_f32<128 , 128 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6774
+ break ;
6775
+ case 256 :
6776
+ soft_max_f32<256 , 256 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6777
+ break ;
6778
+ case 512 :
6779
+ soft_max_f32<512 , 512 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6780
+ break ;
6781
+ case 1024 :
6782
+ soft_max_f32<1024 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6783
+ break ;
6784
+ case 2048 :
6785
+ soft_max_f32<2048 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6786
+ break ;
6787
+ case 4096 :
6788
+ soft_max_f32<4096 , 1024 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6789
+ break ;
6790
+ default :
6791
+ soft_max_f32<0 , 0 ><<<block_nums, block_dims, shmem, stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6792
+ break ;
6793
+ }
6552
6794
}
6553
6795
6554
6796
static void im2col_f32_f16_cuda (const float * x, half* dst,
@@ -7873,7 +8115,21 @@ static void ggml_cuda_op_soft_max(
7873
8115
float scale = 1 .0f ;
7874
8116
memcpy (&scale, dst->op_params , sizeof (float ));
7875
8117
7876
- soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8118
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8119
+ const bool use_f16_soft_max = false ;
8120
+ #else
8121
+ #ifdef GGML_CUDA_F16
8122
+ const bool use_f16_soft_max = true ;
8123
+ #else
8124
+ const bool use_f16_soft_max = false ;
8125
+ #endif // GGML_CUDA_F16
8126
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8127
+
8128
+ if (use_f16_soft_max) {
8129
+ soft_max_f16_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8130
+ } else {
8131
+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8132
+ }
7877
8133
7878
8134
(void ) dst;
7879
8135
}
0 commit comments