Skip to content

Commit 7418d8b

Browse files
tc-mbmglambda
authored andcommitted
llava : support Minicpm-omni (ggml-org#11289)
* init * add readme * update readme * no use make * update readme * update fix code * fix editorconfig-checker * no change convert py * use clip_image_u8_free
1 parent 123214e commit 7418d8b

6 files changed

+100
-15
lines changed

examples/llava/README-minicpmo2.6.md

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
## MiniCPM-o 2.6
2+
Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
3+
4+
### Prepare models and code
5+
6+
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
7+
8+
Clone llama.cpp:
9+
```bash
10+
git clone git@github.com:OpenBMB/llama.cpp.git
11+
cd llama.cpp
12+
git checkout minicpm-omni
13+
```
14+
15+
### Usage of MiniCPM-o 2.6
16+
17+
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
18+
19+
```bash
20+
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
21+
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
22+
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
23+
24+
# quantize int4 version
25+
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
26+
```
27+
28+
Build llama.cpp using `CMake`:
29+
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
30+
31+
```bash
32+
cmake -B build
33+
cmake --build build --config Release
34+
```
35+
36+
Inference on Linux or Mac
37+
```
38+
# run f16 version
39+
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
40+
41+
# run quantized int4 version
42+
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
43+
44+
# or run in interactive mode
45+
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
46+
```

examples/llava/clip.cpp

+27-2
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
718718
else if (ctx->minicpmv_version == 3) {
719719
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
720720
}
721+
else if (ctx->minicpmv_version == 4) {
722+
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
723+
}
721724
ggml_set_name(pos_embed, "pos_embed");
722725
ggml_set_input(pos_embed);
723726
}
@@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
10531056
n_head = hidden_size/d_head;
10541057
num_query = 64;
10551058
}
1059+
else if (ctx->minicpmv_version == 4) {
1060+
hidden_size = 3584;
1061+
n_head = hidden_size/d_head;
1062+
num_query = 64;
1063+
}
10561064

10571065
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
10581066
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
20412049
images[images.size()-1].push_back(patch);
20422050
}
20432051
}
2052+
clip_image_u8_free(refine_image);
20442053
}
20452054
return images;
20462055
}
@@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
20792088
clip_image_f32_free(res);
20802089
}
20812090
}
2091+
for (size_t i = 0; i < imgs.size(); ++i) {
2092+
for (size_t j = 0; j < imgs[i].size(); ++j) {
2093+
if (imgs[i][j] != nullptr) {
2094+
clip_image_u8_free(imgs[i][j]);
2095+
}
2096+
}
2097+
}
20822098
return true;
20832099
}
20842100
else if (ctx->has_qwen2vl_merger) {
@@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
23352351
else if (ctx->minicpmv_version == 3) {
23362352
n_patches = 64;
23372353
}
2354+
else if (ctx->minicpmv_version == 4) {
2355+
n_patches = 64;
2356+
}
23382357
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
23392358
int patch_size = params.patch_size * 2;
23402359
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
@@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
25142533
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
25152534
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
25162535
int* positions_data = (int*)malloc(ggml_nbytes(positions));
2517-
int bucket_coords_h[70];
2518-
int bucket_coords_w[70];
2536+
int bucket_coords_h[1024];
2537+
int bucket_coords_w[1024];
25192538
for (int i = 0; i < pos_h; i++){
25202539
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
25212540
}
@@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
25432562
else if (ctx->minicpmv_version == 3) {
25442563
embed_dim = 3584;
25452564
}
2565+
else if (ctx->minicpmv_version == 4) {
2566+
embed_dim = 3584;
2567+
}
25462568
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
25472569

25482570
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
27862808
else if (ctx->minicpmv_version == 3) {
27872809
return 3584;
27882810
}
2811+
else if (ctx->minicpmv_version == 4) {
2812+
return 3584;
2813+
}
27892814
}
27902815
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
27912816
return ctx->vision_model.mm_1_b->ne[0];

examples/llava/llava.cpp

+5-8
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
216216
return true;
217217
}
218218

219-
static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
219+
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
220220
int width = image->nx;
221221
int height = image->ny;
222222
int num_patches = (height / patch_size) * (width / patch_size);
@@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
277277
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
278278
}
279279
else {
280-
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
281-
if (has_minicpmv_projector == 2) {
282-
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
283-
}
284-
else if (has_minicpmv_projector == 3) {
285-
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
286-
}
280+
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
287281
}
288282

289283
if (!encoded) {
@@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
313307
load_image_size->height = img->ny;
314308
clip_add_load_image_size(ctx_clip, load_image_size);
315309
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
310+
delete[] img_res_v.data;
311+
img_res_v.size = 0;
312+
img_res_v.data = nullptr;
316313
}
317314
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
318315
// flat / default llava-1.5 type embedding

examples/llava/minicpmv-cli.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
140140
else if (has_minicpmv_projector == 3) {
141141
system_prompt = "<|im_start|>user\n";
142142
}
143+
else if (has_minicpmv_projector == 4) {
144+
system_prompt = "<|im_start|>user\n";
145+
}
143146
LOG_INF("%s: image token past: %d\n", __func__, n_past);
144147
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
145148
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
227230
else if (has_minicpmv_projector == 3) {
228231
user_prompt = "<|im_start|>user\n" + prompt;
229232
}
233+
else if (has_minicpmv_projector == 4) {
234+
user_prompt = "<|im_start|>user\n" + prompt;
235+
}
230236
}
231237

232238
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
@@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
236242
else if (has_minicpmv_projector == 3) {
237243
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
238244
}
245+
else if (has_minicpmv_projector == 4) {
246+
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
247+
}
239248

240249
// generate the response
241250

@@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
308317
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
309318
response += tmp;
310319
if (strcmp(tmp, "</s>") == 0) break;
311-
if (strstr(tmp, "###")) break; // Yi-VL behavior
312320
printf("%s", tmp);// mistral llava-1.6
313321
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
314322
fflush(stdout);

examples/llava/minicpmv-convert-image-encoder-to-gguf.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def bytes_to_unicode():
501501
default_image_std = [0.26862954, 0.26130258, 0.27577711]
502502
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
503503
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
504-
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
504+
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
505505

506506
# with proper
507507
args = ap.parse_args()
@@ -545,12 +545,19 @@ def bytes_to_unicode():
545545

546546
minicpmv_version = args.minicpmv_version
547547
emb_dim = 4096
548+
block_count = 26
548549
if minicpmv_version == 1:
549550
emb_dim = 2304
551+
block_count = 26
550552
elif minicpmv_version == 2:
551553
emb_dim = 4096
554+
block_count = 27
552555
elif minicpmv_version == 3:
553556
emb_dim = 3584
557+
block_count = 27
558+
elif minicpmv_version == 4:
559+
emb_dim = 3584
560+
block_count = 27
554561

555562
default_vision_config = {
556563
"hidden_size": 1152,
@@ -567,6 +574,9 @@ def bytes_to_unicode():
567574
if minicpmv_version == 3:
568575
vision_config = SiglipVisionConfig(**default_vision_config)
569576
model = SiglipVisionTransformer(vision_config)
577+
elif minicpmv_version == 4:
578+
vision_config = SiglipVisionConfig(**default_vision_config)
579+
model = SiglipVisionTransformer(vision_config)
570580

571581
processor = None
572582
# if model.attn_pool is not None:
@@ -587,7 +597,7 @@ def bytes_to_unicode():
587597
fname_middle = "mmproj-"
588598
has_text_encoder = False
589599
has_minicpmv_projector = True
590-
minicpmv_version = 3
600+
minicpmv_version = 4
591601
elif args.vision_only:
592602
fname_middle = "vision-"
593603
has_text_encoder = False
@@ -625,7 +635,6 @@ def bytes_to_unicode():
625635
fout.add_uint32("clip.vision.projection_dim", 0)
626636
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
627637
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
628-
block_count = 26
629638
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
630639

631640
if processor is not None:

examples/llava/minicpmv-surgery.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
args = ap.parse_args()
99

1010
# find the model part that includes the the multimodal projector weights
11-
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
11+
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
1212
checkpoint = model.state_dict()
1313

1414
# get a list of mm tensor names

0 commit comments

Comments
 (0)