@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):
76
76
77
77
78
78
class CohereLayerNorm (nn .Module ):
79
- def __init__ (self , hidden_size , eps = 1e-5 , bias = False ):
79
+ def __init__ (self , hidden_size = None , eps = 1e-5 , bias = False ):
80
+ """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
80
81
super ().__init__ ()
81
82
self .weight = nn .Parameter (torch .ones (hidden_size ))
82
- self .bias = nn .Parameter (torch .zeros (hidden_size )) if bias else None
83
83
self .variance_epsilon = eps
84
84
85
85
def forward (self , hidden_states ):
@@ -89,8 +89,6 @@ def forward(self, hidden_states):
89
89
variance = (hidden_states - mean ).pow (2 ).mean (- 1 , keepdim = True )
90
90
hidden_states = (hidden_states - mean ) * torch .rsqrt (variance + self .variance_epsilon )
91
91
hidden_states = self .weight .to (torch .float32 ) * hidden_states
92
- if self .bias is not None :
93
- hidden_states = hidden_states + self .bias .to (torch .float32 )
94
92
return hidden_states .to (input_dtype )
95
93
96
94
@@ -122,7 +120,7 @@ def forward(self, x, position_ids):
122
120
emb = torch .repeat_interleave (freqs , 2 , dim = - 1 )
123
121
cos = emb .cos ()
124
122
sin = emb .sin ()
125
- return cos . to ( dtype = x . dtype ) , sin . to ( dtype = x . dtype )
123
+ return cos , sin
126
124
127
125
128
126
def rotate_half (x ):
@@ -133,7 +131,6 @@ def rotate_half(x):
133
131
return rot_x
134
132
135
133
136
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
137
134
def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
138
135
"""Applies Rotary Position Embedding to the query and key tensors.
139
136
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
154
151
Returns:
155
152
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
156
153
"""
154
+ dtype = q .dtype
155
+ q = q .float ()
156
+ k = k .float ()
157
157
cos = cos .unsqueeze (unsqueeze_dim )
158
158
sin = sin .unsqueeze (unsqueeze_dim )
159
159
q_embed = (q * cos ) + (rotate_half (q ) * sin )
160
160
k_embed = (k * cos ) + (rotate_half (k ) * sin )
161
- return q_embed , k_embed
161
+ return q_embed . to ( dtype = dtype ) , k_embed . to ( dtype = dtype )
162
162
163
163
164
164
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
192
192
return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
193
193
194
194
195
- # Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere
196
195
class CohereAttention (nn .Module ):
197
196
"""Multi-headed attention from 'Attention Is All You Need' paper"""
198
197
@@ -216,13 +215,21 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
216
215
self .max_position_embeddings = config .max_position_embeddings
217
216
self .rope_theta = config .rope_theta
218
217
self .is_causal = True
218
+ self .use_qk_norm = config .use_qk_norm
219
219
220
220
if (self .head_dim * self .num_heads ) != self .hidden_size :
221
221
raise ValueError (
222
222
f"hidden_size must be divisible by num_heads (got `hidden_size`: { self .hidden_size } "
223
223
f" and `num_heads`: { self .num_heads } )."
224
224
)
225
225
226
+ if self .use_qk_norm :
227
+ # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
228
+ self .q_norm = CohereLayerNorm (hidden_size = (self .num_heads , self .head_dim ), eps = config .layer_norm_eps )
229
+ self .k_norm = CohereLayerNorm (
230
+ hidden_size = (self .num_key_value_heads , self .head_dim ), eps = config .layer_norm_eps
231
+ )
232
+
226
233
self .q_proj = nn .Linear (self .hidden_size , self .num_heads * self .head_dim , bias = config .attention_bias )
227
234
self .k_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = config .attention_bias )
228
235
self .v_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = config .attention_bias )
@@ -255,8 +262,14 @@ def forward(
255
262
key_states = self .k_proj (hidden_states )
256
263
value_states = self .v_proj (hidden_states )
257
264
258
- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
259
- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
265
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
266
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
267
+ if self .use_qk_norm :
268
+ query_states = self .q_norm (query_states )
269
+ key_states = self .k_norm (key_states )
270
+
271
+ query_states = query_states .transpose (1 , 2 )
272
+ key_states = key_states .transpose (1 , 2 )
260
273
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
261
274
262
275
past_key_value = getattr (self , "past_key_value" , past_key_value )
@@ -335,11 +348,14 @@ def forward(
335
348
key_states = self .k_proj (hidden_states )
336
349
value_states = self .v_proj (hidden_states )
337
350
338
- # Flash attention requires the input to have the shape
339
- # batch_size x seq_length x head_dim x hidden_dim
340
- # therefore we just need to keep the original shape
341
- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
342
- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
351
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
352
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
353
+ if self .use_qk_norm :
354
+ query_states = self .q_norm (query_states )
355
+ key_states = self .k_norm (key_states )
356
+
357
+ query_states = query_states .transpose (1 , 2 )
358
+ key_states = key_states .transpose (1 , 2 )
343
359
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
344
360
345
361
cos , sin = self .rotary_emb (value_states , position_ids )
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
505
521
SDPA API.
506
522
"""
507
523
508
- # Adapted from CohereAttention.forward
524
+ # Ignore copy
509
525
def forward (
510
526
self ,
511
527
hidden_states : torch .Tensor ,
@@ -538,8 +554,14 @@ def forward(
538
554
key_states = self .k_proj (hidden_states )
539
555
value_states = self .v_proj (hidden_states )
540
556
541
- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
542
- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
557
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
558
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
559
+ if self .use_qk_norm :
560
+ query_states = self .q_norm (query_states )
561
+ key_states = self .k_norm (key_states )
562
+
563
+ query_states = query_states .transpose (1 , 2 )
564
+ key_states = key_states .transpose (1 , 2 )
543
565
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
544
566
545
567
cos , sin = self .rotary_emb (value_states , position_ids )
@@ -599,7 +621,7 @@ def __init__(self, config: CohereConfig, layer_idx: int):
599
621
self .self_attn = COHERE_ATTENTION_CLASSES [config ._attn_implementation ](config = config , layer_idx = layer_idx )
600
622
601
623
self .mlp = CohereMLP (config )
602
- self .input_layernorm = CohereLayerNorm (config .hidden_size , eps = config .layer_norm_eps )
624
+ self .input_layernorm = CohereLayerNorm (hidden_size = ( config .hidden_size ) , eps = config .layer_norm_eps )
603
625
604
626
def forward (
605
627
self ,
@@ -822,7 +844,7 @@ def __init__(self, config: CohereConfig):
822
844
self .layers = nn .ModuleList (
823
845
[CohereDecoderLayer (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
824
846
)
825
- self .norm = CohereLayerNorm (config .hidden_size , eps = config .layer_norm_eps )
847
+ self .norm = CohereLayerNorm (hidden_size = ( config .hidden_size ) , eps = config .layer_norm_eps )
826
848
self .gradient_checkpointing = False
827
849
828
850
# Initialize weights and apply final processing
0 commit comments