From d04ee928a24df14dda233132ddc008ae838e4ccb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Dec 2023 21:31:05 +0200 Subject: [PATCH 01/14] llama : support quantum K cache (wip) --- ggml-metal.m | 2 +- llama.cpp | 30 +++++++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 3343bc8a3af37..c24e0fe206036 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1114,7 +1114,7 @@ void ggml_metal_graph_compute( !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && ne00 % 32 == 0 && ne00 >= 64 && - ne11 > ne11_mm_min) { + (ne11 > ne11_mm_min || ne12 > 1)) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; diff --git a/llama.cpp b/llama.cpp index d23a14469a0f0..04d524fdebed7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1522,7 +1522,8 @@ struct llama_context { static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, - ggml_type wtype, + ggml_type ktype, + ggml_type vtype, uint32_t n_ctx, int n_gpu_layers, bool offload) { @@ -1541,7 +1542,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); - cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*n_layer*ggml_tensor_overhead()); + cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead()); memset(cache.buf.data, 0, cache.buf.size); struct ggml_init_params params; @@ -1566,8 +1567,8 @@ static bool llama_kv_cache_init( GGML_UNUSED(offload); for (int i = 0; i < (int) n_layer; i++) { - ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx); - ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx); + ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx); + ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -3558,8 +3559,8 @@ static void llm_build_k_shift( ggml_rope_custom_inplace(ctx, ggml_view_3d(ctx, kv.k_l[il], n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv.k_l[il])*n_embd_head, - ggml_element_size(kv.k_l[il])*n_embd_gqa, + ggml_type_sizef(kv.k_l[il]->type)*n_embd_head, + ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa, 0), K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -3588,7 +3589,7 @@ static void llm_build_kv_store( cb(v_cur_t, "v_cur_t", il); struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa, - (ggml_element_size(kv.k_l[il])*n_embd_gqa)*kv_head); + (ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head); cb(k_cache_view, "k_cache_view", il); struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa, @@ -3747,8 +3748,8 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv.k_l[il])*n_embd_gqa, - ggml_element_size(kv.k_l[il])*n_embd_head, + ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa, + ggml_type_sizef(kv.k_l[il]->type)*n_embd_head, 0); cb(k, "k", il); @@ -8734,11 +8735,18 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + + // TODO: move as params + const ggml_type k_type = GGML_TYPE_Q4_0; + const ggml_type v_type = GGML_TYPE_F16; + + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(v_type) == 0); // reserve memory for context buffers if (!hparams.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, k_type, v_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; From bcfebf241ddda20b08a5a6c2cefba9768a037748 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Dec 2023 10:42:10 +0200 Subject: [PATCH 02/14] metal : add F32 -> Q8_0 copy kernel --- ggml-metal.m | 14 +++++++++--- ggml-metal.metal | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ llama.cpp | 2 +- 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c24e0fe206036..d9bb9211ecf77 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -118,6 +118,7 @@ GGML_METAL_DECL_KERNEL(im2col_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); + GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); GGML_METAL_DECL_KERNEL(cpy_f16_f16); GGML_METAL_DECL_KERNEL(concat); GGML_METAL_DECL_KERNEL(sqr); @@ -324,6 +325,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(im2col_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); + GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); GGML_METAL_ADD_KERNEL(cpy_f16_f16); GGML_METAL_ADD_KERNEL(concat); GGML_METAL_ADD_KERNEL(sqr); @@ -425,6 +427,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(im2col_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); + GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); GGML_METAL_DEL_KERNEL(cpy_f16_f16); GGML_METAL_DEL_KERNEL(concat); GGML_METAL_DEL_KERNEL(sqr); @@ -1549,14 +1552,19 @@ void ggml_metal_graph_compute( case GGML_OP_CPY: case GGML_OP_CONT: { - const int nth = MIN(1024, ne00); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); switch (src0t) { case GGML_TYPE_F32: { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; default: GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 9a79f815f3a72..468689ed931a6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1460,6 +1460,64 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, diff --git a/llama.cpp b/llama.cpp index 04d524fdebed7..ca23408b42735 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8738,7 +8738,7 @@ struct llama_context * llama_new_context_with_model( //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // TODO: move as params - const ggml_type k_type = GGML_TYPE_Q4_0; + const ggml_type k_type = GGML_TYPE_Q8_0; const ggml_type v_type = GGML_TYPE_F16; GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0); From a1bf6c09f8ce04e374d71c0c640eacee14767f35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Dec 2023 15:08:36 +0200 Subject: [PATCH 03/14] cuda : add F32 -> Q8_0 copy kernel ggml-ci --- ggml-cuda.cu | 93 ++++++++++++++++++++++++++++++++++++++++++---------- llama.cpp | 4 +-- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9019a849f0bff..3ad0d305d3403 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4559,6 +4559,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = xi[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = xi[j]*id; + + dsti->qs[j] = roundf(x0); + } +} + +// TODO: generalize for all quants +template +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0; + + if (i >= ne) { + return; + } + + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = (i - i02*ne01*ne00 - i01*ne00); + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -5737,6 +5784,17 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } +static void ggml_cpy_f32_q8_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + static void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, @@ -6093,20 +6151,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; + const int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { + if (nb0 == ts && nb1 == ts*(ne0/bs)) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); } if (nb0 == ts) { - return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); + return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream); } + GGML_ASSERT(bs == 1 && "TODO: implement bs != 1"); for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*ts*ne0); // pretend the row is a matrix with cols=1 - cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); + cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream); if (r != cudaSuccess) { return r; } } return cudaSuccess; @@ -6533,7 +6592,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( size_t ash; dfloat * src1_dfloat = nullptr; // dfloat == half - bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + bool src1_convert_f16 = + src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; @@ -7103,10 +7163,9 @@ static void ggml_cuda_op_mul_mat( const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); - const bool src1_is_contiguous = ggml_is_contiguous(src1); - const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ? - ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; GGML_ASSERT(!(split && ne02 > 1)); @@ -7231,7 +7290,7 @@ static void ggml_cuda_op_mul_mat( const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs; + char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); @@ -7694,7 +7753,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + if (ggml_nrows(src1) == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else @@ -7770,14 +7829,13 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -7788,6 +7846,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg } static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // TODO: why do we pass dst as src1 here? ggml_cuda_cpy(src0, dst, nullptr); (void) src1; } diff --git a/llama.cpp b/llama.cpp index ca23408b42735..a70e40dba7838 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1246,7 +1246,6 @@ struct llama_cparams { bool mul_mat_q; bool offload_kqv; - }; struct llama_layer { @@ -1562,7 +1561,7 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); - const int i_gpu_start = n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); + const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); GGML_UNUSED(offload); @@ -5696,6 +5695,7 @@ static int llama_decode_internal( // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); From b881f630ca9d18cc10e9600ab42c1b7fe2e8b31e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Dec 2023 15:41:20 +0200 Subject: [PATCH 04/14] cuda : use mmv kernel for quantum cache ops --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3ad0d305d3403..16d17f8013502 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6533,6 +6533,8 @@ inline void ggml_cuda_op_mul_mat_vec_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, const cudaStream_t & stream) { + GGML_ASSERT(ggml_nrows(src1) == 1); + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; @@ -7753,11 +7755,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (ggml_nrows(src1) == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else - const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1; #endif // GGML_CUDA_FORCE_DMMV if (use_mul_mat_vec_q) { From 3ce30e07c98a8cf5ce7a22a866d4d0fd5436216f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 15:40:23 +0200 Subject: [PATCH 05/14] llama : pass KV cache type through API --- common/common.cpp | 34 ++++++++++++++++++++++++++++++++++ common/common.h | 5 ++++- llama.cpp | 29 ++++++++++++++++++----------- llama.h | 3 +++ 4 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 43c374d5ce936..77332d5dbb77b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -500,6 +500,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.dump_kv_cache = true; } else if (arg == "-nkvo" || arg == "--no-kv-offload") { params.no_kv_offload = true; + } else if (arg == "-ctk" || arg == "--cache-type-k") { + params.cache_type_k = argv[++i]; + } else if (arg == "-ctv" || arg == "--cache-type-v") { + params.cache_type_v = argv[++i]; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { @@ -844,6 +848,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" verbose print of the KV cache\n"); printf(" -nkvo, --no-kv-offload\n"); printf(" disable KV offload\n"); + printf(" -ctk TYPE, --cache-type-k TYPE\n"); + printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); + printf(" -ctv TYPE, --cache-type-v TYPE\n"); + printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); @@ -908,6 +916,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & return mparams; } +static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + throw std::runtime_error("Invalid cache type: " + s); +} + struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); @@ -930,6 +961,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.offload_kqv = !params.no_kv_offload; + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); + cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + return cparams; } diff --git a/common/common.h b/common/common.h index 2664c8fc175b7..7f0d03e41dbd9 100644 --- a/common/common.h +++ b/common/common.h @@ -125,9 +125,12 @@ struct gpt_params { bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading + std::string cache_type_k = "f16"; // KV cache data type for the K + std::string cache_type_v = "f16"; // KV cache data type for the V + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file + std::string image = ""; // path to an image file }; bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); diff --git a/llama.cpp b/llama.cpp index a70e40dba7838..3f951dbe31952 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8580,6 +8580,8 @@ struct llama_context_params llama_context_default_params() { /*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_slow =*/ 1.0f, /*.yarn_orig_ctx =*/ 0, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, /*.mul_mat_q =*/ true, /*.f16_kv =*/ true, /*.logits_all =*/ false, @@ -8737,31 +8739,36 @@ struct llama_context * llama_new_context_with_model( //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - // TODO: move as params - const ggml_type k_type = GGML_TYPE_Q8_0; - const ggml_type v_type = GGML_TYPE_F16; + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; - GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0); - GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(v_type) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0); // reserve memory for context buffers if (!hparams.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, k_type, v_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } { - // const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); - size_t memory_size = 0; + size_t memory_size_k = 0; + size_t memory_size_v = 0; + for (auto & k : ctx->kv_self.k_l) { - memory_size += ggml_nbytes(k); + memory_size_k += ggml_nbytes(k); } + for (auto & v : ctx->kv_self.v_l) { - memory_size += ggml_nbytes(v); + memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } // resized during inference diff --git a/llama.h b/llama.h index c1593c9b03331..e45f12975c908 100644 --- a/llama.h +++ b/llama.h @@ -191,6 +191,9 @@ extern "C" { float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size + ggml_type type_k; // data type for K cache + ggml_type type_v; // data type for V cache + // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool f16_kv; // use fp16 for KV cache, fp32 otherwise From 7864a2cd9bbddb04cec793f2a5c803585c94d395 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 15:43:25 +0200 Subject: [PATCH 06/14] llama : fix build ggml-ci --- llama.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.h b/llama.h index e45f12975c908..f6c9d17519e7e 100644 --- a/llama.h +++ b/llama.h @@ -191,8 +191,8 @@ extern "C" { float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - ggml_type type_k; // data type for K cache - ggml_type type_v; // data type for V cache + enum ggml_type type_k; // data type for K cache + enum ggml_type type_v; // data type for V cache // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) From 9d69ecc0c9b8a27411d8d7f509013c10573c1cf4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:01:50 +0200 Subject: [PATCH 07/14] metal : add F32 -> Q4_0 copy kernel --- ggml-metal.m | 16 ++++++++++++ ggml-metal.metal | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/ggml-metal.m b/ggml-metal.m index d9bb9211ecf77..418d3003e0e4a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -119,6 +119,10 @@ GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); + GGML_METAL_DECL_KERNEL(cpy_f32_q4_0); + //GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); + //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); + //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); GGML_METAL_DECL_KERNEL(cpy_f16_f16); GGML_METAL_DECL_KERNEL(concat); GGML_METAL_DECL_KERNEL(sqr); @@ -326,6 +330,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); + GGML_METAL_ADD_KERNEL(cpy_f32_q4_0); + //GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); + //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); + //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); GGML_METAL_ADD_KERNEL(cpy_f16_f16); GGML_METAL_ADD_KERNEL(concat); GGML_METAL_ADD_KERNEL(sqr); @@ -428,6 +436,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); + GGML_METAL_DEL_KERNEL(cpy_f32_q4_0); + //GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); + //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); + //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); GGML_METAL_DEL_KERNEL(cpy_f16_f16); GGML_METAL_DEL_KERNEL(concat); GGML_METAL_DEL_KERNEL(sqr); @@ -1565,6 +1577,10 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break; + //case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; + //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break; + //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break; default: GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 468689ed931a6..3ca21b8d5652c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3,6 +3,7 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) #define QK4_0 32 #define QR4_0 2 @@ -1518,6 +1519,73 @@ kernel void kernel_cpy_f32_q8_0( } } +kernel void kernel_cpy_f32_q4_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, From 6b58ae98921884fc9a85efde0cdefc6a2f4c73b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:09:16 +0200 Subject: [PATCH 08/14] metal : add F32 -> Q4_1 copy kernel --- ggml-metal.m | 8 +++--- ggml-metal.metal | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 418d3003e0e4a..3023aa6db5bf1 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -120,7 +120,7 @@ GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); GGML_METAL_DECL_KERNEL(cpy_f32_q4_0); - //GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); + GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -331,7 +331,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); GGML_METAL_ADD_KERNEL(cpy_f32_q4_0); - //GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); + GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -437,7 +437,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); GGML_METAL_DEL_KERNEL(cpy_f32_q4_0); - //GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); + GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -1578,7 +1578,7 @@ void ggml_metal_graph_compute( case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break; - //case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break; //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break; default: GGML_ASSERT(false && "not implemented"); diff --git a/ggml-metal.metal b/ggml-metal.metal index 3ca21b8d5652c..9f5ffcbafe8fc 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1586,6 +1586,72 @@ kernel void kernel_cpy_f32_q4_0( } } +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, From e8457c90a07067cdc3a2dc6e2ba91119dd0b8e15 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:29:52 +0200 Subject: [PATCH 09/14] cuda : wip --- ggml-cuda.cu | 67 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 16d17f8013502..c2ce1769ca356 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } } -// TODO: generalize for all quants -template +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_0 * dsti = (block_q4_0 *) cdsti; + + float amax = 0.0f; + float max = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = xi[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = xi[0 + j]*id; + const float x1 = xi[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +template static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0; + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; @@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int i12 = i / (ne10*ne11); const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0; + const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk; const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; cpy_blck(cx + x_offset, cdst + dst_offset); @@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda( GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; - cpy_f32_q<<>> + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_0 == 0); + const int num_blocks = ne / QK4_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_1_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_1 == 0); + const int num_blocks = ne / QK4_1; + cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } @@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else { From b2acedeb1a2f7440426d50bc4c01b5a3ea82bd76 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:47:34 +0200 Subject: [PATCH 10/14] cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels --- ggml-cuda.cu | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c2ce1769ca356..53e53a0d1634b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7,6 +7,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { block_q4_0 * dsti = (block_q4_0 *) cdsti; float amax = 0.0f; - float max = 0.0f; + float vmax = 0.0f; for (int j = 0; j < QK4_0; ++j) { const float v = xi[j]; if (amax < fabsf(v)) { amax = fabsf(v); - max = v; + vmax = v; } } - const float d = max / -8; + const float d = vmax / -8; const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + dsti->d = d; for (int j = 0; j < QK4_0/2; ++j) { const float x0 = xi[0 + j]*id; @@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { } } +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_1 * dsti = (block_q4_1 *) cdsti; + + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = xi[j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->dm.x = d; + dsti->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (xi[0 + j] - vmin)*id; + const float x1 = (xi[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + template static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, From 903167a7771166011fe3179b5deef74b70e72c98 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 5 Dec 2023 16:32:53 +0100 Subject: [PATCH 11/14] llama-bench : support type_k/type_v --- examples/llama-bench/llama-bench.cpp | 111 ++++++++++++++++++++++----- 1 file changed, 91 insertions(+), 20 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 9bd82d565834a..6617c050ddfec 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -53,6 +53,13 @@ static std::vector split(const std::string & str, char delim) { return values; } +template +static std::vector transform_to_str(const std::vector & values, F f) { + std::vector str_values; + std::transform(values.begin(), values.end(), std::back_inserter(str_values), f); + return str_values; +} + template static T avg(const std::vector & v) { if (v.empty()) { @@ -126,7 +133,8 @@ struct cmd_params { std::vector n_prompt; std::vector n_gen; std::vector n_batch; - std::vector f32_kv; + std::vector type_k; + std::vector type_v; std::vector n_threads; std::vector n_gpu_layers; std::vector main_gpu; @@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = { /* n_prompt */ {512}, /* n_gen */ {128}, /* n_batch */ {512}, - /* f32_kv */ {false}, + /* type_k */ {GGML_TYPE_F16}, + /* type_v */ {GGML_TYPE_F16}, /* n_threads */ {get_num_physical_cores()}, /* n_gpu_layers */ {99}, /* main_gpu */ {0}, @@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); - printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str()); + printf(" -ctk , --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); + printf(" -ctv , --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); @@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); +} +static ggml_type ggml_type_from_name(const std::string & s) { + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + return GGML_TYPE_COUNT; } + static cmd_params parse_cmd_params(int argc, char ** argv) { cmd_params params; std::string arg; @@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.n_batch.insert(params.n_batch.end(), p.begin(), p.end()); - } else if (arg == "--memory-f32") { + } else if (arg == "-ctk" || arg == "--cache-type-k") { if (++i >= argc) { invalid_param = true; break; } - auto p = split(argv[i], split_delim); - params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end()); + auto p = split(argv[i], split_delim); + std::vector types; + for (const auto & t : p) { + ggml_type gt = ggml_type_from_name(t); + if (gt == GGML_TYPE_COUNT) { + invalid_param = true; + break; + } + types.push_back(gt); + } + params.type_k.insert(params.type_k.end(), types.begin(), types.end()); + } else if (arg == "-ctv" || arg == "--cache-type-v") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + std::vector types; + for (const auto & t : p) { + ggml_type gt = ggml_type_from_name(t); + if (gt == GGML_TYPE_COUNT) { + invalid_param = true; + break; + } + types.push_back(gt); + } + params.type_v.insert(params.type_v.end(), types.begin(), types.end()); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } - if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; } + if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } + if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; } @@ -336,7 +395,8 @@ struct cmd_params_instance { int n_prompt; int n_gen; int n_batch; - bool f32_kv; + ggml_type type_k; + ggml_type type_v; int n_threads; int n_gpu_layers; int main_gpu; @@ -365,7 +425,8 @@ struct cmd_params_instance { cparams.n_ctx = n_prompt + n_gen; cparams.n_batch = n_batch; - cparams.f16_kv = !f32_kv; + cparams.type_k = type_k; + cparams.type_v = type_v; cparams.mul_mat_q = mul_mat_q; return cparams; @@ -380,7 +441,8 @@ static std::vector get_cmd_params_instances_int(const cmd_p for (const auto & mg : params.main_gpu) for (const auto & ts : params.tensor_split) for (const auto & nb : params.n_batch) - for (const auto & fk : params.f32_kv) + for (const auto & tk : params.type_k) + for (const auto & tv : params.type_v) for (const auto & mmq : params.mul_mat_q) for (const auto & nt : params.n_threads) { cmd_params_instance instance = { @@ -388,7 +450,8 @@ static std::vector get_cmd_params_instances_int(const cmd_p /* .n_prompt = */ n_prompt, /* .n_gen = */ n_gen, /* .n_batch = */ nb, - /* .f32_kv = */ fk, + /* .type_k = */ tk, + /* .type_v = */ tv, /* .n_threads = */ nt, /* .n_gpu_layers = */ nl, /* .main_gpu = */ mg, @@ -410,7 +473,8 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & mg : params.main_gpu) for (const auto & ts : params.tensor_split) for (const auto & nb : params.n_batch) - for (const auto & fk : params.f32_kv) + for (const auto & tk : params.type_k) + for (const auto & tv : params.type_v) for (const auto & mmq : params.mul_mat_q) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { @@ -422,7 +486,8 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_prompt = */ n_prompt, /* .n_gen = */ 0, /* .n_batch = */ nb, - /* .f32_kv = */ fk, + /* .type_k = */ tk, + /* .type_v = */ tv, /* .n_threads = */ nt, /* .n_gpu_layers = */ nl, /* .main_gpu = */ mg, @@ -441,7 +506,8 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_prompt = */ 0, /* .n_gen = */ n_gen, /* .n_batch = */ nb, - /* .f32_kv = */ fk, + /* .type_k = */ tk, + /* .type_v = */ tv, /* .n_threads = */ nt, /* .n_gpu_layers = */ nl, /* .main_gpu = */ mg, @@ -489,7 +555,8 @@ struct test { uint64_t model_n_params; int n_batch; int n_threads; - bool f32_kv; + ggml_type type_k; + ggml_type type_v; int n_gpu_layers; int main_gpu; bool mul_mat_q; @@ -508,7 +575,8 @@ struct test { model_n_params = llama_model_n_params(lmodel); n_batch = inst.n_batch; n_threads = inst.n_threads; - f32_kv = inst.f32_kv; + type_k = inst.type_k; + type_v = inst.type_v; n_gpu_layers = inst.n_gpu_layers; main_gpu = inst.main_gpu; mul_mat_q = inst.mul_mat_q; @@ -571,7 +639,7 @@ struct test { "cuda", "opencl", "metal", "gpu_blas", "blas", "cpu_info", "gpu_info", "model_filename", "model_type", "model_size", "model_n_params", - "n_batch", "n_threads", "f16_kv", + "n_batch", "n_threads", "type_k", "type_v", "n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -621,7 +689,7 @@ struct test { std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas), cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), - std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv), + std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -805,8 +873,11 @@ struct markdown_printer : public printer { if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) { fields.push_back("n_batch"); } - if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) { - fields.push_back("f16_kv"); + if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) { + fields.push_back("type_k"); + } + if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) { + fields.push_back("type_v"); } if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) { fields.push_back("main_gpu"); From dd86df82e60b34779c36fb267e07f47dddc8b899 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 18:14:04 +0200 Subject: [PATCH 12/14] metal : use mm kernel only for quantum KV cache --- ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 3023aa6db5bf1..be4ab0f2ed47c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1129,7 +1129,7 @@ void ggml_metal_graph_compute( !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || ne12 > 1)) { + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; From 4adb1d69d9f090eda7486e76feb43d4824803a4f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 18:15:51 +0200 Subject: [PATCH 13/14] cuda : add comment --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 53e53a0d1634b..1200d1c888b42 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7849,6 +7849,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 #endif // GGML_CUDA_FORCE_DMMV if (use_mul_mat_vec_q) { + // NOTE: this kernel does not support ggml_nrows(src1) > 1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); } else { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); From af99c6fbfc815df7dad94d8c1f20d55927b2203a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 18:18:16 +0200 Subject: [PATCH 14/14] llama : remove memory_f16 and kv_f16 flags --- common/common.cpp | 6 ------ common/common.h | 1 - examples/quantize-stats/quantize-stats.cpp | 1 - examples/server/server.cpp | 4 ---- llama.cpp | 3 --- llama.h | 1 - 6 files changed, 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 77332d5dbb77b..a5b5c468c802b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.yarn_beta_slow = std::stof(argv[i]); - } else if (arg == "--memory-f32") { - params.memory_f16 = false; } else if (arg == "--top-p") { if (++i >= argc) { invalid_param = true; @@ -804,8 +802,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --no-penalize-nl do not penalize newline token\n"); - printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); @@ -948,7 +944,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.mul_mat_q = params.mul_mat_q; cparams.seed = params.seed; - cparams.f16_kv = params.memory_f16; cparams.logits_all = params.logits_all; cparams.embedding = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -1375,7 +1370,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); diff --git a/common/common.h b/common/common.h index 7f0d03e41dbd9..4cf471c7a8a16 100644 --- a/common/common.h +++ b/common/common.h @@ -98,7 +98,6 @@ struct gpt_params { size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS - bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 2712824774ae7..773024160f839 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -321,7 +321,6 @@ int main(int argc, char ** argv) { auto cparams = llama_context_default_params(); cparams.n_ctx = 256; cparams.seed = 1; - cparams.f16_kv = false; ctx = llama_new_context_with_model(model, cparams); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 911f7bbe1f85a..ef2a95004f453 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2108,10 +2108,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.yarn_beta_slow = std::stof(argv[i]); } - else if (arg == "--memory-f32" || arg == "--memory_f32") - { - params.memory_f16 = false; - } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) diff --git a/llama.cpp b/llama.cpp index 3f951dbe31952..800951ab8de72 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8583,7 +8583,6 @@ struct llama_context_params llama_context_default_params() { /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.mul_mat_q =*/ true, - /*.f16_kv =*/ true, /*.logits_all =*/ false, /*.embedding =*/ false, /*.offload_kqv =*/ true, @@ -8737,8 +8736,6 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - const ggml_type type_k = params.type_k; const ggml_type type_v = params.type_v; diff --git a/llama.h b/llama.h index f6c9d17519e7e..ead37562e37f3 100644 --- a/llama.h +++ b/llama.h @@ -196,7 +196,6 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) - bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU