Skip to content

Commit 517a3e6

Browse files
Refactor Cohere Model (#30027)
* changes * addressing comments * smol fix
1 parent 75b76a5 commit 517a3e6

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

src/transformers/models/cohere/configuration_cohere.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class CohereConfig(PretrainedConfig):
8585
Whether to use a bias in the query, key, value and output projection layers during self-attention.
8686
attention_dropout (`float`, *optional*, defaults to 0.0):
8787
The dropout ratio for the attention probabilities.
88+
use_qk_norm (`bool`, *optional*, defaults to `False`):
89+
Whether to use query-key normalization in the attention
8890
8991
```python
9092
>>> from transformers import CohereModel, CohereConfig
@@ -123,6 +125,7 @@ def __init__(
123125
rope_theta=10000.0,
124126
attention_bias=False,
125127
attention_dropout=0.0,
128+
use_qk_norm=False,
126129
**kwargs,
127130
):
128131
self.vocab_size = vocab_size
@@ -145,6 +148,7 @@ def __init__(
145148
self.rope_theta = rope_theta
146149
self.attention_bias = attention_bias
147150
self.attention_dropout = attention_dropout
151+
self.use_qk_norm = use_qk_norm
148152

149153
super().__init__(
150154
pad_token_id=pad_token_id,

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def _get_unpad_data(attention_mask):
7676

7777

7878
class CohereLayerNorm(nn.Module):
79-
def __init__(self, hidden_size, eps=1e-5, bias=False):
79+
def __init__(self, hidden_size=None, eps=1e-5, bias=False):
80+
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
8081
super().__init__()
8182
self.weight = nn.Parameter(torch.ones(hidden_size))
82-
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
8383
self.variance_epsilon = eps
8484

8585
def forward(self, hidden_states):
@@ -89,8 +89,6 @@ def forward(self, hidden_states):
8989
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
9090
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
9191
hidden_states = self.weight.to(torch.float32) * hidden_states
92-
if self.bias is not None:
93-
hidden_states = hidden_states + self.bias.to(torch.float32)
9492
return hidden_states.to(input_dtype)
9593

9694

@@ -122,7 +120,7 @@ def forward(self, x, position_ids):
122120
emb = torch.repeat_interleave(freqs, 2, dim=-1)
123121
cos = emb.cos()
124122
sin = emb.sin()
125-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
123+
return cos, sin
126124

127125

128126
def rotate_half(x):
@@ -133,7 +131,6 @@ def rotate_half(x):
133131
return rot_x
134132

135133

136-
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
137134
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
138135
"""Applies Rotary Position Embedding to the query and key tensors.
139136
@@ -154,11 +151,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
154151
Returns:
155152
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
156153
"""
154+
dtype = q.dtype
155+
q = q.float()
156+
k = k.float()
157157
cos = cos.unsqueeze(unsqueeze_dim)
158158
sin = sin.unsqueeze(unsqueeze_dim)
159159
q_embed = (q * cos) + (rotate_half(q) * sin)
160160
k_embed = (k * cos) + (rotate_half(k) * sin)
161-
return q_embed, k_embed
161+
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
162162

163163

164164
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
@@ -192,7 +192,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
192192
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
193193

194194

195-
# Copied from transformers.models.llama.modeling_llama.LlamaAttention Llama->Cohere
196195
class CohereAttention(nn.Module):
197196
"""Multi-headed attention from 'Attention Is All You Need' paper"""
198197

@@ -216,13 +215,21 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
216215
self.max_position_embeddings = config.max_position_embeddings
217216
self.rope_theta = config.rope_theta
218217
self.is_causal = True
218+
self.use_qk_norm = config.use_qk_norm
219219

220220
if (self.head_dim * self.num_heads) != self.hidden_size:
221221
raise ValueError(
222222
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
223223
f" and `num_heads`: {self.num_heads})."
224224
)
225225

226+
if self.use_qk_norm:
227+
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
228+
self.q_norm = CohereLayerNorm(hidden_size=(self.num_heads, self.head_dim), eps=config.layer_norm_eps)
229+
self.k_norm = CohereLayerNorm(
230+
hidden_size=(self.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
231+
)
232+
226233
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
227234
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
228235
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
@@ -255,8 +262,14 @@ def forward(
255262
key_states = self.k_proj(hidden_states)
256263
value_states = self.v_proj(hidden_states)
257264

258-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
259-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
265+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
266+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
267+
if self.use_qk_norm:
268+
query_states = self.q_norm(query_states)
269+
key_states = self.k_norm(key_states)
270+
271+
query_states = query_states.transpose(1, 2)
272+
key_states = key_states.transpose(1, 2)
260273
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
261274

262275
past_key_value = getattr(self, "past_key_value", past_key_value)
@@ -335,11 +348,14 @@ def forward(
335348
key_states = self.k_proj(hidden_states)
336349
value_states = self.v_proj(hidden_states)
337350

338-
# Flash attention requires the input to have the shape
339-
# batch_size x seq_length x head_dim x hidden_dim
340-
# therefore we just need to keep the original shape
341-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
342-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
351+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
352+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
353+
if self.use_qk_norm:
354+
query_states = self.q_norm(query_states)
355+
key_states = self.k_norm(key_states)
356+
357+
query_states = query_states.transpose(1, 2)
358+
key_states = key_states.transpose(1, 2)
343359
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
344360

345361
cos, sin = self.rotary_emb(value_states, position_ids)
@@ -505,7 +521,7 @@ class CohereSdpaAttention(CohereAttention):
505521
SDPA API.
506522
"""
507523

508-
# Adapted from CohereAttention.forward
524+
# Ignore copy
509525
def forward(
510526
self,
511527
hidden_states: torch.Tensor,
@@ -538,8 +554,14 @@ def forward(
538554
key_states = self.k_proj(hidden_states)
539555
value_states = self.v_proj(hidden_states)
540556

541-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
542-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
557+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
558+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
559+
if self.use_qk_norm:
560+
query_states = self.q_norm(query_states)
561+
key_states = self.k_norm(key_states)
562+
563+
query_states = query_states.transpose(1, 2)
564+
key_states = key_states.transpose(1, 2)
543565
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
544566

545567
cos, sin = self.rotary_emb(value_states, position_ids)
@@ -599,7 +621,7 @@ def __init__(self, config: CohereConfig, layer_idx: int):
599621
self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
600622

601623
self.mlp = CohereMLP(config)
602-
self.input_layernorm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
624+
self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
603625

604626
def forward(
605627
self,
@@ -822,7 +844,7 @@ def __init__(self, config: CohereConfig):
822844
self.layers = nn.ModuleList(
823845
[CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
824846
)
825-
self.norm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
847+
self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
826848
self.gradient_checkpointing = False
827849

828850
# Initialize weights and apply final processing

0 commit comments

Comments
 (0)