Skip to content

Commit 0b2b89e

Browse files
ggerganovngxson
authored andcommitted
server : output embeddings for all tokens when pooling = none (ggml-org#10861)
* server : add "tokens" output ggml-ci * server : output embeddings for all tokens when pooling = none ggml-ci * server : update readme [no ci] * server : fix spacing [no ci] Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * server : be explicit about the pooling type in the tests ggml-ci * server : update /embeddings and /v1/embeddings endpoints ggml-ci * server : do not normalize embeddings when there is no pooling ggml-ci * server : update readme ggml-ci * server : fixes * tests : update server tests ggml-ci * server : update readme [no ci] * server : remove rebase artifact --------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
1 parent 6d7f216 commit 0b2b89e

File tree

8 files changed

+158
-37
lines changed

8 files changed

+158
-37
lines changed

common/common.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
17801780
break;
17811781
case 0: // max absolute
17821782
for (int i = 0; i < n; i++) {
1783-
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
1783+
if (sum < std::abs(inp[i])) {
1784+
sum = std::abs(inp[i]);
1785+
}
17841786
}
17851787
sum /= 32760.0; // make an int16 range
17861788
break;

common/common.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
596596
// Embedding utils
597597
//
598598

599-
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
599+
// TODO: repace embd_norm with an enum
600+
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
600601

601602
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
602603

examples/gritlm/gritlm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
7575
}
7676

7777
std::vector<float> emb_norm(emb_unorm.size());
78-
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
78+
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
7979
result.push_back(emb_norm);
8080

8181
#ifdef GRIT_DEBUG

examples/retrieval/retrieval.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
107107
}
108108

109109
float * out = output + batch.seq_id[i][0] * n_embd;
110-
common_embd_normalize(embd, out, n_embd);
110+
common_embd_normalize(embd, out, n_embd, 2);
111111
}
112112
}
113113

examples/server/README.md

+42
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ curl http://localhost:8080/v1/chat/completions \
763763

764764
### POST `/v1/embeddings`: OpenAI-compatible embeddings API
765765

766+
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
767+
766768
*Options:*
767769

768770
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
@@ -795,6 +797,46 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
795797
}'
796798
```
797799

800+
### POST `/embeddings`: non-OpenAI-compatible embeddings API
801+
802+
This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm.
803+
804+
Note that the response format of this endpoint is different from `/v1/embeddings`.
805+
806+
*Options:*
807+
808+
Same as the `/v1/embeddings` endpoint.
809+
810+
*Examples:*
811+
812+
Same as the `/v1/embeddings` endpoint.
813+
814+
**Response format**
815+
816+
```json
817+
[
818+
{
819+
"index": 0,
820+
"embedding": [
821+
[ ... embeddings for token 0 ... ],
822+
[ ... embeddings for token 1 ... ],
823+
[ ... ]
824+
[ ... embeddings for token N-1 ... ],
825+
]
826+
},
827+
...
828+
{
829+
"index": P,
830+
"embedding": [
831+
[ ... embeddings for token 0 ... ],
832+
[ ... embeddings for token 1 ... ],
833+
[ ... ]
834+
[ ... embeddings for token N-1 ... ],
835+
]
836+
}
837+
]
838+
```
839+
798840
### GET `/slots`: Returns the current slots processing state
799841

800842
> [!WARNING]

examples/server/server.cpp

+56-18
Original file line numberDiff line numberDiff line change
@@ -726,18 +726,32 @@ struct server_task_result_cmpl_partial : server_task_result {
726726

727727
struct server_task_result_embd : server_task_result {
728728
int index = 0;
729-
std::vector<float> embedding;
729+
std::vector<std::vector<float>> embedding;
730730

731731
int32_t n_tokens;
732732

733+
// OAI-compat fields
734+
bool oaicompat = false;
735+
733736
virtual int get_index() override {
734737
return index;
735738
}
736739

737740
virtual json to_json() override {
741+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
742+
}
743+
744+
json to_json_non_oaicompat() {
745+
return json {
746+
{"index", index},
747+
{"embedding", embedding},
748+
};
749+
}
750+
751+
json to_json_oaicompat() {
738752
return json {
739753
{"index", index},
740-
{"embedding", embedding},
754+
{"embedding", embedding[0]},
741755
{"tokens_evaluated", n_tokens},
742756
};
743757
}
@@ -2017,9 +2031,10 @@ struct server_context {
20172031

20182032
void send_embedding(const server_slot & slot, const llama_batch & batch) {
20192033
auto res = std::make_unique<server_task_result_embd>();
2020-
res->id = slot.id_task;
2021-
res->index = slot.index;
2022-
res->n_tokens = slot.n_prompt_tokens;
2034+
res->id = slot.id_task;
2035+
res->index = slot.index;
2036+
res->n_tokens = slot.n_prompt_tokens;
2037+
res->oaicompat = slot.params.oaicompat;
20232038

20242039
const int n_embd = llama_n_embd(model);
20252040

@@ -2038,12 +2053,18 @@ struct server_context {
20382053
if (embd == NULL) {
20392054
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
20402055

2041-
res->embedding = std::vector<float>(n_embd, 0.0f);
2056+
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
20422057
continue;
20432058
}
20442059

2045-
common_embd_normalize(embd, embd_res.data(), n_embd);
2046-
res->embedding = embd_res;
2060+
// normalize only when there is pooling
2061+
// TODO: configurable
2062+
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2063+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2064+
res->embedding.push_back(embd_res);
2065+
} else {
2066+
res->embedding.push_back({ embd, embd + n_embd });
2067+
}
20472068
}
20482069

20492070
SLT_DBG(slot, "%s", "sending embeddings\n");
@@ -2657,7 +2678,10 @@ struct server_context {
26572678

26582679
// add prompt tokens for processing in the current batch
26592680
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2660-
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
2681+
// without pooling, we want to output the embeddings for all the tokens in the batch
2682+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
2683+
2684+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
26612685

26622686
if (slot.params.cache_prompt) {
26632687
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3665,14 +3689,17 @@ int main(int argc, char ** argv) {
36653689
res_ok(res, data);
36663690
};
36673691

3668-
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3692+
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
36693693
const json body = json::parse(req.body);
3670-
bool oaicompat = false;
3694+
3695+
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
3696+
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
3697+
return;
3698+
}
36713699

36723700
// for the shape of input/content, see tokenize_input_prompts()
36733701
json prompt;
3674-
if (body.contains("input")) {
3675-
oaicompat = true;
3702+
if (body.count("input") != 0) {
36763703
prompt = body.at("input");
36773704
} else if (body.contains("content")) {
36783705
oaicompat = false;
@@ -3697,10 +3724,15 @@ int main(int argc, char ** argv) {
36973724
{
36983725
std::vector<server_task> tasks;
36993726
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3700-
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3727+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3728+
37013729
task.id = ctx_server.queue_tasks.get_new_id();
37023730
task.index = i;
37033731
task.prompt_tokens = std::move(tokenized_prompts[i]);
3732+
3733+
// OAI-compat
3734+
task.params.oaicompat = oaicompat;
3735+
37043736
tasks.push_back(task);
37053737
}
37063738

@@ -3728,12 +3760,18 @@ int main(int argc, char ** argv) {
37283760
}
37293761

37303762
// write JSON response
3731-
json root = oaicompat
3732-
? format_embeddings_response_oaicompat(body, responses)
3733-
: responses.size() == 1 ? responses[0] : json(responses);
3763+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
37343764
res_ok(res, root);
37353765
};
37363766

3767+
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3768+
handle_embeddings_impl(req, res, false);
3769+
};
3770+
3771+
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3772+
handle_embeddings_impl(req, res, true);
3773+
};
3774+
37373775
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
37383776
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
37393777
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
@@ -3907,7 +3945,7 @@ int main(int argc, char ** argv) {
39073945
svr->Post("/infill", handle_infill);
39083946
svr->Post("/embedding", handle_embeddings); // legacy
39093947
svr->Post("/embeddings", handle_embeddings);
3910-
svr->Post("/v1/embeddings", handle_embeddings);
3948+
svr->Post("/v1/embeddings", handle_embeddings_oai);
39113949
svr->Post("/rerank", handle_rerank);
39123950
svr->Post("/reranking", handle_rerank);
39133951
svr->Post("/v1/rerank", handle_rerank);

examples/server/tests/unit/test_embedding.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ def create_server():
1414

1515
def test_embedding_single():
1616
global server
17+
server.pooling = 'last'
1718
server.start()
18-
res = server.make_request("POST", "/embeddings", data={
19+
res = server.make_request("POST", "/v1/embeddings", data={
1920
"input": "I believe the meaning of life is",
2021
})
2122
assert res.status_code == 200
@@ -29,8 +30,9 @@ def test_embedding_single():
2930

3031
def test_embedding_multiple():
3132
global server
33+
server.pooling = 'last'
3234
server.start()
33-
res = server.make_request("POST", "/embeddings", data={
35+
res = server.make_request("POST", "/v1/embeddings", data={
3436
"input": [
3537
"I believe the meaning of life is",
3638
"Write a joke about AI from a very long prompt which will not be truncated",
@@ -46,7 +48,7 @@ def test_embedding_multiple():
4648

4749

4850
@pytest.mark.parametrize(
49-
"content,is_multi_prompt",
51+
"input,is_multi_prompt",
5052
[
5153
# single prompt
5254
("string", False),
@@ -59,34 +61,65 @@ def test_embedding_multiple():
5961
([[12, 34, 56], [12, "string", 34, 56]], True),
6062
]
6163
)
62-
def test_embedding_mixed_input(content, is_multi_prompt: bool):
64+
def test_embedding_mixed_input(input, is_multi_prompt: bool):
6365
global server
6466
server.start()
65-
res = server.make_request("POST", "/embeddings", data={"content": content})
67+
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
6668
assert res.status_code == 200
69+
data = res.body['data']
6770
if is_multi_prompt:
68-
assert len(res.body) == len(content)
69-
for d in res.body:
71+
assert len(data) == len(input)
72+
for d in data:
7073
assert 'embedding' in d
7174
assert len(d['embedding']) > 1
7275
else:
73-
assert 'embedding' in res.body
74-
assert len(res.body['embedding']) > 1
76+
assert 'embedding' in data[0]
77+
assert len(data[0]['embedding']) > 1
78+
79+
80+
def test_embedding_pooling_none():
81+
global server
82+
server.pooling = 'none'
83+
server.start()
84+
res = server.make_request("POST", "/embeddings", data={
85+
"input": "hello hello hello",
86+
})
87+
assert res.status_code == 200
88+
assert 'embedding' in res.body[0]
89+
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
90+
91+
# make sure embedding vector is not normalized
92+
for x in res.body[0]['embedding']:
93+
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
94+
95+
96+
def test_embedding_pooling_none_oai():
97+
global server
98+
server.pooling = 'none'
99+
server.start()
100+
res = server.make_request("POST", "/v1/embeddings", data={
101+
"input": "hello hello hello",
102+
})
103+
104+
# /v1/embeddings does not support pooling type 'none'
105+
assert res.status_code == 400
75106

76107

77108
def test_embedding_openai_library_single():
78109
global server
110+
server.pooling = 'last'
79111
server.start()
80-
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
112+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
81113
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
82114
assert len(res.data) == 1
83115
assert len(res.data[0].embedding) > 1
84116

85117

86118
def test_embedding_openai_library_multiple():
87119
global server
120+
server.pooling = 'last'
88121
server.start()
89-
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
122+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
90123
res = client.embeddings.create(model="text-embedding-3-small", input=[
91124
"I believe the meaning of life is",
92125
"Write a joke about AI from a very long prompt which will not be truncated",
@@ -100,17 +133,19 @@ def test_embedding_openai_library_multiple():
100133

101134
def test_embedding_error_prompt_too_long():
102135
global server
136+
server.pooling = 'last'
103137
server.start()
104-
res = server.make_request("POST", "/embeddings", data={
138+
res = server.make_request("POST", "/v1/embeddings", data={
105139
"input": "This is a test " * 512,
106140
})
107141
assert res.status_code != 200
108142
assert "too large" in res.body["error"]["message"]
109143

110144

111145
def test_same_prompt_give_same_result():
146+
server.pooling = 'last'
112147
server.start()
113-
res = server.make_request("POST", "/embeddings", data={
148+
res = server.make_request("POST", "/v1/embeddings", data={
114149
"input": [
115150
"I believe the meaning of life is",
116151
"I believe the meaning of life is",
@@ -138,7 +173,7 @@ def test_same_prompt_give_same_result():
138173
def test_embedding_usage_single(content, n_tokens):
139174
global server
140175
server.start()
141-
res = server.make_request("POST", "/embeddings", data={"input": content})
176+
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
142177
assert res.status_code == 200
143178
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
144179
assert res.body['usage']['prompt_tokens'] == n_tokens
@@ -147,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
147182
def test_embedding_usage_multiple():
148183
global server
149184
server.start()
150-
res = server.make_request("POST", "/embeddings", data={
185+
res = server.make_request("POST", "/v1/embeddings", data={
151186
"input": [
152187
"I believe the meaning of life is",
153188
"I believe the meaning of life is",

0 commit comments

Comments
 (0)