Skip to content

Commit 2e4db48

Browse files
committed
ggml : update get_rows f16 and q
1 parent ac3f7d8 commit 2e4db48

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ ifdef LLAMA_CUBLAS
396396
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
397397
OBJS += ggml-cuda.o
398398
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
399+
400+
ifdef LLAMA_DEBUG
401+
NVCCFLAGS += -lineinfo
402+
endif
403+
399404
ifdef LLAMA_CUDA_NVCC
400405
NVCC = $(LLAMA_CUDA_NVCC)
401406
else

ggml.c

+21-18
Original file line numberDiff line numberDiff line change
@@ -4086,7 +4086,7 @@ struct ggml_tensor * ggml_mul_mat_id(
40864086
GGML_ASSERT(ids->ne[1] == b->ne[1]);
40874087
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
40884088
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
4089-
GGML_ASSERT(id >= 0 && id < n_as);
4089+
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
40904090

40914091
bool is_node = false;
40924092

@@ -10345,7 +10345,7 @@ static void ggml_compute_forward_get_rows_q(
1034510345
GGML_TENSOR_BINARY_OP_LOCALS
1034610346

1034710347
const int64_t nc = ne00;
10348-
const int64_t nr = ggml_nelements(src1);
10348+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
1034910349

1035010350
const enum ggml_type type = src0->type;
1035110351
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -10356,14 +10356,16 @@ static void ggml_compute_forward_get_rows_q(
1035610356
assert(ggml_nrows(dst) == nr);
1035710357

1035810358
// TODO: multi-thread
10359-
for (int64_t i = 0; i < nr; ++i) {
10360-
const int64_t r = ((int32_t *) src1->data)[i];
10361-
10362-
const int64_t i02 = i/ne10;
10359+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
10360+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
10361+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
10362+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
1036310363

10364-
dequantize_row_q(
10365-
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
10366-
(float *) ((char *) dst->data + i*nb1), nc);
10364+
dequantize_row_q(
10365+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10366+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10367+
}
10368+
}
1036710369
}
1036810370
}
1036910371

@@ -10381,22 +10383,23 @@ static void ggml_compute_forward_get_rows_f16(
1038110383
GGML_TENSOR_BINARY_OP_LOCALS
1038210384

1038310385
const int64_t nc = ne00;
10384-
const int64_t nr = ggml_nelements(src1);
10386+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
1038510387

1038610388
assert(ne0 == nc);
1038710389
assert(ne02 == ne11);
1038810390
assert(nb00 == sizeof(ggml_fp16_t));
1038910391
assert(ggml_nrows(dst) == nr);
1039010392

1039110393
// TODO: multi-thread
10392-
for (int64_t i = 0; i < nr; ++i) {
10393-
const int64_t r = ((int32_t *) src1->data)[i];
10394-
10395-
const int64_t i02 = i/ne10;
10394+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
10395+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
10396+
for (int64_t i10 = 0; i10 < ne10; ++i10) {
10397+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
1039610398

10397-
for (int j = 0; j < nc; ++j) {
10398-
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
10399-
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
10399+
ggml_fp16_to_fp32_row(
10400+
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10401+
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10402+
}
1040010403
}
1040110404
}
1040210405
}
@@ -10415,14 +10418,14 @@ static void ggml_compute_forward_get_rows_f32(
1041510418
GGML_TENSOR_BINARY_OP_LOCALS
1041610419

1041710420
const int64_t nc = ne00;
10421+
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
1041810422

1041910423
assert(ne0 == nc);
1042010424
assert(ne02 == ne11);
1042110425
assert(nb00 == sizeof(float));
1042210426
assert(ggml_nrows(dst) == nr);
1042310427

1042410428
// TODO: multi-thread
10425-
// TODO: same impl for get_rows_q and get_rows_f16
1042610429
for (int64_t i12 = 0; i12 < ne12; ++i12) {
1042710430
for (int64_t i11 = 0; i11 < ne11; ++i11) {
1042810431
for (int64_t i10 = 0; i10 < ne10; ++i10) {

0 commit comments

Comments
 (0)