Skip to content

Commit 677ad0a

Browse files
committed
Add initial/naive CUDA kernels for the GGML_OP_SSM_CONV and GGML_OP_SSM_SCAN ops
1 parent 2ac95c9 commit 677ad0a

File tree

5 files changed

+353
-0
lines changed

5 files changed

+353
-0
lines changed

ggml-cuda.cu

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "ggml-cuda/tsembd.cuh"
3030
#include "ggml-cuda/unary.cuh"
3131
#include "ggml-cuda/upscale.cuh"
32+
#include "ggml-cuda/ssm_conv.cuh"
33+
#include "ggml-cuda/ssm_scan.cuh"
3234

3335
#include <algorithm>
3436
#include <array>
@@ -2350,6 +2352,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23502352
case GGML_OP_FLASH_ATTN_EXT:
23512353
ggml_cuda_flash_attn_ext(ctx, dst);
23522354
break;
2355+
case GGML_OP_SSM_CONV:
2356+
ggml_cuda_op_ssm_conv(ctx, dst);
2357+
break;
2358+
case GGML_OP_SSM_SCAN:
2359+
ggml_cuda_op_ssm_scan(ctx, dst);
2360+
break;
23532361
default:
23542362
return false;
23552363
}

ggml-cuda/ssm_conv.cu

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "ssm_conv.cuh"
2+
3+
template <int block_size>
4+
static __global__ void ssm_conv_f32(
5+
const float * src0, const float * src1, const float * src2, const float * src3,
6+
const int src0_ne0, const int src0_nb1, const int src0_nb2,
7+
const int src1_nb0, const int src1_nb1,
8+
const int src2_nb1, const int src2_nb2,
9+
const int src3_nb1,
10+
float * dst,
11+
const int nc, const int nr, const int n_t, const int n_kv) {
12+
13+
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
14+
const int tid = threadIdx.x;
15+
16+
const int ith = tid;
17+
const int nth = WARP_SIZE;
18+
19+
// rows per thread
20+
const int dr = (nr + nth - 1)/nth;
21+
22+
// row range for this thread
23+
const int ir0 = dr*ith;
24+
const int ir1 = min(ir0 + dr, nr);
25+
const int ir = ir1 - ir0;
26+
27+
if (n_kv > 1) {
28+
// multiple sequences means it's hard to know when it's the first time a state is read,
29+
// so copy them all over to the destination, just to be sure.
30+
for (int i3 = 0; i3 < n_kv; ++i3) {
31+
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2);
32+
float * s = (float *) ((char *) dst + ir0*src2_nb1 + i3*src2_nb2 + nr*n_t*sizeof(float));
33+
// can't use memcpy because of d_conv vs d_conv - 1
34+
for (int i1 = 0; i1 < ir; ++i1) {
35+
for (int i0 = 0; i0 < nc - 1; ++i0) {
36+
// copy s0 to last (d_conv - 1) columns of s
37+
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
38+
}
39+
}
40+
}
41+
}
42+
43+
for (int i2 = 0; i2 < n_t; ++i2) {
44+
int32_t * sq = (int32_t *) ((char *) src3 + i2*src3_nb1); // {n_kv, n_tokens}
45+
float * x = (float *) ((char *) dst + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
46+
float * s = (float *) ((char *) dst + ir0*src2_nb1 + sq[0]*src2_nb2 + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
47+
float * s0; // {d_conv - 1, d_inner, n_kv}
48+
float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
49+
float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
50+
int ne0s0;
51+
52+
// avoid needing to copy the state for the first token
53+
if (i2 == 0) {
54+
s0 = (float *) ((char *) src0 + ir0*src0_nb1 + sq[0]*src0_nb2); // {d_conv - 1, d_inner, n_kv}
55+
ne0s0 = src0_ne0;
56+
} else {
57+
// the source is the last (d_conv - 1) columns of the destination
58+
s0 = s + 1;
59+
ne0s0 = nc;
60+
}
61+
62+
// d_inner
63+
for (int i1 = 0; i1 < ir; ++i1) {
64+
// shift state left
65+
for (int i0 = 0; i0 < nc - 1; ++i0) {
66+
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
67+
}
68+
// insert x on the last column
69+
s[(nc - 1) + i1*nc] = x0[i1];
70+
}
71+
72+
// handle copies when there are multiple output states
73+
for (int i3 = 1; i3 < n_kv; ++i3) {
74+
int32_t seq = sq[i3];
75+
if (0 <= seq && seq < n_kv) {
76+
float * s1 = s + (seq - sq[0])*nc*nr;
77+
78+
//memcpy(s1, s, nc*ir*sizeof(float));
79+
for (int i4 = 0; i4 < nc*ir; i4++) {
80+
s1[i4] = s[i4];
81+
}
82+
} else {
83+
// stop at negative or too big seq_ids
84+
break;
85+
}
86+
}
87+
88+
// it seems a little faster when this is separate from the state shift
89+
for (int i1 = 0; i1 < ir; ++i1) {
90+
// rowwise dot product
91+
float sumf = 0.0f;
92+
for (int i0 = 0; i0 < nc; ++i0) {
93+
int i = i0 + i1*nc;
94+
sumf += s[i] * c[i];
95+
}
96+
x[i1] = sumf;
97+
}
98+
}
99+
}
100+
101+
static void ssm_conv_f32_cuda(
102+
const float * src0, const float * src1, const float * src2, const float * src3,
103+
const int src0_ne0, const int src0_nb1, const int src0_nb2,
104+
const int src1_nb0, const int src1_nb1,
105+
const int src2_nb1, const int src2_nb2,
106+
const int src3_nb1,
107+
float * dst,
108+
const int nc, const int nr, const int n_t, const int n_kv, cudaStream_t stream) {
109+
110+
const dim3 block_dims(WARP_SIZE, 1, 1);
111+
const int nblocks = 1; // TODO
112+
113+
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
114+
src0, src1, src2, src3,
115+
src0_ne0, src0_nb1, src0_nb2,
116+
src1_nb0, src1_nb1,
117+
src2_nb1, src2_nb2,
118+
src3_nb1,
119+
dst,
120+
nc, nr, n_t, n_kv);
121+
}
122+
123+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124+
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
125+
const struct ggml_tensor * src1 = dst->src[1]; // x
126+
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
127+
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
128+
129+
const int nc = src2->ne[0]; // d_conv
130+
const int nr = src0->ne[1]; // d_inner
131+
const int n_t = src1->ne[1]; // n_tokens
132+
const int n_kv = src0->ne[2]; // max number of sequences in the batch
133+
134+
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
135+
GGML_ASSERT(src0->nb[0] == sizeof(float));
136+
GGML_ASSERT(src1->nb[0] == sizeof(float));
137+
GGML_ASSERT(src2->nb[0] == sizeof(float));
138+
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
139+
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
140+
// for use with the destination state offset between sequences
141+
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
142+
143+
const float * src0_d = (const float *)src0->data;
144+
const float * src1_d = (const float *)src1->data;
145+
const float * src2_d = (const float *)src2->data;
146+
const float * src3_d = (const float *)src3->data;
147+
float * dst_d = (float *)dst->data;
148+
cudaStream_t stream = ctx.stream();
149+
150+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
151+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
152+
153+
ssm_conv_f32_cuda(src0_d, src1_d, src2_d, src3_d,
154+
src0->ne[0], src0->nb[1], src0->nb[2],
155+
src1->nb[0], src1->nb[1],
156+
src2->nb[1], src2->nb[2],
157+
src3->nb[1],
158+
dst_d, nc, nr, n_t, n_kv, stream);
159+
}

ggml-cuda/ssm_conv.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml-cuda/ssm_scan.cu

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include "ssm_scan.cuh"
2+
3+
template <int block_size>
4+
static __global__ void ssm_scan_f32(
5+
const float * src0, const float * src1, const float * src2, const float * src3,
6+
const float * src4, const float * src5, const float * src6,
7+
const int src0_nb1, const int src0_nb2,
8+
const int src1_nb0, const int src1_nb1, const int src1_nb2,
9+
const int src2_nb0, const int src2_nb1,
10+
const int src3_nb1,
11+
const int src4_nb1,
12+
const int src5_nb1,
13+
const int src6_nb1,
14+
float * dst,
15+
const int nc, const int nr, const int n_t, const int n_kv) {
16+
17+
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
18+
const int tid = threadIdx.x;
19+
20+
const int ith = tid;
21+
const int nth = WARP_SIZE;
22+
23+
// rows per thread
24+
const int dr = (nr + nth - 1)/nth;
25+
26+
// row range for this thread
27+
const int ir0 = dr*ith;
28+
const int ir1 = min(ir0 + dr, nr);
29+
const int ir = ir1 - ir0;
30+
31+
if (n_kv > 1) {
32+
// it's hard to know if the source states have already been copied
33+
// when there are multiple, so copy them already.
34+
for (int i3 = 0; i3 < n_kv; ++i3) {
35+
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2);
36+
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb2);
37+
38+
//memcpy(s, s0, nc*ir*sizeof(float));
39+
for (int i4 = 0; i4 < nc*ir; i4++) {
40+
s[i4] = s0[i4];
41+
}
42+
}
43+
}
44+
45+
for (int i2 = 0; i2 < n_t; ++i2) {
46+
int32_t * sq = (int32_t *) ((char *) src6 + i2*src6_nb1); // {n_kv, n_tokens}
47+
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
48+
float * s = (float *) ((char *) dst + ir0*src0_nb1 + sq[0]*src0_nb2 + src1_nb2); // {d_state, d_inner, n_kv}
49+
float * s0;
50+
float * x = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
51+
float * dt = (float *) ((char *) src2 + ir0*src2_nb0 + i2*src2_nb1); // {d_inner, n_tokens}
52+
float * A = (float *) ((char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
53+
float * B = (float *) ((char *) src4 + i2*src4_nb1); // {d_state, n_tokens}
54+
float * C = (float *) ((char *) src5 + i2*src5_nb1); // {d_state, n_tokens}
55+
56+
// avoid needing to copy the state for the first token
57+
if (i2 == 0) {
58+
s0 = (float *) ((char *) src0 + ir0*(src0_nb1) + sq[0]*src0_nb2); // {d_state, d_inner, n_kv}
59+
} else {
60+
// otherwise the source is the same as the destination
61+
s0 = s;
62+
}
63+
64+
// d_inner
65+
for (int i1 = 0; i1 < ir; ++i1) {
66+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
67+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
68+
float x_dt = x[i1] * dt_soft_plus;
69+
float sumf = 0.0f;
70+
// d_state
71+
for (int i0 = 0; i0 < nc; ++i0) {
72+
int i = i0 + i1*nc;
73+
// state = prev_state * dA + dB * x
74+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
75+
// y = rowwise_dotprod(state, C)
76+
sumf += state * C[i0];
77+
s[i] = state;
78+
}
79+
y[i1] = sumf;
80+
}
81+
82+
// handle copies when there are multiple output states
83+
for (int i3 = 1; i3 < n_kv; ++i3) {
84+
int32_t seq = sq[i3];
85+
if (0 <= seq && seq < n_kv) {
86+
float * s1 = s + (seq - sq[0])*nc*nr;
87+
//memcpy(s1, s, nc*ir*sizeof(float));
88+
for (int i4 = 0; i4 < nc*ir; i4++) {
89+
s1[i4] = s[i4];
90+
}
91+
} else {
92+
// stop at negative or too big seq_ids
93+
break;
94+
}
95+
}
96+
}
97+
}
98+
99+
static void ssm_scan_f32_cuda(
100+
const float * src0, const float * src1, const float * src2, const float * src3,
101+
const float * src4, const float * src5, const float * src6,
102+
const int src0_nb1, const int src0_nb2,
103+
const int src1_nb0, const int src1_nb1, const int src1_nb2,
104+
const int src2_nb0, const int src2_nb1,
105+
const int src3_nb1,
106+
const int src4_nb1,
107+
const int src5_nb1,
108+
const int src6_nb1,
109+
float * dst,
110+
const int nc, const int nr, const int n_t, const int n_kv, cudaStream_t stream) {
111+
112+
const dim3 block_dims(WARP_SIZE, 1, 1);
113+
const int nblocks = 1; // TODO
114+
115+
ssm_scan_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
116+
src0, src1, src2, src3, src4, src5, src6,
117+
src0_nb1, src0_nb2,
118+
src1_nb0, src1_nb1, src1_nb2,
119+
src2_nb0, src2_nb1,
120+
src3_nb1,
121+
src4_nb1,
122+
src5_nb1,
123+
src6_nb1,
124+
dst,
125+
nc, nr, n_t, n_kv);
126+
}
127+
128+
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
129+
const struct ggml_tensor * src0 = dst->src[0]; // s
130+
const struct ggml_tensor * src1 = dst->src[1]; // x
131+
const struct ggml_tensor * src2 = dst->src[2]; // dt
132+
const struct ggml_tensor * src3 = dst->src[3]; // A
133+
const struct ggml_tensor * src4 = dst->src[4]; // B
134+
const struct ggml_tensor * src5 = dst->src[5]; // C
135+
const struct ggml_tensor * src6 = dst->src[6]; // sq
136+
137+
const int64_t nc = src0->ne[0]; // d_state
138+
const int64_t nr = src0->ne[1]; // d_inner
139+
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
140+
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
141+
142+
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
143+
GGML_ASSERT(src0->nb[0] == sizeof(float));
144+
GGML_ASSERT(src1->nb[0] == sizeof(float));
145+
GGML_ASSERT(src2->nb[0] == sizeof(float));
146+
GGML_ASSERT(src3->nb[0] == sizeof(float));
147+
GGML_ASSERT(src4->nb[0] == sizeof(float));
148+
GGML_ASSERT(src5->nb[0] == sizeof(float));
149+
// required for the dot product between s and C, and when copying the states
150+
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
151+
// required for per-sequence offsets for states
152+
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
153+
// required to get correct offset for state destination (i.e. src1->nb[2])
154+
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
155+
156+
const float * src0_d = (const float *)src0->data;
157+
const float * src1_d = (const float *)src1->data;
158+
const float * src2_d = (const float *)src2->data;
159+
const float * src3_d = (const float *)src3->data;
160+
const float * src4_d = (const float *)src4->data;
161+
const float * src5_d = (const float *)src5->data;
162+
const float * src6_d = (const float *)src6->data;
163+
float * dst_d = (float *)dst->data;
164+
cudaStream_t stream = ctx.stream();
165+
166+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
167+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
168+
169+
ssm_scan_f32_cuda(
170+
src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d,
171+
src0->nb[1], src0->nb[2],
172+
src1->nb[0], src1->nb[1], src1->nb[2],
173+
src2->nb[0], src2->nb[1],
174+
src3->nb[1],
175+
src4->nb[1],
176+
src5->nb[1],
177+
src6->nb[1],
178+
dst_d,
179+
nc, nr, n_t, n_kv, stream);
180+
}

ggml-cuda/ssm_scan.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)