@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
38
38
scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
39
39
// head_size] or [num_tokens, num_heads,
40
40
// head_size]
41
- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
41
+ scalar_t * __restrict__ key, // nullptr or
42
+ // [batch_size, seq_len, num_kv_heads,
42
43
// head_size] or [num_tokens, num_kv_heads,
43
44
// head_size]
44
45
const scalar_t * cache_ptr, const int head_size, const int num_heads,
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
57
58
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
58
59
}
59
60
60
- const int nk = num_kv_heads * embed_dim;
61
- for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
62
- const int head_idx = i / embed_dim;
63
- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
64
- const int rot_offset = i % embed_dim;
65
- apply_token_rotary_embedding<scalar_t , IS_NEOX>(
66
- key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
61
+ if (key != nullptr ) {
62
+ const int nk = num_kv_heads * embed_dim;
63
+ for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
64
+ const int head_idx = i / embed_dim;
65
+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
66
+ const int rot_offset = i % embed_dim;
67
+ apply_token_rotary_embedding<scalar_t , IS_NEOX>(
68
+ key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
69
+ }
67
70
}
68
71
}
69
72
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
74
77
scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
75
78
// head_size] or [num_tokens, num_heads,
76
79
// head_size]
77
- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
80
+ scalar_t * __restrict__ key, // nullptr or
81
+ // [batch_size, seq_len, num_kv_heads,
78
82
// head_size] or [num_tokens, num_kv_heads,
79
83
// head_size]
80
84
const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
98
102
scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads,
99
103
// head_size] or [num_tokens, num_heads,
100
104
// head_size]
101
- scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads,
105
+ scalar_t * __restrict__ key, // nullptr or
106
+ // [batch_size, seq_len, num_kv_heads,
102
107
// head_size] or [num_tokens, num_kv_heads,
103
108
// head_size]
104
109
const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -127,51 +132,53 @@ void rotary_embedding(
127
132
// [num_tokens, num_heads * head_size] or
128
133
// [batch_size, seq_len, num_heads, head_size] or
129
134
// [num_tokens, num_heads, head_size]
130
- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
131
- // [num_tokens, num_kv_heads * head_size] or
132
- // [batch_size, seq_len, num_heads, head_size] or
133
- // [num_tokens, num_heads, head_size]
135
+ std::optional<torch::Tensor> key,
136
+ // null or
137
+ // [batch_size, seq_len, num_kv_heads * head_size] or
138
+ // [num_tokens, num_kv_heads * head_size] or
139
+ // [batch_size, seq_len, num_heads, head_size] or
140
+ // [num_tokens, num_heads, head_size]
134
141
int64_t head_size,
135
142
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
136
143
bool is_neox) {
137
144
// num_tokens = batch_size * seq_len
138
145
int64_t num_tokens = positions.numel ();
139
146
int positions_ndim = positions.dim ();
140
147
141
- // Make sure num_tokens dim is consistent across positions, query, and key.
148
+ // Make sure num_tokens dim is consistent across positions, query, and key
142
149
TORCH_CHECK (
143
150
positions_ndim == 1 || positions_ndim == 2 ,
144
151
" positions must have shape [num_tokens] or [batch_size, seq_len]" );
145
152
if (positions_ndim == 1 ) {
146
- TORCH_CHECK (
147
- query. size ( 0 ) == positions. size ( 0 ) && key. size (0 ) == positions.size (0 ),
148
- " query, key and positions must have the same number of tokens" );
153
+ TORCH_CHECK (query. size ( 0 ) == positions. size ( 0 ) &&
154
+ (!key. has_value () || key-> size (0 ) == positions.size (0 ) ),
155
+ " query, key and positions must have the same number of tokens" );
149
156
}
150
157
if (positions_ndim == 2 ) {
151
158
TORCH_CHECK (
152
159
query.size (0 ) == positions.size (0 ) &&
153
- key.size (0 ) == positions.size (0 ) &&
160
+ (! key.has_value () || key-> size (0 ) == positions.size (0 ) ) &&
154
161
query.size (1 ) == positions.size (1 ) &&
155
- key.size (1 ) == positions.size (1 ),
162
+ (! key.has_value () || key-> size (1 ) == positions.size (1 ) ),
156
163
" query, key and positions must have the same batch_size and seq_len" );
157
164
}
158
165
159
166
// Make sure head_size is valid for query and key
160
167
// hidden_size = num_heads * head_size
161
168
int query_hidden_size = query.numel () / num_tokens;
162
- int key_hidden_size = key.numel () / num_tokens;
169
+ int key_hidden_size = key.has_value () ? key-> numel () / num_tokens : 0 ;
163
170
TORCH_CHECK (query_hidden_size % head_size == 0 );
164
171
TORCH_CHECK (key_hidden_size % head_size == 0 );
165
172
166
173
// Make sure query and key have consistent number of heads
167
174
int num_heads = query_hidden_size / head_size;
168
- int num_kv_heads = key_hidden_size / head_size;
175
+ int num_kv_heads = key. has_value () ? key_hidden_size / head_size : num_heads ;
169
176
TORCH_CHECK (num_heads % num_kv_heads == 0 );
170
177
171
178
int rot_dim = cos_sin_cache.size (1 );
172
179
int seq_dim_idx = positions_ndim - 1 ;
173
180
int64_t query_stride = query.stride (seq_dim_idx);
174
- int64_t key_stride = key.stride (seq_dim_idx);
181
+ int64_t key_stride = key.has_value () ? key-> stride (seq_dim_idx) : 0 ;
175
182
176
183
dim3 grid (num_tokens);
177
184
dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -181,15 +188,16 @@ void rotary_embedding(
181
188
if (is_neox) {
182
189
vllm::rotary_embedding_kernel<scalar_t , true ><<<grid, block, 0 , stream>>> (
183
190
positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
184
- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(), rot_dim,
185
- query_stride, key_stride, num_heads, num_kv_heads, head_size);
191
+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
192
+ cos_sin_cache.data_ptr <scalar_t >(), rot_dim, query_stride, key_stride,
193
+ num_heads, num_kv_heads, head_size);
186
194
} else {
187
195
vllm::rotary_embedding_kernel<scalar_t , false >
188
196
<<<grid, block, 0 , stream>>> (
189
197
positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
190
- key.data_ptr < scalar_t >(), cos_sin_cache. data_ptr <scalar_t >(),
191
- rot_dim, query_stride, key_stride, num_heads, num_kv_heads ,
192
- head_size);
198
+ key.has_value () ? key-> data_ptr <scalar_t >() : nullptr ,
199
+ cos_sin_cache. data_ptr < scalar_t >(), rot_dim, query_stride ,
200
+ key_stride, num_heads, num_kv_heads, head_size);
193
201
}
194
202
});
195
203
}
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
204
212
// [num_tokens, num_heads * head_size] or
205
213
// [batch_size, seq_len, num_heads, head_size] or
206
214
// [num_tokens, num_heads, head_size]
207
- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
208
- // [num_tokens, num_kv_heads * head_size] or
209
- // [batch_size, seq_len, num_heads, head_size] or
210
- // [num_tokens, num_heads, head_size]
215
+ std::optional<torch::Tensor>
216
+ key, // null or
217
+ // [batch_size, seq_len, num_kv_heads * head_size] or
218
+ // [num_tokens, num_kv_heads * head_size] or
219
+ // [batch_size, seq_len, num_heads, head_size] or
220
+ // [num_tokens, num_heads, head_size]
211
221
int64_t head_size,
212
222
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
213
223
bool is_neox, int64_t rot_dim,
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
221
231
" cos_sin_cache_offsets" );
222
232
223
233
int positions_ndim = positions.dim ();
224
- // Make sure num_tokens dim is consistent across positions, query, and key.
234
+ // Make sure num_tokens dim is consistent across positions, query, and key
225
235
TORCH_CHECK (
226
236
positions_ndim == 1 || positions_ndim == 2 ,
227
237
" positions must have shape [num_tokens] or [batch_size, seq_len]" );
228
238
if (positions_ndim == 1 ) {
229
- TORCH_CHECK (
230
- query. size ( 0 ) == positions. size ( 0 ) && key. size (0 ) == positions.size (0 ),
231
- " query, key and positions must have the same number of tokens" );
239
+ TORCH_CHECK (query. size ( 0 ) == positions. size ( 0 ) &&
240
+ (!key. has_value () || key-> size (0 ) == positions.size (0 ) ),
241
+ " query, key and positions must have the same number of tokens" );
232
242
}
233
243
if (positions_ndim == 2 ) {
234
244
TORCH_CHECK (
235
245
query.size (0 ) == positions.size (0 ) &&
236
- key.size (0 ) == positions.size (0 ) &&
246
+ (! key.has_value () || key-> size (0 ) == positions.size (0 ) ) &&
237
247
query.size (1 ) == positions.size (1 ) &&
238
- key.size (1 ) == positions.size (1 ),
248
+ (! key.has_value () || key-> size (1 ) == positions.size (1 ) ),
239
249
" query, key and positions must have the same batch_size and seq_len" );
240
250
}
241
251
242
252
// Make sure head_size is valid for query and key
243
253
int query_hidden_size = query.numel () / num_tokens;
244
- int key_hidden_size = key.numel () / num_tokens;
254
+ int key_hidden_size = key.has_value () ? key-> numel () / num_tokens : 0 ;
245
255
TORCH_CHECK (query_hidden_size % head_size == 0 );
246
256
TORCH_CHECK (key_hidden_size % head_size == 0 );
247
257
248
258
// Make sure query and key have concistent number of heads
249
259
int num_heads = query_hidden_size / head_size;
250
- int num_kv_heads = key_hidden_size / head_size;
260
+ int num_kv_heads = key. has_value () ? key_hidden_size / head_size : num_heads ;
251
261
TORCH_CHECK (num_heads % num_kv_heads == 0 );
252
262
253
263
int seq_dim_idx = positions_ndim - 1 ;
254
264
int64_t query_stride = query.stride (seq_dim_idx);
255
- int64_t key_stride = key.stride (seq_dim_idx);
265
+ int64_t key_stride = key.has_value () ? key-> stride (seq_dim_idx) : 0 ;
256
266
257
267
dim3 grid (num_tokens);
258
268
dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
263
273
vllm::batched_rotary_embedding_kernel<scalar_t , true >
264
274
<<<grid, block, 0 , stream>>> (
265
275
positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
266
- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(),
276
+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
277
+ cos_sin_cache.data_ptr <scalar_t >(),
267
278
cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
268
279
key_stride, num_heads, num_kv_heads, head_size);
269
280
} else {
270
281
vllm::batched_rotary_embedding_kernel<scalar_t , false >
271
282
<<<grid, block, 0 , stream>>> (
272
283
positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
273
- key.data_ptr <scalar_t >(), cos_sin_cache.data_ptr <scalar_t >(),
284
+ key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
285
+ cos_sin_cache.data_ptr <scalar_t >(),
274
286
cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
275
287
key_stride, num_heads, num_kv_heads, head_size);
276
288
}
0 commit comments