3
3
import copy
4
4
import math
5
5
import re
6
- from typing import Dict , Iterable , List , Optional , Tuple , Union
6
+ from typing import Dict , Iterable , List , Optional , Set , Tuple , Union
7
7
8
8
import torch
9
9
import torch .distributed
@@ -110,7 +110,17 @@ def _forward(
110
110
variance = tensor_model_parallel_all_reduce (
111
111
variance ) / self .tp_world
112
112
x = x * torch .rsqrt (variance + self .variance_epsilon )
113
- x = x .to (orig_dtype ) * self .weight
113
+
114
+ weight = self .weight
115
+ if x .size (- 1 ) != self .weight .size (0 ):
116
+ if self .weight .size (0 ) < x .size (- 1 ):
117
+ repeat_count = (x .size (- 1 ) + self .weight .size (0 )) // x .size (- 1 )
118
+ full_weight = self .weight .repeat (repeat_count )
119
+ weight = full_weight [:x .size (- 1 )]
120
+ else :
121
+ weight = self .weight [:x .size (- 1 )]
122
+
123
+ x = x .to (orig_dtype ) * weight
114
124
return x
115
125
116
126
def forward (
@@ -421,6 +431,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
421
431
attn_metadata ):
422
432
hidden = []
423
433
for _prefill_idx in range (getattr (attn_metadata , "num_prefills" , 0 )):
434
+ if _prefill_idx >= len (attn_metadata .query_start_loc ):
435
+ break
436
+ if _prefill_idx >= len (state_indices_tensor ):
437
+ break
424
438
_start = attn_metadata .query_start_loc [_prefill_idx ]
425
439
_end = attn_metadata .query_start_loc [_prefill_idx + 1 ]
426
440
slot_id = state_indices_tensor [_prefill_idx ]
@@ -443,6 +457,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
443
457
hidden .append (
444
458
self ._decode_infer (q , k , v , kv_cache , state_indices_tensor ,
445
459
attn_metadata ))
460
+
461
+ if not hidden :
462
+ return torch .empty ((0 , q .size (- 1 )), device = q .device , dtype = q .dtype )
463
+
446
464
hidden = torch .concat (hidden , dim = 0 ).contiguous ()
447
465
return hidden
448
466
@@ -663,6 +681,9 @@ def __init__(
663
681
self .shared_moe = False
664
682
665
683
shared_intermediate = getattr (config , 'shared_intermediate_size' , 0 )
684
+ if isinstance (shared_intermediate , list ):
685
+ shared_intermediate = shared_intermediate [
686
+ layer_id ] if layer_id < len (shared_intermediate ) else 0
666
687
if shared_intermediate > 0 :
667
688
self .shared_moe = True
668
689
self .shared_mlp = MiniMaxText01MLP (
@@ -875,6 +896,8 @@ def _clear_prefill_cache(self, attn_metadata,
875
896
876
897
slots_to_clear = []
877
898
for _prefill_id in range (getattr (attn_metadata , "num_prefills" , 0 )):
899
+ if _prefill_id >= len (seq_id_map ):
900
+ break
878
901
seq_id = seq_id_map [_prefill_id ]
879
902
if attn_metadata .context_lens_tensor [
880
903
_prefill_id ] == 0 and seq_id in seq_to_slot_maps :
@@ -886,13 +909,18 @@ def _clear_prefill_cache(self, attn_metadata,
886
909
dtype = torch .long )
887
910
minimax_cache_tensors [:, slots_tensor , ...] = 0
888
911
912
+ def get_input_embeddings (
913
+ self ,
914
+ input_ids : torch .Tensor ,
915
+ ) -> torch .Tensor :
916
+ return self .embed_tokens (input_ids )
917
+
889
918
def forward (self ,
890
919
input_ids : Optional [torch .Tensor ],
891
920
positions : torch .Tensor ,
892
- kv_caches : List [torch .Tensor ],
893
- intermediate_tensors = None ,
921
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
894
922
inputs_embeds : Optional [torch .Tensor ] = None ,
895
- ** kwargs ) -> torch .Tensor :
923
+ ** kwargs ) -> Union [ torch .Tensor , IntermediateTensors ] :
896
924
forward_context = get_forward_context ()
897
925
attn_metadata = forward_context .attn_metadata
898
926
if attn_metadata is None :
@@ -901,6 +929,7 @@ def forward(self,
901
929
kwargs ["request_ids_to_seq_ids" ] = {}
902
930
if "finished_requests_ids" not in kwargs :
903
931
kwargs ["finished_requests_ids" ] = []
932
+
904
933
(
905
934
minimax_cache_tensors ,
906
935
state_indices_tensor ,
@@ -922,15 +951,11 @@ def forward(self,
922
951
hidden_states = intermediate_tensors ["hidden_states" ]
923
952
residual = intermediate_tensors ["residual" ]
924
953
925
- kv_cache_index = 0
926
954
minimax_cache_index = 0
927
955
attn_metadata .rotary_emb = self .rotary_emb
928
956
for i in range (self .start_layer , self .end_layer ):
929
957
layer = self .layers [i ]
930
958
_caches = None
931
- if isinstance (layer .self_attn , MiniMaxText01Attention ):
932
- _caches = kv_caches [kv_cache_index ]
933
- kv_cache_index += 1
934
959
if isinstance (layer .self_attn , MiniMaxText01LinearAttention ):
935
960
current_state_layer = minimax_cache_index
936
961
_caches = minimax_cache_params .at_layer_idx (
@@ -1009,15 +1034,20 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
1009
1034
return self .model .minimax_cache .get_seqlen_agnostic_capture_inputs (
1010
1035
batch_size )
1011
1036
1037
+ def get_input_embeddings (
1038
+ self ,
1039
+ input_ids : torch .Tensor ,
1040
+ ) -> torch .Tensor :
1041
+ return self .model .get_input_embeddings (input_ids )
1042
+
1012
1043
def forward (self ,
1013
1044
input_ids : torch .Tensor ,
1014
1045
positions : torch .Tensor ,
1015
1046
intermediate_tensors : Optional [IntermediateTensors ] = None ,
1016
1047
inputs_embeds : Optional [torch .Tensor ] = None ,
1017
1048
** kwargs ) -> torch .Tensor :
1018
- hidden_states = self .model (input_ids , positions , self .kv_cache ,
1019
- intermediate_tensors , inputs_embeds ,
1020
- ** kwargs )
1049
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
1050
+ inputs_embeds , ** kwargs )
1021
1051
1022
1052
return hidden_states
1023
1053
@@ -1043,8 +1073,9 @@ def make_empty_intermediate_tensors(
1043
1073
})
1044
1074
1045
1075
def load_weights (self , weights : Iterable [Tuple [str ,
1046
- torch .Tensor ]]) -> None :
1076
+ torch .Tensor ]]) -> Set [ str ] :
1047
1077
params_dict = dict (self .named_parameters ())
1078
+ loaded_params : Set [str ] = set ()
1048
1079
1049
1080
def which_layer (name : str ) -> int :
1050
1081
if "layers" in name :
@@ -1108,6 +1139,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
1108
1139
weight_name ,
1109
1140
expert_id = expert_id ,
1110
1141
shard_id = shard_id )
1142
+ loaded_params .add (name )
1111
1143
break
1112
1144
else :
1113
1145
if is_pp_missing_parameter (name , self ):
@@ -1117,6 +1149,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
1117
1149
default_weight_loader )
1118
1150
weight_loader = weight_loader_with_alias (name )(weight_loader )
1119
1151
weight_loader (param , loaded_weight )
1152
+ loaded_params .add (name )
1120
1153
return
1121
1154
1122
1155
def is_shared_mlp_weight (name : str ) -> bool :
@@ -1154,6 +1187,7 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
1154
1187
else :
1155
1188
raise AssertionError (
1156
1189
"MLP weight not in [gate_up_proj, down_proj]" )
1190
+ loaded_params .add (name )
1157
1191
return
1158
1192
1159
1193
def is_mha_weight (name : str ) -> bool :
@@ -1170,6 +1204,7 @@ def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
1170
1204
MiniMaxText01LinearAttention .weight_direct_load )
1171
1205
weight_loader = weight_loader_with_alias (name )(weight_loader )
1172
1206
weight_loader (param , loaded_weight )
1207
+ loaded_params .add (name )
1173
1208
return
1174
1209
1175
1210
def load_flash_attn_weight (name : str , loaded_weight : torch .Tensor ,
@@ -1194,6 +1229,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
1194
1229
default_weight_loader )
1195
1230
weight_loader = weight_loader_with_alias (name )(weight_loader )
1196
1231
weight_loader (param , loaded_weight , shard_id )
1232
+ loaded_params .add (name )
1197
1233
break
1198
1234
else :
1199
1235
if is_pp_missing_parameter (name , self ):
@@ -1204,6 +1240,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
1204
1240
default_weight_loader )
1205
1241
weight_loader = weight_loader_with_alias (name )(weight_loader )
1206
1242
weight_loader (param , loaded_weight )
1243
+ loaded_params .add (name )
1207
1244
return
1208
1245
1209
1246
def is_layer_norm_weight (name : str ) -> bool :
@@ -1219,6 +1256,7 @@ def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
1219
1256
default_weight_loader )
1220
1257
weight_loader = weight_loader_with_alias (name )(weight_loader )
1221
1258
weight_loader (param , loaded_weight )
1259
+ loaded_params .add (name )
1222
1260
return
1223
1261
1224
1262
def load_basic_weight (name : str , loaded_weight : torch .Tensor ,
@@ -1230,6 +1268,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
1230
1268
default_weight_loader )
1231
1269
weight_loader = weight_loader_with_alias (name )(weight_loader )
1232
1270
weight_loader (param , loaded_weight )
1271
+ loaded_params .add (name )
1233
1272
return
1234
1273
1235
1274
for name , loaded_weight in weights :
@@ -1258,4 +1297,4 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
1258
1297
continue
1259
1298
1260
1299
load_basic_weight (name , loaded_weight , self )
1261
- return
1300
+ return loaded_params
0 commit comments