@@ -1987,6 +1987,9 @@ struct llama_context {
1987
1987
std::vector<uint8_t> buf_compute_meta;
1988
1988
ggml_backend_sched_t sched = nullptr;
1989
1989
1990
+ ggml_abort_callback abort_callback = nullptr;
1991
+ void * abort_callback_data = nullptr;
1992
+
1990
1993
// input tensors
1991
1994
ggml_backend_buffer_t buf_input = nullptr;
1992
1995
ggml_context * ctx_input = nullptr;
@@ -8071,6 +8074,7 @@ static void llama_graph_compute(
8071
8074
8072
8075
if (lctx.backend_cpu != nullptr) {
8073
8076
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
8077
+ ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
8074
8078
}
8075
8079
8076
8080
ggml_backend_sched_graph_compute(lctx.sched, gf);
@@ -11856,6 +11860,8 @@ struct llama_context_params llama_context_default_params() {
11856
11860
/*.embedding =*/ false,
11857
11861
/*.offload_kqv =*/ true,
11858
11862
/*.do_pooling =*/ true,
11863
+ /*.abort_callback =*/ nullptr,
11864
+ /*.abort_callback_data =*/ nullptr,
11859
11865
};
11860
11866
11861
11867
return result;
@@ -12038,8 +12044,11 @@ struct llama_context * llama_new_context_with_model(
12038
12044
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
12039
12045
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
12040
12046
12041
- ctx->rng = std::mt19937(params.seed);
12042
- ctx->logits_all = params.logits_all;
12047
+ ctx->abort_callback = params.abort_callback;
12048
+ ctx->abort_callback_data = params.abort_callback_data;
12049
+
12050
+ ctx->rng = std::mt19937(params.seed);
12051
+ ctx->logits_all = params.logits_all;
12043
12052
12044
12053
const ggml_type type_k = params.type_k;
12045
12054
const ggml_type type_v = params.type_v;
@@ -12989,6 +12998,11 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
12989
12998
ctx->cparams.n_threads_batch = n_threads_batch;
12990
12999
}
12991
13000
13001
+ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
13002
+ ctx->abort_callback = abort_callback;
13003
+ ctx->abort_callback_data = abort_callback_data;
13004
+ }
13005
+
12992
13006
struct llama_batch llama_batch_get_one(
12993
13007
llama_token * tokens,
12994
13008
int32_t n_tokens,
0 commit comments