From 17b8a2e669a5066575a278328258b24aec0a283a Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Mon, 11 Nov 2024 21:31:41 +0000 Subject: [PATCH 1/3] sycl : Fixes RWKV6 broken build in the cuda backend --- ggml/src/ggml-sycl/outprod.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index c2779df0ecf91..e61cdc2ca5d53 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -1,4 +1,5 @@ #include +#include #include "outprod.hpp" @@ -39,7 +40,7 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr try { // Perform matrix multiplication using oneMKL GEMM - oneapi::mkl::blas::gemm(*stream, + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, From f6ea8b73f87870ad881a8344656aa4827a9443c4 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Mon, 11 Nov 2024 21:33:41 +0000 Subject: [PATCH 2/3] sycl : marks permuted MUL_MAT as unsupported --- ggml/src/ggml-sycl.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 255bc64c6badd..2dba15d237e92 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -4350,6 +4350,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (op->op == GGML_OP_MUL_MAT) { a = op->src[0]; b = op->src[1]; + if (ggml_is_permuted(a) || ggml_is_permuted(b)) { + // TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021 + return false; + } } else { a = op->src[2]; b = op->src[1]; From 6a2c025684b328b7d7c3514bc2087b8b26f604b9 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Tue, 12 Nov 2024 09:31:07 +0000 Subject: [PATCH 3/3] sycl : fix norm asserts in debug build --- ggml/src/ggml-sycl/norm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index b3159b9d1b94d..72d8fdb878c8d 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { @@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con int end = start + group_size; const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); start += item_ct1.get_local_id(2); int nreduce = nwarps / WARP_SIZE; @@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa const int tid = item_ct1.get_local_id(2); const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:17: The work-group size passed to the SYCL kernel may exceed @@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:18: The work-group size passed to the SYCL kernel may exceed @@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:19: The work-group size passed to the SYCL kernel may exceed