Skip to content

Commit df4f43c

Browse files
sarckkmawong-amd
authored andcommitted
Make key optional for rotary embedding (vllm-project#17566)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 8ac9b3a commit df4f43c

File tree

10 files changed

+221
-151
lines changed

10 files changed

+221
-151
lines changed

csrc/cpu/pos_encoding.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ void rotary_embedding_impl(
99
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
1010
/// head_size] or [num_tokens, num_heads,
1111
/// head_size]
12-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
12+
scalar_t* __restrict__ key, // nullptr (optional) or
13+
// [batch_size, seq_len, num_kv_heads,
1314
// head_size] or [num_tokens, num_kv_heads,
1415
// head_size]
1516
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
8586
compute_loop(token_head, cache_ptr, query);
8687
}
8788

88-
for (int i = 0; i < num_kv_heads; ++i) {
89-
const int head_idx = i;
90-
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
91-
compute_loop(token_head, cache_ptr, key);
89+
if (key != nullptr) {
90+
for (int i = 0; i < num_kv_heads; ++i) {
91+
const int head_idx = i;
92+
const int64_t token_head =
93+
token_idx * key_stride + head_idx * head_size;
94+
compute_loop(token_head, cache_ptr, key);
95+
}
9296
}
9397
}
9498
}
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
100104
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
101105
/// head_size] or [num_tokens, num_heads,
102106
/// head_size]
103-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
107+
scalar_t* __restrict__ key, // nullptr (optional) or
108+
// [batch_size, seq_len, num_kv_heads,
104109
// head_size] or [num_tokens, num_kv_heads,
105110
// head_size]
106111
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
138143
}
139144
}
140145

146+
if (key == nullptr) {
147+
return;
148+
}
149+
141150
#pragma omp parallel for collapse(2)
142151
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
143152
for (int i = 0; i < num_kv_heads; ++i) {
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
168177
}; // namespace
169178

170179
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
171-
torch::Tensor& key, int64_t head_size,
180+
std::optional<torch::Tensor> key, int64_t head_size,
172181
torch::Tensor& cos_sin_cache, bool is_neox) {
173182
int num_tokens = positions.numel();
174183
int rot_dim = cos_sin_cache.size(1);
175184
int num_heads = query.size(-1) / head_size;
176-
int num_kv_heads = key.size(-1) / head_size;
177-
int64_t key_stride = key.stride(-2);
185+
int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
186+
int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
178187
int64_t query_stride = query.stride(-2);
179188

180189
VLLM_DISPATCH_FLOATING_TYPES(
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
183192
if (is_neox) {
184193
rotary_embedding_impl(
185194
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
186-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
187-
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
188-
head_size, num_tokens);
195+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
196+
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
197+
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
189198
} else {
190199
rotary_embedding_gptj_impl(
191200
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
192-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
193-
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
194-
head_size, num_tokens);
201+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
202+
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
203+
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
195204
}
196205

197206
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)

csrc/cpu/torch_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
117117
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
118118
ops.def(
119119
"rotary_embedding(Tensor positions, Tensor! query,"
120-
" Tensor! key, int head_size,"
120+
" Tensor!? key, int head_size,"
121121
" Tensor cos_sin_cache, bool is_neox) -> ()");
122122
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
123123

csrc/ops.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
8686
std::optional<torch::Tensor> residual);
8787

8888
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
89-
torch::Tensor& key, int64_t head_size,
89+
std::optional<torch::Tensor> key, int64_t head_size,
9090
torch::Tensor& cos_sin_cache, bool is_neox);
9191

9292
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
93-
torch::Tensor& key, int64_t head_size,
94-
torch::Tensor& cos_sin_cache, bool is_neox,
95-
int64_t rot_dim,
93+
std::optional<torch::Tensor> key,
94+
int64_t head_size, torch::Tensor& cos_sin_cache,
95+
bool is_neox, int64_t rot_dim,
9696
torch::Tensor& cos_sin_cache_offsets);
9797

9898
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

csrc/pos_encoding_kernels.cu

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
3838
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
3939
// head_size] or [num_tokens, num_heads,
4040
// head_size]
41-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
41+
scalar_t* __restrict__ key, // nullptr or
42+
// [batch_size, seq_len, num_kv_heads,
4243
// head_size] or [num_tokens, num_kv_heads,
4344
// head_size]
4445
const scalar_t* cache_ptr, const int head_size, const int num_heads,
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
5758
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
5859
}
5960

60-
const int nk = num_kv_heads * embed_dim;
61-
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
62-
const int head_idx = i / embed_dim;
63-
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
64-
const int rot_offset = i % embed_dim;
65-
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
66-
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
61+
if (key != nullptr) {
62+
const int nk = num_kv_heads * embed_dim;
63+
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
64+
const int head_idx = i / embed_dim;
65+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
66+
const int rot_offset = i % embed_dim;
67+
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
68+
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
69+
}
6770
}
6871
}
6972

@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
7477
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
7578
// head_size] or [num_tokens, num_heads,
7679
// head_size]
77-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
80+
scalar_t* __restrict__ key, // nullptr or
81+
// [batch_size, seq_len, num_kv_heads,
7882
// head_size] or [num_tokens, num_kv_heads,
7983
// head_size]
8084
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
98102
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
99103
// head_size] or [num_tokens, num_heads,
100104
// head_size]
101-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
105+
scalar_t* __restrict__ key, // nullptr or
106+
// [batch_size, seq_len, num_kv_heads,
102107
// head_size] or [num_tokens, num_kv_heads,
103108
// head_size]
104109
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -127,51 +132,53 @@ void rotary_embedding(
127132
// [num_tokens, num_heads * head_size] or
128133
// [batch_size, seq_len, num_heads, head_size] or
129134
// [num_tokens, num_heads, head_size]
130-
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
131-
// [num_tokens, num_kv_heads * head_size] or
132-
// [batch_size, seq_len, num_heads, head_size] or
133-
// [num_tokens, num_heads, head_size]
135+
std::optional<torch::Tensor> key,
136+
// null or
137+
// [batch_size, seq_len, num_kv_heads * head_size] or
138+
// [num_tokens, num_kv_heads * head_size] or
139+
// [batch_size, seq_len, num_heads, head_size] or
140+
// [num_tokens, num_heads, head_size]
134141
int64_t head_size,
135142
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
136143
bool is_neox) {
137144
// num_tokens = batch_size * seq_len
138145
int64_t num_tokens = positions.numel();
139146
int positions_ndim = positions.dim();
140147

141-
// Make sure num_tokens dim is consistent across positions, query, and key.
148+
// Make sure num_tokens dim is consistent across positions, query, and key
142149
TORCH_CHECK(
143150
positions_ndim == 1 || positions_ndim == 2,
144151
"positions must have shape [num_tokens] or [batch_size, seq_len]");
145152
if (positions_ndim == 1) {
146-
TORCH_CHECK(
147-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
148-
"query, key and positions must have the same number of tokens");
153+
TORCH_CHECK(query.size(0) == positions.size(0) &&
154+
(!key.has_value() || key->size(0) == positions.size(0)),
155+
"query, key and positions must have the same number of tokens");
149156
}
150157
if (positions_ndim == 2) {
151158
TORCH_CHECK(
152159
query.size(0) == positions.size(0) &&
153-
key.size(0) == positions.size(0) &&
160+
(!key.has_value() || key->size(0) == positions.size(0)) &&
154161
query.size(1) == positions.size(1) &&
155-
key.size(1) == positions.size(1),
162+
(!key.has_value() || key->size(1) == positions.size(1)),
156163
"query, key and positions must have the same batch_size and seq_len");
157164
}
158165

159166
// Make sure head_size is valid for query and key
160167
// hidden_size = num_heads * head_size
161168
int query_hidden_size = query.numel() / num_tokens;
162-
int key_hidden_size = key.numel() / num_tokens;
169+
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
163170
TORCH_CHECK(query_hidden_size % head_size == 0);
164171
TORCH_CHECK(key_hidden_size % head_size == 0);
165172

166173
// Make sure query and key have consistent number of heads
167174
int num_heads = query_hidden_size / head_size;
168-
int num_kv_heads = key_hidden_size / head_size;
175+
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
169176
TORCH_CHECK(num_heads % num_kv_heads == 0);
170177

171178
int rot_dim = cos_sin_cache.size(1);
172179
int seq_dim_idx = positions_ndim - 1;
173180
int64_t query_stride = query.stride(seq_dim_idx);
174-
int64_t key_stride = key.stride(seq_dim_idx);
181+
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
175182

176183
dim3 grid(num_tokens);
177184
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -181,15 +188,16 @@ void rotary_embedding(
181188
if (is_neox) {
182189
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
183190
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
184-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
185-
query_stride, key_stride, num_heads, num_kv_heads, head_size);
191+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
192+
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
193+
num_heads, num_kv_heads, head_size);
186194
} else {
187195
vllm::rotary_embedding_kernel<scalar_t, false>
188196
<<<grid, block, 0, stream>>>(
189197
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
190-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
191-
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
192-
head_size);
198+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
199+
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
200+
key_stride, num_heads, num_kv_heads, head_size);
193201
}
194202
});
195203
}
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
204212
// [num_tokens, num_heads * head_size] or
205213
// [batch_size, seq_len, num_heads, head_size] or
206214
// [num_tokens, num_heads, head_size]
207-
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
208-
// [num_tokens, num_kv_heads * head_size] or
209-
// [batch_size, seq_len, num_heads, head_size] or
210-
// [num_tokens, num_heads, head_size]
215+
std::optional<torch::Tensor>
216+
key, // null or
217+
// [batch_size, seq_len, num_kv_heads * head_size] or
218+
// [num_tokens, num_kv_heads * head_size] or
219+
// [batch_size, seq_len, num_heads, head_size] or
220+
// [num_tokens, num_heads, head_size]
211221
int64_t head_size,
212222
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
213223
bool is_neox, int64_t rot_dim,
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
221231
"cos_sin_cache_offsets");
222232

223233
int positions_ndim = positions.dim();
224-
// Make sure num_tokens dim is consistent across positions, query, and key.
234+
// Make sure num_tokens dim is consistent across positions, query, and key
225235
TORCH_CHECK(
226236
positions_ndim == 1 || positions_ndim == 2,
227237
"positions must have shape [num_tokens] or [batch_size, seq_len]");
228238
if (positions_ndim == 1) {
229-
TORCH_CHECK(
230-
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
231-
"query, key and positions must have the same number of tokens");
239+
TORCH_CHECK(query.size(0) == positions.size(0) &&
240+
(!key.has_value() || key->size(0) == positions.size(0)),
241+
"query, key and positions must have the same number of tokens");
232242
}
233243
if (positions_ndim == 2) {
234244
TORCH_CHECK(
235245
query.size(0) == positions.size(0) &&
236-
key.size(0) == positions.size(0) &&
246+
(!key.has_value() || key->size(0) == positions.size(0)) &&
237247
query.size(1) == positions.size(1) &&
238-
key.size(1) == positions.size(1),
248+
(!key.has_value() || key->size(1) == positions.size(1)),
239249
"query, key and positions must have the same batch_size and seq_len");
240250
}
241251

242252
// Make sure head_size is valid for query and key
243253
int query_hidden_size = query.numel() / num_tokens;
244-
int key_hidden_size = key.numel() / num_tokens;
254+
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
245255
TORCH_CHECK(query_hidden_size % head_size == 0);
246256
TORCH_CHECK(key_hidden_size % head_size == 0);
247257

248258
// Make sure query and key have concistent number of heads
249259
int num_heads = query_hidden_size / head_size;
250-
int num_kv_heads = key_hidden_size / head_size;
260+
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
251261
TORCH_CHECK(num_heads % num_kv_heads == 0);
252262

253263
int seq_dim_idx = positions_ndim - 1;
254264
int64_t query_stride = query.stride(seq_dim_idx);
255-
int64_t key_stride = key.stride(seq_dim_idx);
265+
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
256266

257267
dim3 grid(num_tokens);
258268
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
263273
vllm::batched_rotary_embedding_kernel<scalar_t, true>
264274
<<<grid, block, 0, stream>>>(
265275
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
266-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
276+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
277+
cos_sin_cache.data_ptr<scalar_t>(),
267278
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
268279
key_stride, num_heads, num_kv_heads, head_size);
269280
} else {
270281
vllm::batched_rotary_embedding_kernel<scalar_t, false>
271282
<<<grid, block, 0, stream>>>(
272283
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
273-
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
284+
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
285+
cos_sin_cache.data_ptr<scalar_t>(),
274286
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
275287
key_stride, num_heads, num_kv_heads, head_size);
276288
}

csrc/torch_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
176176
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
177177
ops.def(
178178
"rotary_embedding(Tensor positions, Tensor! query,"
179-
" Tensor! key, int head_size,"
179+
" Tensor!? key, int head_size,"
180180
" Tensor cos_sin_cache, bool is_neox) -> ()");
181181
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
182182

183183
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
184184
// (supports multiple loras).
185185
ops.def(
186186
"batched_rotary_embedding(Tensor positions, Tensor! query,"
187-
" Tensor! key, int head_size,"
187+
" Tensor!? key, int head_size,"
188188
" Tensor cos_sin_cache, bool is_neox,"
189189
" int rot_dim,"
190190
" Tensor cos_sin_cache_offsets) -> ()");

0 commit comments

Comments
 (0)