Skip to content

Commit 191c62b

Browse files
committed
Add Command-R Model
Information about the Command-R 35B model can be found at: https://huggingface.co/CohereForAI/c4ai-command-r-v01 Based on the llama2 model with a few changes: 1) New hyper parameter to scale output logits (logit_scale) 2) Uses LayerNorm instead of RMSNorm 3) Transfomer layers have a single shared LayerNorm that feeds into both the self-attention and FFN layers in parallel. There is no post-attention LayerNorm. 4) No support for Rotary Position Embeddings (RoPE) scaling To convert model to GGUF format: 1) Download Command-R Hugging Face safetensors: git lfs install git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01 2) Run: python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
1 parent 306d34b commit 191c62b

File tree

5 files changed

+216
-0
lines changed

5 files changed

+216
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Typically finetunes of the base models below are supported as well.
110110
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
111111
- [x] [Gemma](https://ai.google.dev/gemma)
112112
- [x] [Mamba](https://github.com/state-spaces/mamba)
113+
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
113114

114115
**Multimodal models:**
115116

convert-hf-to-gguf.py

+10
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,16 @@ def write_tensors(self):
19651965
self.gguf_writer.add_tensor(new_name, data)
19661966

19671967

1968+
@Model.register("CohereForCausalLM")
1969+
class CommandR2Model(Model):
1970+
model_arch = gguf.MODEL_ARCH.COMMAND_R
1971+
1972+
def set_gguf_parameters(self):
1973+
super().set_gguf_parameters()
1974+
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
1975+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
1976+
1977+
19681978
###### CONVERSION LOGIC ######
19691979

19701980

gguf-py/gguf/constants.py

+15
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class LLM:
4141
EXPERT_COUNT = "{arch}.expert_count"
4242
EXPERT_USED_COUNT = "{arch}.expert_used_count"
4343
POOLING_TYPE = "{arch}.pooling_type"
44+
LOGIT_SCALE = "{arch}.logit_scale"
4445

4546
class Attention:
4647
HEAD_COUNT = "{arch}.attention.head_count"
@@ -120,6 +121,7 @@ class MODEL_ARCH(IntEnum):
120121
GEMMA = auto()
121122
STARCODER2 = auto()
122123
MAMBA = auto()
124+
COMMAND_R = auto()
123125

124126

125127
class MODEL_TENSOR(IntEnum):
@@ -186,6 +188,7 @@ class MODEL_TENSOR(IntEnum):
186188
MODEL_ARCH.GEMMA: "gemma",
187189
MODEL_ARCH.STARCODER2: "starcoder2",
188190
MODEL_ARCH.MAMBA: "mamba",
191+
MODEL_ARCH.COMMAND_R: "command-r",
189192
}
190193

191194
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -578,6 +581,18 @@ class MODEL_TENSOR(IntEnum):
578581
MODEL_TENSOR.SSM_D,
579582
MODEL_TENSOR.SSM_OUT,
580583
],
584+
MODEL_ARCH.COMMAND_R: [
585+
MODEL_TENSOR.TOKEN_EMBD,
586+
MODEL_TENSOR.OUTPUT_NORM,
587+
MODEL_TENSOR.ATTN_NORM,
588+
MODEL_TENSOR.ATTN_Q,
589+
MODEL_TENSOR.ATTN_K,
590+
MODEL_TENSOR.ATTN_V,
591+
MODEL_TENSOR.ATTN_OUT,
592+
MODEL_TENSOR.FFN_GATE,
593+
MODEL_TENSOR.FFN_DOWN,
594+
MODEL_TENSOR.FFN_UP,
595+
],
581596
# TODO
582597
}
583598

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ def add_max_alibi_bias(self, bias: float) -> None:
346346
def add_clamp_kqv(self, value: float) -> None:
347347
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
348348

349+
def add_logit_scale(self, value: float) -> None:
350+
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
351+
349352
def add_expert_count(self, count: int) -> None:
350353
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
351354

llama.cpp

+187
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ enum llm_arch {
214214
LLM_ARCH_GEMMA,
215215
LLM_ARCH_STARCODER2,
216216
LLM_ARCH_MAMBA,
217+
LLM_ARCH_COMMAND_R,
217218
LLM_ARCH_UNKNOWN,
218219
};
219220

@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
243244
{ LLM_ARCH_GEMMA, "gemma" },
244245
{ LLM_ARCH_STARCODER2, "starcoder2" },
245246
{ LLM_ARCH_MAMBA, "mamba" },
247+
{ LLM_ARCH_COMMAND_R, "command-r" },
246248
{ LLM_ARCH_UNKNOWN, "(unknown)" },
247249
};
248250

@@ -267,6 +269,7 @@ enum llm_kv {
267269
LLM_KV_EXPERT_COUNT,
268270
LLM_KV_EXPERT_USED_COUNT,
269271
LLM_KV_POOLING_TYPE,
272+
LLM_KV_LOGIT_SCALE,
270273

271274
LLM_KV_ATTENTION_HEAD_COUNT,
272275
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -330,6 +333,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
330333
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
331334
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
332335
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
336+
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
333337

334338
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
335339
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -836,6 +840,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
836840
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
837841
},
838842
},
843+
{
844+
LLM_ARCH_COMMAND_R,
845+
{
846+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
847+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
848+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
849+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
850+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
851+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
852+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
853+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
854+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
855+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
856+
},
857+
},
839858
{
840859
LLM_ARCH_UNKNOWN,
841860
{
@@ -1610,6 +1629,7 @@ enum e_model {
16101629
MODEL_20B,
16111630
MODEL_30B,
16121631
MODEL_34B,
1632+
MODEL_35B,
16131633
MODEL_40B,
16141634
MODEL_65B,
16151635
MODEL_70B,
@@ -1656,6 +1676,7 @@ struct llama_hparams {
16561676

16571677
float f_clamp_kqv = 0.0f;
16581678
float f_max_alibi_bias = 0.0f;
1679+
float f_logit_scale = 0.0f;
16591680

16601681
bool causal_attn = true;
16611682
bool need_kq_pos = false;
@@ -3237,6 +3258,7 @@ static const char * llama_model_type_name(e_model type) {
32373258
case MODEL_20B: return "20B";
32383259
case MODEL_30B: return "30B";
32393260
case MODEL_34B: return "34B";
3261+
case MODEL_35B: return "35B";
32403262
case MODEL_40B: return "40B";
32413263
case MODEL_65B: return "65B";
32423264
case MODEL_70B: return "70B";
@@ -3628,6 +3650,15 @@ static void llm_load_hparams(
36283650
default: model.type = e_model::MODEL_UNKNOWN;
36293651
}
36303652
} break;
3653+
case LLM_ARCH_COMMAND_R:
3654+
{
3655+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
3656+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
3657+
switch (hparams.n_layer) {
3658+
case 40: model.type = e_model::MODEL_35B; break;
3659+
default: model.type = e_model::MODEL_UNKNOWN;
3660+
}
3661+
} break;
36313662
default: (void)0;
36323663
}
36333664

@@ -3937,6 +3968,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
39373968
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
39383969
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
39393970
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
3971+
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
39403972
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
39413973
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
39423974
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@@ -4910,6 +4942,37 @@ static bool llm_load_tensors(
49104942
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
49114943
}
49124944
} break;
4945+
case LLM_ARCH_COMMAND_R:
4946+
{
4947+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4948+
4949+
// output
4950+
{
4951+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4952+
// init output from the input tok embed
4953+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4954+
ml.n_created--; // artificial tensor
4955+
ml.size_data += ggml_nbytes(model.output);
4956+
}
4957+
4958+
for (int i = 0; i < n_layer; ++i) {
4959+
ggml_context * ctx_layer = ctx_for_layer(i);
4960+
ggml_context * ctx_split = ctx_for_layer_split(i);
4961+
4962+
auto & layer = model.layers[i];
4963+
4964+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4965+
4966+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4967+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
4968+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
4969+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4970+
4971+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
4972+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
4973+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4974+
}
4975+
} break;
49134976
default:
49144977
throw std::runtime_error("unknown architecture");
49154978
}
@@ -8302,6 +8365,125 @@ struct llm_build_context {
83028365

83038366
return gf;
83048367
}
8368+
8369+
struct ggml_cgraph * build_command_r() {
8370+
8371+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8372+
8373+
const int64_t n_embd_head = hparams.n_embd_head_v;
8374+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8375+
const float f_logit_scale = hparams.f_logit_scale;
8376+
8377+
struct ggml_tensor * cur;
8378+
struct ggml_tensor * inpL;
8379+
8380+
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
8381+
cb(inpL, "inp_embd", -1);
8382+
8383+
// inp_pos - contains the positions
8384+
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
8385+
cb(inp_pos, "inp_pos", -1);
8386+
8387+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8388+
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
8389+
cb(KQ_mask, "KQ_mask", -1);
8390+
8391+
for (int il = 0; il < n_layer; ++il) {
8392+
8393+
// norm
8394+
cur = llm_build_norm(ctx0, inpL, hparams,
8395+
model.layers[il].attn_norm, NULL,
8396+
LLM_NORM, cb, il);
8397+
cb(cur, "attn_norm", il);
8398+
struct ggml_tensor * ffn_inp = cur;
8399+
8400+
// self-attention
8401+
{
8402+
// compute Q and K and RoPE them
8403+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
8404+
cb(Qcur, "Qcur", il);
8405+
if (model.layers[il].bq) {
8406+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8407+
cb(Qcur, "Qcur", il);
8408+
}
8409+
8410+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
8411+
cb(Kcur, "Kcur", il);
8412+
if (model.layers[il].bk) {
8413+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8414+
cb(Kcur, "Kcur", il);
8415+
}
8416+
8417+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
8418+
cb(Vcur, "Vcur", il);
8419+
if (model.layers[il].bv) {
8420+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8421+
cb(Vcur, "Vcur", il);
8422+
}
8423+
8424+
Qcur = ggml_rope_custom(
8425+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8426+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8427+
ext_factor, attn_factor, beta_fast, beta_slow
8428+
);
8429+
cb(Qcur, "Qcur", il);
8430+
8431+
Kcur = ggml_rope_custom(
8432+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8433+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8434+
ext_factor, attn_factor, beta_fast, beta_slow
8435+
);
8436+
cb(Kcur, "Kcur", il);
8437+
8438+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8439+
model.layers[il].wo, model.layers[il].bo,
8440+
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8441+
cb(cur, "kqv_out", il);
8442+
}
8443+
8444+
struct ggml_tensor * attn_out = cur;
8445+
8446+
// feed-forward network
8447+
{
8448+
cur = llm_build_ffn(ctx0, ffn_inp,
8449+
model.layers[il].ffn_up, NULL,
8450+
model.layers[il].ffn_gate, NULL,
8451+
model.layers[il].ffn_down, NULL,
8452+
NULL,
8453+
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8454+
cb(cur, "ffn_out", il);
8455+
}
8456+
8457+
// add together residual + FFN + self-attention
8458+
cur = ggml_add(ctx0, cur, inpL);
8459+
cur = ggml_add(ctx0, cur, attn_out);
8460+
cb(cur, "l_out", il);
8461+
8462+
// input for next layer
8463+
inpL = cur;
8464+
}
8465+
8466+
cur = inpL;
8467+
8468+
cur = llm_build_norm(ctx0, cur, hparams,
8469+
model.output_norm, NULL,
8470+
LLM_NORM, cb, -1);
8471+
cb(cur, "result_norm", -1);
8472+
8473+
// lm_head
8474+
cur = ggml_mul_mat(ctx0, model.output, cur);
8475+
8476+
if (f_logit_scale) {
8477+
cur = ggml_scale(ctx0, cur, f_logit_scale);
8478+
}
8479+
8480+
cb(cur, "result_output", -1);
8481+
8482+
ggml_build_forward_expand(gf, cur);
8483+
8484+
return gf;
8485+
8486+
}
83058487
};
83068488

83078489
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8473,6 +8655,10 @@ static struct ggml_cgraph * llama_build_graph(
84738655
{
84748656
result = llm.build_mamba();
84758657
} break;
8658+
case LLM_ARCH_COMMAND_R:
8659+
{
8660+
result = llm.build_command_r();
8661+
} break;
84768662
default:
84778663
GGML_ASSERT(false);
84788664
}
@@ -13053,6 +13239,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1305313239
case LLM_ARCH_ORION:
1305413240
case LLM_ARCH_INTERNLM2:
1305513241
case LLM_ARCH_MINICPM:
13242+
case LLM_ARCH_COMMAND_R:
1305613243
return LLAMA_ROPE_TYPE_NORM;
1305713244

1305813245
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)