Skip to content

Commit 8d6303f

Browse files
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <yexin93@qq.com> Co-authored-by: MagnetoWang <magnetowang@outlook.com>
1 parent 99ef59c commit 8d6303f

File tree

1 file changed

+46
-15
lines changed

1 file changed

+46
-15
lines changed

vllm/model_executor/layers/rotary_embedding.py

+46-15
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.utils import is_flash_attn_2_available
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
3435

36+
if is_flash_attn_2_available():
37+
from flash_attn.layers.rotary import apply_rotary_emb
38+
3539

3640
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
3741
x1 = x[..., :x.shape[-1] // 2]
@@ -46,20 +50,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4650
return x.flatten(-2)
4751

4852

49-
def _apply_rotary_emb(
53+
def _apply_rotary_emb_torch(
5054
x: torch.Tensor,
5155
cos: torch.Tensor,
5256
sin: torch.Tensor,
5357
is_neox_style: bool,
5458
) -> torch.Tensor:
55-
"""
56-
Args:
57-
x: [num_tokens, num_heads, head_size]
58-
cos: [num_tokens, head_size // 2]
59-
sin: [num_tokens, head_size // 2]
60-
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61-
positional embeddings.
62-
"""
6359
cos = cos.unsqueeze(-2).to(x.dtype)
6460
sin = sin.unsqueeze(-2).to(x.dtype)
6561
if is_neox_style:
@@ -75,6 +71,27 @@ def _apply_rotary_emb(
7571
return torch.stack((o1, o2), dim=-1).flatten(-2)
7672

7773

74+
def _apply_rotary_emb(x: torch.Tensor,
75+
cos: torch.Tensor,
76+
sin: torch.Tensor,
77+
is_neox_style: bool,
78+
use_flash_attn=False) -> torch.Tensor:
79+
"""
80+
Args:
81+
x: [num_tokens, num_heads, head_size]
82+
cos: [num_tokens, head_size // 2]
83+
sin: [num_tokens, head_size // 2]
84+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
85+
positional embeddings.
86+
use_flash_attn: Whether to enable Flash Attention optimizations.
87+
"""
88+
if use_flash_attn:
89+
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
90+
not is_neox_style).squeeze(0)
91+
else:
92+
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
93+
94+
7895
@CustomOp.register("rotary_embedding")
7996
class RotaryEmbedding(CustomOp):
8097
"""Original rotary positional embedding."""
@@ -100,6 +117,10 @@ def __init__(
100117
cache = cache.to(dtype)
101118
self.cos_sin_cache: torch.Tensor
102119
self.register_buffer("cos_sin_cache", cache, persistent=False)
120+
if is_flash_attn_2_available():
121+
self._use_flash_attn = True
122+
else:
123+
self._use_flash_attn = False
103124

104125
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
105126
"""Compute the inverse frequency."""
@@ -141,14 +162,16 @@ def forward_native(
141162
query = query.view(num_tokens, -1, self.head_size)
142163
query_rot = query[..., :self.rotary_dim]
143164
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
165+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style,
166+
self._use_flash_attn)
145167
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146168

147169
key_shape = key.shape
148170
key = key.view(num_tokens, -1, self.head_size)
149171
key_rot = key[..., :self.rotary_dim]
150172
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
173+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style,
174+
self._use_flash_attn)
152175
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153176
return query, key
154177

@@ -309,9 +332,11 @@ def _apply_rotary_emb_neuron(
309332
key = key.view(num_tokens, -1, self.head_size)
310333

311334
if self.rotary_dim == self.head_size:
312-
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
335+
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style,
336+
self._use_flash_attn)
313337
query = query.reshape(query_shape)
314-
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
338+
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style,
339+
self._use_flash_attn)
315340
key = key.reshape(key_shape)
316341
else:
317342
head_size = query.shape[-1]
@@ -938,6 +963,10 @@ def __init__(
938963
self.mrope_section = mrope_section
939964
if self.mrope_section:
940965
assert sum(self.mrope_section) == rotary_dim // 2
966+
if is_flash_attn_2_available():
967+
self._use_flash_attn = True
968+
else:
969+
self._use_flash_attn = False
941970

942971
def forward(
943972
self,
@@ -977,14 +1006,16 @@ def forward(
9771006
query = query.view(num_tokens, -1, self.head_size)
9781007
query_rot = query[..., :self.rotary_dim]
9791008
query_pass = query[..., self.rotary_dim:]
980-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1009+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style,
1010+
self._use_flash_attn)
9811011
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
9821012

9831013
key_shape = key.shape
9841014
key = key.view(num_tokens, -1, self.head_size)
9851015
key_rot = key[..., :self.rotary_dim]
9861016
key_pass = key[..., self.rotary_dim:]
987-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1017+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style,
1018+
self._use_flash_attn)
9881019
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
9891020
return query, key
9901021

0 commit comments

Comments
 (0)