Skip to content

Commit 80dd7ff

Browse files
authored
vulkan: Optimize contiguous copies (#10254)
* tests: Fix memory bandwidth calculation for perf tests Add a flops calculation for flash attention. Add one GGML_OP_CPY perf test. * vulkan: Optimize contiguous copies Add a variant of the copy shader for when the tensors are contiguous. Avoid the complex addressing calculations, and do four elements per invocation to hide some other overhead. Apply similar changes to the scale shader, since scale is always contiguous. Add a "progress bar" for shader compiles.
1 parent 54ef9cf commit 80dd7ff

13 files changed

+144
-27
lines changed

ggml/src/ggml-vulkan.cpp

+57-19
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ struct vk_device_struct {
196196
vk_pipeline pipeline_pad_f32;
197197
vk_pipeline pipeline_repeat_f32;
198198
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
199+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
199200
vk_pipeline pipeline_norm_f32;
200201
vk_pipeline pipeline_group_norm_f32;
201202
vk_pipeline pipeline_rms_norm_f32;
@@ -722,6 +723,12 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
722723
std::lock_guard<std::mutex> guard(compile_count_mutex);
723724
assert(compile_count > 0);
724725
compile_count--;
726+
727+
// "Progress bar" for shader compiles
728+
static uint32_t total_compile_count = 0;
729+
if ((total_compile_count++ % 10) == 0) {
730+
std::cerr << ".";
731+
}
725732
}
726733
compile_count_cond.notify_all();
727734
}
@@ -1200,6 +1207,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
12001207
static void ggml_vk_load_shaders(vk_device& device) {
12011208
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
12021209

1210+
std::cerr << "ggml_vulkan: Compiling shaders";
1211+
12031212
// mulmat
12041213
std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
12051214
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
@@ -1759,6 +1768,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
17591768
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
17601769
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
17611770

1771+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1772+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1773+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1774+
17621775
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
17631776
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
17641777

@@ -1817,6 +1830,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18171830
for (auto &c : compiles) {
18181831
c.wait();
18191832
}
1833+
std::cerr << "Done!" << std::endl;
18201834
}
18211835

18221836
static vk_device ggml_vk_get_device(size_t idx) {
@@ -3061,18 +3075,34 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
30613075
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
30623076
}
30633077

3064-
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
3065-
if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
3066-
return ctx->device->pipeline_cpy_f32_f32;
3078+
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
3079+
3080+
// Choose "contiguous copy" shader if src/dst are contiguous
3081+
bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
3082+
3083+
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
3084+
if (contig) {
3085+
return ctx->device->pipeline_contig_cpy_f32_f32;
3086+
} else {
3087+
return ctx->device->pipeline_cpy_f32_f32;
3088+
}
30673089
}
3068-
if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
3069-
return ctx->device->pipeline_cpy_f32_f16;
3090+
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
3091+
if (contig) {
3092+
return ctx->device->pipeline_contig_cpy_f32_f16;
3093+
} else {
3094+
return ctx->device->pipeline_cpy_f32_f16;
3095+
}
30703096
}
3071-
if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
3072-
return ctx->device->pipeline_cpy_f16_f16;
3097+
if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
3098+
if (contig) {
3099+
return ctx->device->pipeline_contig_cpy_f16_f16;
3100+
} else {
3101+
return ctx->device->pipeline_cpy_f16_f16;
3102+
}
30733103
}
30743104

3075-
std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
3105+
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
30763106
GGML_ABORT("fatal error");
30773107
}
30783108

@@ -3082,6 +3112,15 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
30823112
const int tensor_type_size = ggml_type_size(tensor->type);
30833113

30843114
const uint32_t ne = ggml_nelements(tensor);
3115+
std::array<uint32_t, 3> elements;
3116+
3117+
if (ne > 262144) {
3118+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
3119+
} else if (ne > 512) {
3120+
elements = { 512, CEIL_DIV(ne, 512), 1 };
3121+
} else {
3122+
elements = { ne, 1, 1 };
3123+
}
30853124

30863125
const vk_op_unary_push_constants pc = {
30873126
(uint32_t)ne,
@@ -3091,7 +3130,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
30913130
0.0f, 0.0f,
30923131
};
30933132
ggml_vk_sync_buffers(subctx);
3094-
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
3133+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
30953134
}
30963135

30973136
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -3176,12 +3215,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
31763215
vk_pipeline to_fp16_vk_1 = nullptr;
31773216

31783217
if (x_non_contig) {
3179-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
3218+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
31803219
} else {
31813220
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
31823221
}
31833222
if (y_non_contig) {
3184-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
3223+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
31853224
} else {
31863225
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
31873226
}
@@ -3361,10 +3400,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
33613400
vk_pipeline to_fp16_vk_0 = nullptr;
33623401
vk_pipeline to_fp16_vk_1 = nullptr;
33633402
if (x_non_contig) {
3364-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
3403+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
33653404
}
33663405
if (y_non_contig) {
3367-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
3406+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
33683407
} else {
33693408
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
33703409
}
@@ -3745,12 +3784,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
37453784
vk_pipeline to_fp16_vk_1 = nullptr;
37463785

37473786
if (x_non_contig) {
3748-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
3787+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
37493788
} else {
37503789
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
37513790
}
37523791
if (y_non_contig) {
3753-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
3792+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
37543793
} else {
37553794
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
37563795
}
@@ -3938,10 +3977,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
39383977
vk_pipeline to_fp16_vk_0 = nullptr;
39393978
vk_pipeline to_fp16_vk_1 = nullptr;
39403979
if (x_non_contig) {
3941-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
3980+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
39423981
}
39433982
if (y_non_contig) {
3944-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
3983+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
39453984
} else {
39463985
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
39473986
}
@@ -4148,7 +4187,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
41484187
case GGML_OP_CPY:
41494188
case GGML_OP_CONT:
41504189
case GGML_OP_DUP:
4151-
return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
4190+
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
41524191
case GGML_OP_NORM:
41534192
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
41544193
return ctx->device->pipeline_norm_f32;
@@ -4281,7 +4320,6 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
42814320
case GGML_OP_DIV:
42824321
case GGML_OP_CONCAT:
42834322
case GGML_OP_UPSCALE:
4284-
case GGML_OP_SCALE:
42854323
case GGML_OP_SQR:
42864324
case GGML_OP_SIN:
42874325
case GGML_OP_COS:

ggml/src/vulkan-shaders/clamp.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = get_idx();
810

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : require
7+
8+
const uint num_threads = 128;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
12+
void main() {
13+
uint idx = get_idx();
14+
15+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16+
const uint num_iter = 4;
17+
18+
// fast path for when all four iterations are in-bounds
19+
if (idx + (num_iter-1)*num_threads < p.ne) {
20+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
21+
#ifndef OPTIMIZATION_ERROR_WORKAROUND
22+
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
23+
#else
24+
data_d[p.d_offset + idx] = data_a[idx];
25+
#endif
26+
idx += num_threads;
27+
}
28+
} else {
29+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
30+
if (idx >= p.ne) {
31+
continue;
32+
}
33+
34+
#ifndef OPTIMIZATION_ERROR_WORKAROUND
35+
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
36+
#else
37+
data_d[p.d_offset + idx] = data_a[idx];
38+
#endif
39+
idx += num_threads;
40+
}
41+
}
42+
}

ggml/src/vulkan-shaders/copy.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = get_idx();
810

ggml/src/vulkan-shaders/cos.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = get_idx();
810

ggml/src/vulkan-shaders/generic_unary_head.comp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#extension GL_EXT_shader_16bit_storage : require
2+
#extension GL_EXT_control_flow_attributes : require
23

34
layout (push_constant) uniform parameter
45
{
@@ -9,8 +10,6 @@ layout (push_constant) uniform parameter
910
float param1; float param2;
1011
} p;
1112

12-
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
13-
1413
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1514
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
1615

ggml/src/vulkan-shaders/pad.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
810

ggml/src/vulkan-shaders/repeat.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
uint src0_idx_mod(uint idx) {
79
const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
810
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;

ggml/src/vulkan-shaders/scale.comp

+15-5
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,22 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
const uint num_threads = 128;
7+
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
610
void main() {
7-
const uint idx = get_idx();
11+
uint idx = get_idx();
812

9-
if (idx >= p.ne) {
10-
return;
11-
}
13+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
14+
const uint num_iter = 4;
1215

13-
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1));
16+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
17+
if (idx >= p.ne) {
18+
continue;
19+
}
20+
21+
data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
22+
idx += num_threads;
23+
}
1424
}

ggml/src/vulkan-shaders/sin.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = get_idx();
810

ggml/src/vulkan-shaders/square.comp

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "types.comp"
44
#include "generic_unary_head.comp"
55

6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
68
void main() {
79
const uint idx = get_idx();
810

ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ void process_shaders() {
350350
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
351351
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
352352
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
353+
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
354+
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
355+
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
353356

354357
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
355358
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});

tests/test-backend-ops.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -681,13 +681,15 @@ struct test_case {
681681

682682
// run
683683
int64_t total_time_us = 0;
684+
int64_t total_mem = 0;
684685
int total_runs = 0;
685686
do {
686687
int64_t start_time = ggml_time_us();
687688
ggml_backend_graph_compute(backend, gf);
688689
int64_t end_time = ggml_time_us();
689690

690691
total_time_us += end_time - start_time;
692+
total_mem += mem;
691693
total_runs += n_runs;
692694
} while (total_time_us < 1000*1000); // run for at least 1 second
693695

@@ -717,7 +719,7 @@ struct test_case {
717719
} else {
718720
printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
719721
op_size(out) / 1024,
720-
mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
722+
total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
721723
}
722724
printf("\n");
723725

@@ -2740,6 +2742,13 @@ struct test_flash_attn_ext : public test_case {
27402742
return 5e-4;
27412743
}
27422744

2745+
uint64_t op_flops(ggml_tensor * t) override {
2746+
GGML_UNUSED(t);
2747+
// Just counting matmul costs:
2748+
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
2749+
return 2 * 2 * nh * nb * hs * kv;
2750+
}
2751+
27432752
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
27442753
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
27452754
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
@@ -3779,6 +3788,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
37793788
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
37803789
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
37813790

3791+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
3792+
37823793
for (int bs : {1, 512}) {
37833794
for (ggml_type type_a : all_types) {
37843795
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)