Skip to content

Commit d9e4870

Browse files
committed
wire to grpc
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 01e2e3d commit d9e4870

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

backend/cpp/llama/grpc-server.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,12 @@ struct llama_server_context
527527
}
528528
}
529529

530+
// Enable reranking if embeddings are enabled
531+
if (params.embedding) {
532+
params.rerank = true;
533+
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
534+
}
535+
530536
common_init_result common_init = common_init_from_params(params);
531537
model = common_init.model.release();
532538
ctx = common_init.context.release();
@@ -2670,6 +2676,46 @@ class BackendServiceImpl final : public backend::Backend::Service {
26702676
return grpc::Status::OK;
26712677
}
26722678

2679+
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
2680+
// Create a JSON object with the query and documents
2681+
json data = {
2682+
{"prompt", request->query()},
2683+
{"documents", request->documents()},
2684+
{"top_n", request->top_n()}
2685+
};
2686+
2687+
// Generate a new task ID
2688+
const int task_id = llama.queue_tasks.get_new_id();
2689+
llama.queue_results.add_waiting_task_id(task_id);
2690+
2691+
// Queue the task with reranking mode enabled
2692+
llama.request_completion(task_id, data, false, false, true, -1);
2693+
2694+
// Get the result
2695+
task_result result = llama.queue_results.recv(task_id);
2696+
llama.queue_results.remove_waiting_task_id(task_id);
2697+
2698+
if (!result.error && result.stop) {
2699+
// Set usage information
2700+
backend::Usage* usage = rerankResult->mutable_usage();
2701+
usage->set_total_tokens(result.result_json.value("tokens", 0));
2702+
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
2703+
2704+
// Get the score from the result
2705+
float score = result.result_json.value("score", 0.0f);
2706+
2707+
// Create document results for each input document
2708+
for (int i = 0; i < request->documents_size(); i++) {
2709+
backend::DocumentResult* doc_result = rerankResult->add_results();
2710+
doc_result->set_index(i);
2711+
doc_result->set_text(request->documents(i));
2712+
doc_result->set_relevance_score(score);
2713+
}
2714+
}
2715+
2716+
return grpc::Status::OK;
2717+
}
2718+
26732719
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
26742720
llama_client_slot* active_slot = llama.get_active_slot();
26752721

0 commit comments

Comments
 (0)