Skip to content

Commit ba779f1

Browse files
FSSRepoggerganovslaren
authored
ggml : replace conv 1D - 2D stage_0 and stage_1 with im2col and mul_mat (#564)
* added conv2d stage 0 - 1 cuda kernels * add im2col + refactor conv1d and conv2d * fix params invalid index * add conv1d and conv2d unit tests * resolving wrong values and fix mul_mat validation * improve tests + reduce code duplication * add cuda kernels * more data test * fix ggml_op_count to 70 * add temp test - gemm != mul_mat * tests : fix test-mul-mat matrix multiplication * test-mul-mat match gemm == ggml_mul_mat with conv2d op * replaced gemm by ggml_mul_mat * ggml_mul_mat cpu backend support fp16 src1 * ggml_mul_mat cuda backend fp16 fixed * remove unnecessary ggml_cont and removed conv1d-2d functions deprecated * some fixes * explain conv1d reshapes * ggml : fix tests on Arm + do not use BLAS for F16 data * tests : fix FP16 handling on Arm * ggml : avoid ggml_cont and ggml_transpose in ggml_conv_xd * ci : switch back to release * cuda : fix wrong pointer usage * ggml : add metal support for im2col and f16xf16 mul mat * ggml : im2col opts * Update src/ggml-cuda.cu Co-authored-by: slaren <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent 239defe commit ba779f1

File tree

9 files changed

+1582
-1098
lines changed

9 files changed

+1582
-1098
lines changed

include/ggml/ggml.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,8 @@ extern "C" {
403403
GGML_OP_ROPE_BACK,
404404
GGML_OP_ALIBI,
405405
GGML_OP_CLAMP,
406-
GGML_OP_CONV_1D,
407-
GGML_OP_CONV_1D_STAGE_0, // internal
408-
GGML_OP_CONV_1D_STAGE_1, // internal
409406
GGML_OP_CONV_TRANSPOSE_1D,
410-
GGML_OP_CONV_2D,
411-
GGML_OP_CONV_2D_STAGE_0, // internal
412-
GGML_OP_CONV_2D_STAGE_1, // internal
407+
GGML_OP_IM2COL,
413408
GGML_OP_CONV_TRANSPOSE_2D,
414409
GGML_OP_POOL_1D,
415410
GGML_OP_POOL_2D,
@@ -1398,6 +1393,18 @@ extern "C" {
13981393
float min,
13991394
float max);
14001395

1396+
GGML_API struct ggml_tensor * ggml_im2col(
1397+
struct ggml_context * ctx,
1398+
struct ggml_tensor * a,
1399+
struct ggml_tensor * b,
1400+
int s0,
1401+
int s1,
1402+
int p0,
1403+
int p1,
1404+
int d0,
1405+
int d1,
1406+
bool is_2D);
1407+
14011408
GGML_API struct ggml_tensor * ggml_conv_1d(
14021409
struct ggml_context * ctx,
14031410
struct ggml_tensor * a,

src/ggml-cuda.cu

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
4040
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
4141
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
42+
#define cudaDeviceGetMemPool hipDeviceGetMemPool
4243
#define cudaDeviceProp hipDeviceProp_t
4344
#define cudaDeviceSynchronize hipDeviceSynchronize
4445
#define cudaError_t hipError_t
@@ -48,13 +49,15 @@
4849
#define cudaEvent_t hipEvent_t
4950
#define cudaEventDestroy hipEventDestroy
5051
#define cudaFree hipFree
52+
#define cudaFreeAsync hipFreeAsync
5153
#define cudaFreeHost hipHostFree
5254
#define cudaGetDevice hipGetDevice
5355
#define cudaGetDeviceCount hipGetDeviceCount
5456
#define cudaGetDeviceProperties hipGetDeviceProperties
5557
#define cudaGetErrorString hipGetErrorString
5658
#define cudaGetLastError hipGetLastError
5759
#define cudaMalloc hipMalloc
60+
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
5861
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
5962
#define cudaMemcpy hipMemcpy
6063
#define cudaMemcpy2DAsync hipMemcpy2DAsync
@@ -63,6 +66,9 @@
6366
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
6467
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
6568
#define cudaMemcpyKind hipMemcpyKind
69+
#define cudaMemPool_t hipMemPool_t
70+
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
71+
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
6672
#define cudaMemset hipMemset
6773
#define cudaMemsetAsync hipMemsetAsync
6874
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
@@ -4470,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
44704476
*dsti = __float2half(*xi);
44714477
}
44724478

4479+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
4480+
const half * xi = (const half *) cxi;
4481+
half * dsti = (half *) cdsti;
4482+
4483+
*dsti = *xi;
4484+
}
4485+
44734486
template <cpy_kernel_t cpy_1>
44744487
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
44754488
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4723,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
47234736
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
47244737
}
47254738

4739+
static __global__ void im2col_f32_f16(
4740+
const float * x, half * dst,
4741+
int ofs0, int ofs1, int IW, int IH, int CHW,
4742+
int s0, int s1, int p0, int p1, int d0, int d1) {
4743+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4744+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4745+
4746+
const int offset_dst =
4747+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4748+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4749+
4750+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
4751+
dst[offset_dst] = __float2half(0.0f);
4752+
} else {
4753+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4754+
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4755+
}
4756+
}
4757+
47264758
template<int qk, int qr, dequantize_kernel_t dq>
47274759
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
47284760
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5612,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
56125644
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
56135645
}
56145646

5647+
static void ggml_cpy_f16_f16_cuda(
5648+
const char * cx, char * cdst, const int ne,
5649+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5650+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5651+
5652+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
5653+
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
5654+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5655+
}
5656+
56155657
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
56165658
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
56175659
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5695,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
56955737
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
56965738
}
56975739

5740+
static void im2col_f32_f16_cuda(const float * x, half * dst,
5741+
int OH, int IW, int IH, int OW, int IC,
5742+
int KH, int KW, int N, int ofs0, int ofs1,
5743+
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5744+
dim3 block_nums(IC, OH, OW);
5745+
dim3 block_dims(N, KH, KW);
5746+
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5747+
}
5748+
56985749
// buffer pool for cuda
56995750
#define MAX_CUDA_BUFFERS 256
57005751

@@ -6477,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64776528
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
64786529
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64796530
}
6480-
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6531+
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
64816532
size_t dst_f16_as = 0;
64826533
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
64836534

@@ -6653,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
66536704
(void) src1_dd;
66546705
}
66556706

6707+
inline void ggml_cuda_op_im2col(
6708+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6709+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6710+
6711+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
6712+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
6713+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
6714+
6715+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6716+
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
6717+
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
6718+
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
6719+
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
6720+
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
6721+
6722+
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
6723+
6724+
const int64_t N = src1->ne[is_2D ? 3 : 2];
6725+
const int64_t IC = src1->ne[is_2D ? 2 : 1];
6726+
const int64_t IH = is_2D ? src1->ne[1] : 1;
6727+
const int64_t IW = src1->ne[0];
6728+
6729+
const int64_t KH = is_2D ? src0->ne[1] : 1;
6730+
const int64_t KW = src0->ne[0];
6731+
6732+
const int64_t OH = is_2D ? dst->ne[2] : 1;
6733+
const int64_t OW = dst->ne[1];
6734+
6735+
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
6736+
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
6737+
6738+
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
6739+
OH, IW, IH, OW, IC, KH, KW, N,
6740+
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6741+
6742+
(void) src0;
6743+
(void) src0_dd;
6744+
}
6745+
66566746
inline void ggml_cuda_op_diag_mask_inf(
66576747
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
66586748
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7543,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
75437633
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
75447634
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
75457635
ne10, ne11, nb10, nb11, nb12, main_stream);
7636+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7637+
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7638+
ne10, ne11, nb10, nb11, nb12, main_stream);
75467639
} else {
75477640
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
75487641
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7574,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
75747667
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
75757668
}
75767669

7670+
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7671+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7672+
}
7673+
75777674
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
75787675
(void) src0;
75797676
(void) src1;
@@ -7937,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
79378034
case GGML_OP_ALIBI:
79388035
func = ggml_cuda_alibi;
79398036
break;
8037+
case GGML_OP_IM2COL:
8038+
func = ggml_cuda_im2col;
8039+
break;
79408040
default:
79418041
return false;
79428042
}

src/ggml-metal.m

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
GGML_METAL_DECL_KERNEL(rms_norm);
8787
GGML_METAL_DECL_KERNEL(norm);
8888
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
8990
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
9091
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
9192
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@
114115
GGML_METAL_DECL_KERNEL(rope_f32);
115116
GGML_METAL_DECL_KERNEL(rope_f16);
116117
GGML_METAL_DECL_KERNEL(alibi_f32);
118+
GGML_METAL_DECL_KERNEL(im2col_f16);
117119
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118120
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
119121
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -287,6 +289,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
287289
GGML_METAL_ADD_KERNEL(rms_norm);
288290
GGML_METAL_ADD_KERNEL(norm);
289291
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
292+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
290293
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291294
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292295
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -317,6 +320,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
317320
GGML_METAL_ADD_KERNEL(rope_f32);
318321
GGML_METAL_ADD_KERNEL(rope_f16);
319322
GGML_METAL_ADD_KERNEL(alibi_f32);
323+
GGML_METAL_ADD_KERNEL(im2col_f16);
320324
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321325
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
322326
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
386390
GGML_METAL_DEL_KERNEL(rms_norm);
387391
GGML_METAL_DEL_KERNEL(norm);
388392
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
393+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
389394
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
390395
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
391396
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416421
GGML_METAL_DEL_KERNEL(rope_f32);
417422
GGML_METAL_DEL_KERNEL(rope_f16);
418423
GGML_METAL_DEL_KERNEL(alibi_f32);
424+
GGML_METAL_DEL_KERNEL(im2col_f16);
419425
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
420426
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
421427
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -1030,7 +1036,7 @@ void ggml_metal_graph_compute(
10301036
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
10311037
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
10321038
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1033-
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1039+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
10341040

10351041
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10361042
} break;
@@ -1139,20 +1145,26 @@ void ggml_metal_graph_compute(
11391145
switch (src0t) {
11401146
case GGML_TYPE_F32:
11411147
{
1148+
GGML_ASSERT(src1t == GGML_TYPE_F32);
11421149
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
11431150
nrows = 4;
11441151
} break;
11451152
case GGML_TYPE_F16:
11461153
{
11471154
nth0 = 32;
11481155
nth1 = 1;
1149-
if (ne11 * ne12 < 4) {
1150-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1151-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1152-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1153-
nrows = ne11;
1156+
if (src1t == GGML_TYPE_F32) {
1157+
if (ne11 * ne12 < 4) {
1158+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1159+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1160+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1161+
nrows = ne11;
1162+
} else {
1163+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1164+
nrows = 4;
1165+
}
11541166
} else {
1155-
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1167+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
11561168
nrows = 4;
11571169
}
11581170
} break;
@@ -1342,7 +1354,7 @@ void ggml_metal_graph_compute(
13421354
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
13431355
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
13441356
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1345-
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1357+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
13461358

13471359
const int64_t nrows = ggml_nrows(src0);
13481360

@@ -1361,7 +1373,7 @@ void ggml_metal_graph_compute(
13611373
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
13621374
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
13631375
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1364-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1376+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
13651377

13661378
const int64_t nrows = ggml_nrows(src0);
13671379

@@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(
14641476

14651477
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
14661478
} break;
1479+
case GGML_OP_IM2COL:
1480+
{
1481+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
1482+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
1483+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
1484+
1485+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1486+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1487+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1488+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1489+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1490+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1491+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1492+
1493+
const int32_t N = src1->ne[is_2D ? 3 : 2];
1494+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
1495+
const int32_t IH = is_2D ? src1->ne[1] : 1;
1496+
const int32_t IW = src1->ne[0];
1497+
1498+
const int32_t KH = is_2D ? src0->ne[1] : 1;
1499+
const int32_t KW = src0->ne[0];
1500+
1501+
const int32_t OH = is_2D ? dst->ne[2] : 1;
1502+
const int32_t OW = dst->ne[1];
1503+
1504+
const int32_t CHW = IC * KH * KW;
1505+
1506+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1507+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1508+
1509+
switch (src0->type) {
1510+
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
1511+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1512+
default: GGML_ASSERT(false);
1513+
};
1514+
1515+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1516+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1517+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1518+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1519+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1520+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1521+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1522+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1523+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1524+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1525+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1526+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1527+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1528+
1529+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1530+
} break;
14671531
case GGML_OP_DUP:
14681532
case GGML_OP_CPY:
14691533
case GGML_OP_CONT:

0 commit comments

Comments
 (0)