Skip to content

Commit 497a460

Browse files
cascade812tlrmchlsmth
authored andcommitted
[Feature] support sequence parallelism using compilation pass (vllm-project#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent 8603e06 commit 497a460

21 files changed

+1072
-44
lines changed

.buildkite/test-pipeline.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ steps:
299299
commands:
300300
- pytest -v -s compile/test_pass_manager.py
301301
- pytest -v -s compile/test_fusion.py
302+
- pytest -v -s compile/test_sequence_parallelism.py
302303

303304
- label: PyTorch Fullgraph Smoke Test # 9min
304305
source_file_dependencies:
@@ -583,6 +584,8 @@ steps:
583584
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
584585
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
585586
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
587+
# test sequence parallel
588+
- pytest -v -s distributed/test_sequence_parallel.py
586589
# this test fails consistently.
587590
# TODO: investigate and fix
588591
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py

tests/compile/test_functionalization.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
kFp8DynamicTokenSym, kFp8StaticTensorSym)
1111
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1212
from vllm.compilation.noop_elimination import NoOpEliminationPass
13-
from vllm.config import CompilationConfig
13+
from vllm.config import CompilationConfig, VllmConfig
1414

1515
from .backend import TestBackend
1616

@@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
4949
do_fusion: bool):
5050
torch.set_default_device("cuda")
5151

52-
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
53-
enable_noop=True)
54-
noop_pass = NoOpEliminationPass(config)
55-
fusion_pass = FusionPass.instance(config)
52+
vllm_config = VllmConfig()
53+
vllm_config.compilation_config = CompilationConfig(pass_config= \
54+
CompilationConfig.PassConfig(enable_fusion=do_fusion,
55+
enable_noop=True))
56+
noop_pass = NoOpEliminationPass(vllm_config)
57+
fusion_pass = FusionPass.instance(vllm_config)
5658

5759
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
58-
func_pass = FixFunctionalizationPass(config)
60+
func_pass = FixFunctionalizationPass(vllm_config)
5961
backend_func = TestBackend(*passes, func_pass)
6062
backend_no_func = TestBackend(*passes)
6163

tests/compile/test_fusion.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
7777

7878
vllm_config = VllmConfig(compilation_config=CompilationConfig(
7979
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
80+
vllm_config.compilation_config.pass_config = \
81+
CompilationConfig.PassConfig(enable_fusion=True,
82+
enable_noop=True)
8083
with vllm.config.set_current_vllm_config(vllm_config):
8184
# Reshape pass is needed for the fusion pass to work
82-
config = CompilationConfig.PassConfig(enable_fusion=True,
83-
enable_noop=True)
84-
noop_pass = NoOpEliminationPass(config)
85-
fusion_pass = FusionPass.instance(config)
85+
noop_pass = NoOpEliminationPass(vllm_config)
86+
fusion_pass = FusionPass.instance(vllm_config)
8687

8788
backend = TestBackend(noop_pass, fusion_pass)
8889
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

tests/compile/test_pass_manager.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
88
from vllm.compilation.pass_manager import PostGradPassManager
9-
from vllm.config import CompilationConfig
9+
from vllm.config import VllmConfig
1010

1111

1212
# dummy custom pass that doesn't inherit
@@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
1616

1717
# Should fail to add directly to the pass manager
1818
def test_bad_callable():
19-
config = CompilationConfig().pass_config
19+
config = VllmConfig()
2020

2121
pass_manager = PostGradPassManager()
2222
pass_manager.configure(config)
@@ -43,7 +43,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None:
4343
],
4444
)
4545
def test_pass_manager_uuid(callable):
46-
config = CompilationConfig().pass_config
46+
config = VllmConfig()
4747

4848
pass_manager = PostGradPassManager()
4949
pass_manager.configure(config)
@@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
6464

6565
# UUID should be different due to config change
6666
config2 = copy.deepcopy(config)
67-
config2.enable_fusion = not config2.enable_fusion
67+
config2.compilation_config.pass_config.enable_fusion = not \
68+
config2.compilation_config.pass_config.enable_fusion
6869
pass_manager3 = PostGradPassManager()
6970
pass_manager3.configure(config2)
7071
pass_manager3.add(callable)
+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
import vllm.envs as envs
7+
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
8+
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
9+
find_specified_fn,
10+
find_specified_fn_maybe, is_func)
11+
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
12+
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
13+
VllmConfig)
14+
from vllm.distributed import tensor_model_parallel_all_reduce
15+
from vllm.distributed.parallel_state import (init_distributed_environment,
16+
initialize_model_parallel)
17+
from vllm.model_executor.layers.layernorm import RMSNorm
18+
from vllm.platforms import current_platform
19+
from vllm.utils import update_environment_variables
20+
21+
from ..utils import multi_gpu_test
22+
from .backend import TestBackend
23+
24+
OPS_IN_MODEL_BEFORE = [
25+
torch.ops.vllm.all_reduce.default,
26+
]
27+
28+
OPS_IN_MODEL_AFTER = [
29+
torch.ops.vllm.reduce_scatter.default,
30+
torch.ops.vllm.all_gather.default,
31+
]
32+
33+
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
34+
35+
prompts = [
36+
"Hello, my name is",
37+
"The president of the United States is",
38+
"The capital of France is",
39+
"The future of AI is",
40+
]
41+
42+
43+
class TestModel(torch.nn.Module):
44+
45+
def __init__(self, hidden_size=16, intermediate_size=32):
46+
super().__init__()
47+
self.hidden_size = hidden_size
48+
self.intermediate_size = intermediate_size
49+
self.gate_proj = torch.nn.Parameter(
50+
torch.empty((intermediate_size, hidden_size)))
51+
self.norm = RMSNorm(hidden_size, 1e-05)
52+
# Initialize weights
53+
torch.nn.init.normal_(self.gate_proj, std=0.02)
54+
55+
def forward(self, hidden_states, residual):
56+
"""
57+
Forward pass implementing the operations in the FX graph
58+
59+
Args:
60+
hidden_states: Input tensor
61+
residual: Residual tensor from previous layer
62+
63+
Returns:
64+
Tuple containing the output tensor
65+
"""
66+
# Reshape input
67+
view = hidden_states.reshape(-1, self.hidden_size)
68+
69+
#matrix multiplication
70+
permute = self.gate_proj.permute(1, 0)
71+
mm = torch.mm(view, permute)
72+
73+
# Tensor parallel all-reduce
74+
all_reduce = tensor_model_parallel_all_reduce(mm)
75+
76+
# layer normalization
77+
norm_output, residual_output = self.norm(all_reduce, residual)
78+
79+
return norm_output, residual_output
80+
81+
82+
@multi_gpu_test(num_gpus=2)
83+
@pytest.mark.parametrize("batch_size", [8])
84+
@pytest.mark.parametrize("seq_len", [16])
85+
@pytest.mark.parametrize("hidden_size", [16])
86+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
87+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
88+
reason="Only test on CUDA")
89+
def test_sequence_parallelism_pass(batch_size: int, seq_len: int,
90+
hidden_size: int, dtype: torch.dtype):
91+
num_processes = 2
92+
93+
def run_torch_spawn(fn, nprocs):
94+
# need to use torch.mp.spawn otherwise will have problems with
95+
# torch.distributed and cuda
96+
torch.multiprocessing.spawn(fn,
97+
args=(num_processes, batch_size, seq_len,
98+
hidden_size, dtype),
99+
nprocs=nprocs)
100+
101+
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
102+
103+
104+
def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
105+
batch_size: int, seq_len: int,
106+
hidden_size: int,
107+
dtype: torch.dtype):
108+
current_platform.seed_everything(0)
109+
110+
device = torch.device(f"cuda:{local_rank}")
111+
torch.cuda.set_device(device)
112+
torch.set_default_device(device)
113+
torch.set_default_dtype(dtype)
114+
115+
update_environment_variables({
116+
'RANK': str(local_rank),
117+
'LOCAL_RANK': str(local_rank),
118+
'WORLD_SIZE': str(world_size),
119+
'MASTER_ADDR': 'localhost',
120+
'MASTER_PORT': '12345',
121+
})
122+
123+
# initialize distributed
124+
init_distributed_environment()
125+
initialize_model_parallel(tensor_model_parallel_size=world_size)
126+
127+
# configure vllm config for SequenceParallelismPass
128+
vllm_config = VllmConfig()
129+
vllm_config.compilation_config = CompilationConfig(
130+
pass_config=CompilationConfig.PassConfig(
131+
enable_sequence_parallelism=True, ), )
132+
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
133+
134+
# this is a fake model name to construct the model config
135+
# in the vllm_config, it's not really used.
136+
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
137+
vllm_config.model_config = ModelConfig(model=model,
138+
task="auto",
139+
tokenizer=model,
140+
tokenizer_mode="auto",
141+
trust_remote_code=True,
142+
dtype=dtype,
143+
seed=42)
144+
145+
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
146+
backend_no_func = TestBackend(sequence_parallelism_pass)
147+
func_pass = FixFunctionalizationPass(vllm_config)
148+
backend_func = TestBackend(sequence_parallelism_pass, func_pass)
149+
150+
model = TestModel(hidden_size, hidden_size * 2)
151+
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
152+
dtype=dtype)
153+
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
154+
155+
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
156+
compiled_model_no_func(hidden_states, residual)
157+
compiled_model_func = torch.compile(model, backend=backend_func)
158+
compiled_model_func(hidden_states, residual)
159+
160+
# Check substitution worked
161+
pre_nodes = backend_no_func.graph_pre_pass.nodes
162+
post_nodes = backend_no_func.graph_post_pass.nodes
163+
164+
# In pre-nodes, all reduce should be there,
165+
# reduce scatter and all gather should not
166+
for op in OPS_IN_MODEL_BEFORE:
167+
find_specified_fn(pre_nodes, op)
168+
for op in OPS_IN_MODEL_AFTER:
169+
assert find_specified_fn_maybe(pre_nodes, op) is None
170+
171+
# In post-nodes, reduce scatter and all gather should be there,
172+
# all reduce should not
173+
for op in OPS_IN_MODEL_AFTER:
174+
find_specified_fn(post_nodes, op)
175+
for op in OPS_IN_MODEL_BEFORE:
176+
assert find_specified_fn_maybe(post_nodes, op) is None
177+
178+
# check if the functionalization pass is applied
179+
for op in OPS_IN_MODEL:
180+
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
181+
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
182+
op) is None # noqa: E501
183+
184+
# make sure the ops were all de-functionalized
185+
found = dict()
186+
for node in backend_func.graph_post_pass.nodes:
187+
for op in OPS_IN_MODEL:
188+
if is_func(node, op):
189+
found[op] = True
190+
assert all(found[op] for op in OPS_IN_MODEL)

tests/distributed/test_comm_ops.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
1616
tensor_model_parallel_all_gather,
17-
tensor_model_parallel_all_reduce)
17+
tensor_model_parallel_all_reduce,
18+
tensor_model_parallel_reduce_scatter)
1819

1920
from ..utils import init_test_distributed_environment, multi_process_parallel
2021

@@ -47,6 +48,34 @@ def all_reduce_test_worker(
4748
torch.testing.assert_close(t, expected)
4849

4950

51+
@ray.remote(num_gpus=1, max_calls=1)
52+
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
53+
pp_size: int, rank: int,
54+
distributed_init_port: str):
55+
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
56+
# so that each worker can see all the GPUs
57+
# they will be able to set the device to the correct GPU
58+
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
59+
device = torch.device(f"cuda:{rank}")
60+
torch.cuda.set_device(device)
61+
init_test_distributed_environment(tp_size, pp_size, rank,
62+
distributed_init_port)
63+
64+
num_elements = 8
65+
all_tensors = [
66+
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
67+
(r + 1) for r in range(tp_size)
68+
]
69+
70+
index = rank % tp_size
71+
partition_size = num_elements // tp_size
72+
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
73+
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
74+
t = all_tensors[index]
75+
t = tensor_model_parallel_reduce_scatter(t, 0)
76+
torch.testing.assert_close(t, expected)
77+
78+
5079
@ray.remote(num_gpus=1, max_calls=1)
5180
def all_gather_test_worker(
5281
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)