Skip to content

Commit 2dbd8e5

Browse files
shihaobaijayfeather9rootshihaobairoot
authored
add support for multinode tp (#751)
Co-authored-by: wufeiyang <jayfeather9@qq.com> Co-authored-by: root <root@pt-290ac8041d114af0b1647509a5544872-master-0.pt-290ac8041d114af0b1647509a5544872.ns-devoversea-d41e68bd.svc.cluster.local> Co-authored-by: shihaobai <baishihao@sensetime.com> Co-authored-by: root <root@pt-511f450a52c24c2d9df9b20f0c8ebdb7-master-0.pt-511f450a52c24c2d9df9b20f0c8ebdb7.ns-devoversea-d41e68bd.svc.cluster.local> Co-authored-by: Feiyang Wu <wufeiyang@sensetime.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 4a974c2 commit 2dbd8e5

File tree

36 files changed

+608
-220
lines changed

36 files changed

+608
-220
lines changed

docs/CN/source/getting_started/quickstart.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@
5656
.. note::
5757
上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。
5858

59+
单机H200部署 DeepSeek-R1 模型, 启动命令如下:
60+
61+
.. code-block:: console
62+
63+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 8 --graph_max_batch_size 100
64+
65+
.. note::
66+
LOADWORKER 指定了模型加载的线程,可以提高模型加载的速度。--graph_max_batch_size 指定了要捕获的cudagraph的数量,将捕获从1到100的batch size的图。
67+
68+
双机H100部署 DeepSeek-R1 模型,启动命令如下:
69+
70+
.. code-block:: console
71+
$ # Node 0
72+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0
73+
$ # Node 1
74+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1
5975
6076
3. (可选)测试模型服务
6177
-------------------------
@@ -75,3 +91,10 @@
7591
$ }'
7692
7793
94+
对于DeepSeek-R1模型,可以用如下脚本进行测试:
95+
96+
.. code-block:: console
97+
98+
$ cd test
99+
$ python benchmark_client.py --num_clients 100 --input_num 2000 --tokenizer_path /nvme/DeepSeek-R1/ --url http://127.0.01:8000/generate_stream
100+

docs/EN/source/getting_started/quickstart.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ After downloading the Llama-2-7b-chat model, use the following command in the te
5353
.. note::
5454
The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine.
5555

56-
For the DeepSeek-R1 model on H200, it can be launched with the following command:
56+
For the DeepSeek-R1 model on single H200, it can be launched with the following command:
5757

5858
.. code-block:: console
5959
@@ -62,6 +62,14 @@ For the DeepSeek-R1 model on H200, it can be launched with the following command
6262
.. note::
6363
LOADWORKER specifies the thread for model loading, which can enhance the speed of model loading. The --graph_max_batch_size parameter specifies the number of cudagraphs to be captured, which will capture graphs for batch sizes ranging from 1 to 100.
6464

65+
For the DeepSeek-R1 model on two H100, it can be launched with the following command:
66+
67+
.. code-block:: console
68+
$ # Node 0
69+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 0
70+
$ # Node 1
71+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 16 --graph_max_batch_size 100 --nccl_host master_addr --nnodes 2 --node_rank 1
72+
6573
6674
3. (Optional) Test the Model Service
6775
--------------------------------------

lightllm/common/basemodel/layer_weights/base_layer_weight.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import threading
44
from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight
5+
from lightllm.utils.dist_utils import get_current_device_id
56

67

78
class BaseLayerWeight:
@@ -37,4 +38,4 @@ def _cuda(self, cpu_tensor):
3738
if self.tp_rank_ is None:
3839
return cpu_tensor.contiguous().to(self.data_type_).cuda()
3940
else:
40-
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
41+
return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id())

lightllm/common/basemodel/layer_weights/hf_load_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import gc
44
from safetensors import safe_open
55
import lightllm.utils.petrel_helper as utils
6+
from lightllm.utils.dist_utils import get_current_device_id
67

78

89
def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
910
# fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug
1011
import torch.distributed as dist
1112

12-
tp_rank = dist.get_rank()
13-
torch.cuda.set_device(tp_rank)
13+
torch.cuda.set_device(get_current_device_id())
1414

1515
if use_safetensors:
1616
weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from abc import ABC, abstractmethod
3-
from lightllm.utils.dist_utils import get_world_size, get_rank
4-
from lightllm.utils.device_utils import get_current_device_id
3+
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
54

65

76
class BaseWeight(ABC):
@@ -19,8 +18,8 @@ def verify_load(self):
1918

2019
class BaseWeightTpl(BaseWeight):
2120
def __init__(self):
22-
self.world_size_ = get_world_size()
23-
self.tp_rank_ = get_rank()
21+
self.world_size_ = get_global_world_size()
22+
self.tp_rank_ = get_global_rank()
2423
self.device_id_ = get_current_device_id()
2524

2625
def load_hf_weights(self, weights):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from .base_weight import BaseWeight
66
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
77
from lightllm.common.quantization.quantize_method import QuantizationMethod
8-
from lightllm.utils.dist_utils import get_world_size, get_rank
8+
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
99
from lightllm.common.vllm_kernel import _custom_ops as ops
10-
from lightllm.utils.device_utils import get_current_device_id
1110

1211

1312
class FusedMoeWeight(BaseWeight):
@@ -39,7 +38,7 @@ def __init__(
3938
self.n_routed_experts = n_routed_experts
4039
self.split_inter_size = split_inter_size
4140
self.data_type_ = data_type
42-
self.tp_rank_ = get_rank()
41+
self.tp_rank_ = get_global_rank()
4342
self.experts_up_projs = [None] * self.n_routed_experts
4443
self.experts_gate_projs = [None] * self.n_routed_experts
4544
self.experts_up_proj_scales = [None] * self.n_routed_experts
@@ -159,7 +158,7 @@ def _fuse_weight_scale(self):
159158
delattr(self, "experts_gate_proj_scales")
160159

161160
def _load_hf_weights_etp(self, weights):
162-
world_size_ = get_world_size()
161+
world_size_ = get_global_world_size()
163162
assert self.n_routed_experts % world_size_ == 0
164163
n_expert_ep = self.n_routed_experts // world_size_
165164

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Tuple, List, Dict, Any
55
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
66
from lightllm.common.quantization.quantize_method import QuantizationMethod
7+
from lightllm.utils.dist_utils import get_current_device_id
78

89

910
def generate_scale_name(name, weight_scale_suffix, act_scale_suffix):
@@ -73,20 +74,17 @@ def _post_load_weights(self) -> None:
7374
and (not self.static_activation or self.input_scale is not None)
7475
):
7576
if self.weight_scale.ndim > 1:
76-
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
77-
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
77+
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(get_current_device_id())
7878
self.weight = [
79-
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
80-
self.weight.cuda(self.device_id_).transpose(0, 1),
79+
self.weight.cuda(get_current_device_id()).transpose(0, 1),
8180
self.weight_scale,
8281
self.input_scale,
8382
]
8483
else:
85-
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
84+
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(get_current_device_id()))
8685
return
87-
8886
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
89-
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)
87+
self.weight = self.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1)
9088

9189

9290
class MMWeight(MMWeightTpl):
@@ -133,7 +131,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
133131
self.weight = weight[self.start : self.end]
134132
if self.bias_name in weights:
135133
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
136-
self.bias = bias.cuda(self.device_id_)
134+
self.bias = bias.cuda(get_current_device_id())
137135

138136
if self.weight_scale_name is not None and self.weight_scale_name in weights:
139137
block_size = 1
@@ -154,7 +152,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
154152

155153
if self.act_scale_name is not None and self.act_scale_name in weights:
156154
input_scale = weights[self.act_scale_name].to(torch.float)
157-
self.input_scale = input_scale.cuda()
155+
self.input_scale = input_scale.cuda(get_current_device_id())
158156

159157
if weight is None and weight_scale is None and input_scale is None:
160158
return
@@ -198,7 +196,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
198196
self.weight = weight[:, self.start : self.end]
199197
if self.bias_name in weights:
200198
bias = weights[self.bias_name]
201-
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_)
199+
self.bias = (bias / self.world_size_).to(self.data_type_).cuda(get_current_device_id())
202200

203201
if self.quantized_weight and self.weight_scale_name in weights:
204202
block_size = 1
@@ -216,7 +214,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
216214

217215
if self.static_activation and self.act_scale_name in weights:
218216
input_scale = weights[self.act_scale_name].to(torch.float)
219-
self.input_scale = input_scale.cuda()
217+
self.input_scale = input_scale.cuda(get_current_device_id())
220218

221219
if weight is None and weight_scale is None and input_scale is None:
222220
return
@@ -294,19 +292,19 @@ def _fuse(self) -> None:
294292
delattr(self, "weights")
295293

296294
if self.weight_scale is None and (None not in self.weight_scales):
297-
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda()
295+
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id())
298296
self._post_load_weights()
299297
delattr(self, "weight_scales")
300298

301299
if self.static_activation and self.input_scale is None and (None not in self.input_scales):
302300
input_scales = torch.stack(self.input_scales, dim=0)
303-
self.input_scale = torch.max(input_scales).cuda()
301+
self.input_scale = torch.max(input_scales).cuda(get_current_device_id())
304302
self._post_load_weights()
305303
delattr(self, "input_scales")
306304

307305
if self.has_bias:
308306
if self.bias is None and (None not in self.biases):
309-
self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_)
307+
self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id())
310308
delattr(self, "biases")
311309
return self
312310

@@ -449,10 +447,10 @@ def _post_load_weights(self) -> None:
449447
and (not self.static_activation or self.input_scale is not None)
450448
):
451449
if self.weight_scale.ndim > 1:
452-
self.weight_scale = self.weight_scale.cuda(self.device_id_)
453-
self.weight = [self.weight.cuda(self.device_id_), self.weight_scale, self.input_scale]
450+
self.weight_scale = self.weight_scale.cuda(get_current_device_id())
451+
self.weight = [self.weight.cuda(get_current_device_id()), self.weight_scale, self.input_scale]
454452
return
455-
self.weight = self.weight.cuda(self.device_id_)
453+
self.weight = self.weight.cuda(get_current_device_id())
456454

457455

458456
class BMMWeight(BMMWeightTpl):
@@ -518,7 +516,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
518516
self.weight = weight[self.start : self.end]
519517
if self.bias_name in weights:
520518
bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end]
521-
self.bias = bias.cuda(self.device_id_)
519+
self.bias = bias.cuda(get_current_device_id())
522520

523521
if self.weight_scale_name is not None and self.weight_scale_name in weights:
524522
weight_scale = weights[self.weight_scale_name]
@@ -532,7 +530,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
532530

533531
if self.act_scale_name is not None and self.act_scale_name in weights:
534532
input_scale = weights[self.act_scale_name].to(torch.float)
535-
self.input_scale = input_scale.cuda()
533+
self.input_scale = input_scale.cuda(get_current_device_id())
536534

537535
if weight is None and weight_scale is None and input_scale is None:
538536
return

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from .base_weight import BaseWeightTpl
3+
from lightllm.utils.dist_utils import get_current_device_id
34

45

56
class NormWeight(BaseWeightTpl):
@@ -13,9 +14,9 @@ def __init__(self, weight_name, data_type, bias_name=None):
1314

1415
def load_hf_weights(self, weights):
1516
if self.weight_name in weights:
16-
self.weight = weights[self.weight_name].to(self.data_type_).cuda(self.device_id_)
17+
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
1718
if self.bias_name in weights:
18-
self.bias = weights[self.bias_name].to(self.data_type_).cuda(self.device_id_)
19+
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())
1920

2021
def verify_load(self):
2122
load_ok = True
@@ -33,7 +34,7 @@ def __init__(self, weight_name, data_type, bias_name=None):
3334

3435
def load_hf_weights(self, weights):
3536
if self.weight_name in weights:
36-
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(self.device_id_)
37+
self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id())
3738

3839

3940
class TpNormWeight(NormWeight):
@@ -46,6 +47,6 @@ def load_hf_weights(self, weights):
4647
end = self.split_n_embed * (self.tp_rank_ + 1)
4748

4849
if self.weight_name in weights:
49-
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(self.device_id_)
50+
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
5051
if self.bias_name in weights:
51-
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(self.device_id_)
52+
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())

0 commit comments

Comments
 (0)