|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | + |
| 6 | +import vllm.envs as envs |
| 7 | +from vllm.compilation.fix_functionalization import FixFunctionalizationPass |
| 8 | +from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, |
| 9 | + find_specified_fn, |
| 10 | + find_specified_fn_maybe, is_func) |
| 11 | +from vllm.compilation.sequence_parallelism import SequenceParallelismPass |
| 12 | +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, |
| 13 | + VllmConfig) |
| 14 | +from vllm.distributed import tensor_model_parallel_all_reduce |
| 15 | +from vllm.distributed.parallel_state import (init_distributed_environment, |
| 16 | + initialize_model_parallel) |
| 17 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 18 | +from vllm.platforms import current_platform |
| 19 | +from vllm.utils import update_environment_variables |
| 20 | + |
| 21 | +from ..utils import multi_gpu_test |
| 22 | +from .backend import TestBackend |
| 23 | + |
| 24 | +OPS_IN_MODEL_BEFORE = [ |
| 25 | + torch.ops.vllm.all_reduce.default, |
| 26 | +] |
| 27 | + |
| 28 | +OPS_IN_MODEL_AFTER = [ |
| 29 | + torch.ops.vllm.reduce_scatter.default, |
| 30 | + torch.ops.vllm.all_gather.default, |
| 31 | +] |
| 32 | + |
| 33 | +OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] |
| 34 | + |
| 35 | +prompts = [ |
| 36 | + "Hello, my name is", |
| 37 | + "The president of the United States is", |
| 38 | + "The capital of France is", |
| 39 | + "The future of AI is", |
| 40 | +] |
| 41 | + |
| 42 | + |
| 43 | +class TestModel(torch.nn.Module): |
| 44 | + |
| 45 | + def __init__(self, hidden_size=16, intermediate_size=32): |
| 46 | + super().__init__() |
| 47 | + self.hidden_size = hidden_size |
| 48 | + self.intermediate_size = intermediate_size |
| 49 | + self.gate_proj = torch.nn.Parameter( |
| 50 | + torch.empty((intermediate_size, hidden_size))) |
| 51 | + self.norm = RMSNorm(hidden_size, 1e-05) |
| 52 | + # Initialize weights |
| 53 | + torch.nn.init.normal_(self.gate_proj, std=0.02) |
| 54 | + |
| 55 | + def forward(self, hidden_states, residual): |
| 56 | + """ |
| 57 | + Forward pass implementing the operations in the FX graph |
| 58 | + |
| 59 | + Args: |
| 60 | + hidden_states: Input tensor |
| 61 | + residual: Residual tensor from previous layer |
| 62 | + |
| 63 | + Returns: |
| 64 | + Tuple containing the output tensor |
| 65 | + """ |
| 66 | + # Reshape input |
| 67 | + view = hidden_states.reshape(-1, self.hidden_size) |
| 68 | + |
| 69 | + #matrix multiplication |
| 70 | + permute = self.gate_proj.permute(1, 0) |
| 71 | + mm = torch.mm(view, permute) |
| 72 | + |
| 73 | + # Tensor parallel all-reduce |
| 74 | + all_reduce = tensor_model_parallel_all_reduce(mm) |
| 75 | + |
| 76 | + # layer normalization |
| 77 | + norm_output, residual_output = self.norm(all_reduce, residual) |
| 78 | + |
| 79 | + return norm_output, residual_output |
| 80 | + |
| 81 | + |
| 82 | +@multi_gpu_test(num_gpus=2) |
| 83 | +@pytest.mark.parametrize("batch_size", [8]) |
| 84 | +@pytest.mark.parametrize("seq_len", [16]) |
| 85 | +@pytest.mark.parametrize("hidden_size", [16]) |
| 86 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 87 | +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], |
| 88 | + reason="Only test on CUDA") |
| 89 | +def test_sequence_parallelism_pass(batch_size: int, seq_len: int, |
| 90 | + hidden_size: int, dtype: torch.dtype): |
| 91 | + num_processes = 2 |
| 92 | + |
| 93 | + def run_torch_spawn(fn, nprocs): |
| 94 | + # need to use torch.mp.spawn otherwise will have problems with |
| 95 | + # torch.distributed and cuda |
| 96 | + torch.multiprocessing.spawn(fn, |
| 97 | + args=(num_processes, batch_size, seq_len, |
| 98 | + hidden_size, dtype), |
| 99 | + nprocs=nprocs) |
| 100 | + |
| 101 | + run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) |
| 102 | + |
| 103 | + |
| 104 | +def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, |
| 105 | + batch_size: int, seq_len: int, |
| 106 | + hidden_size: int, |
| 107 | + dtype: torch.dtype): |
| 108 | + current_platform.seed_everything(0) |
| 109 | + |
| 110 | + device = torch.device(f"cuda:{local_rank}") |
| 111 | + torch.cuda.set_device(device) |
| 112 | + torch.set_default_device(device) |
| 113 | + torch.set_default_dtype(dtype) |
| 114 | + |
| 115 | + update_environment_variables({ |
| 116 | + 'RANK': str(local_rank), |
| 117 | + 'LOCAL_RANK': str(local_rank), |
| 118 | + 'WORLD_SIZE': str(world_size), |
| 119 | + 'MASTER_ADDR': 'localhost', |
| 120 | + 'MASTER_PORT': '12345', |
| 121 | + }) |
| 122 | + |
| 123 | + # initialize distributed |
| 124 | + init_distributed_environment() |
| 125 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 126 | + |
| 127 | + # configure vllm config for SequenceParallelismPass |
| 128 | + vllm_config = VllmConfig() |
| 129 | + vllm_config.compilation_config = CompilationConfig( |
| 130 | + pass_config=CompilationConfig.PassConfig( |
| 131 | + enable_sequence_parallelism=True, ), ) |
| 132 | + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) |
| 133 | + |
| 134 | + # this is a fake model name to construct the model config |
| 135 | + # in the vllm_config, it's not really used. |
| 136 | + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" |
| 137 | + vllm_config.model_config = ModelConfig(model=model, |
| 138 | + task="auto", |
| 139 | + tokenizer=model, |
| 140 | + tokenizer_mode="auto", |
| 141 | + trust_remote_code=True, |
| 142 | + dtype=dtype, |
| 143 | + seed=42) |
| 144 | + |
| 145 | + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) |
| 146 | + backend_no_func = TestBackend(sequence_parallelism_pass) |
| 147 | + func_pass = FixFunctionalizationPass(vllm_config) |
| 148 | + backend_func = TestBackend(sequence_parallelism_pass, func_pass) |
| 149 | + |
| 150 | + model = TestModel(hidden_size, hidden_size * 2) |
| 151 | + hidden_states = torch.randn((batch_size * seq_len, hidden_size), |
| 152 | + dtype=dtype) |
| 153 | + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) |
| 154 | + |
| 155 | + compiled_model_no_func = torch.compile(model, backend=backend_no_func) |
| 156 | + compiled_model_no_func(hidden_states, residual) |
| 157 | + compiled_model_func = torch.compile(model, backend=backend_func) |
| 158 | + compiled_model_func(hidden_states, residual) |
| 159 | + |
| 160 | + # Check substitution worked |
| 161 | + pre_nodes = backend_no_func.graph_pre_pass.nodes |
| 162 | + post_nodes = backend_no_func.graph_post_pass.nodes |
| 163 | + |
| 164 | + # In pre-nodes, all reduce should be there, |
| 165 | + # reduce scatter and all gather should not |
| 166 | + for op in OPS_IN_MODEL_BEFORE: |
| 167 | + find_specified_fn(pre_nodes, op) |
| 168 | + for op in OPS_IN_MODEL_AFTER: |
| 169 | + assert find_specified_fn_maybe(pre_nodes, op) is None |
| 170 | + |
| 171 | + # In post-nodes, reduce scatter and all gather should be there, |
| 172 | + # all reduce should not |
| 173 | + for op in OPS_IN_MODEL_AFTER: |
| 174 | + find_specified_fn(post_nodes, op) |
| 175 | + for op in OPS_IN_MODEL_BEFORE: |
| 176 | + assert find_specified_fn_maybe(post_nodes, op) is None |
| 177 | + |
| 178 | + # check if the functionalization pass is applied |
| 179 | + for op in OPS_IN_MODEL: |
| 180 | + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) |
| 181 | + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, |
| 182 | + op) is None # noqa: E501 |
| 183 | + |
| 184 | + # make sure the ops were all de-functionalized |
| 185 | + found = dict() |
| 186 | + for node in backend_func.graph_post_pass.nodes: |
| 187 | + for op in OPS_IN_MODEL: |
| 188 | + if is_func(node, op): |
| 189 | + found[op] = True |
| 190 | + assert all(found[op] for op in OPS_IN_MODEL) |
0 commit comments