Skip to content

Commit 2a5510e

Browse files
committed
tests : update server tests
ggml-ci
1 parent 87df601 commit 2a5510e

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

examples/server/server.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -744,16 +744,16 @@ struct server_task_result_embd : server_task_result {
744744

745745
json to_json_non_oaicompat() {
746746
return json {
747-
{"index", index},
748-
{"embedding", embedding},
749-
{"tokens_evaluated", n_tokens},
747+
{"index", index},
748+
{"embedding", embedding},
750749
};
751750
}
752751

753752
json to_json_oaicompat() {
754753
return json {
755-
{"index", index},
756-
{"embedding", embedding[0]},
754+
{"index", index},
755+
{"embedding", embedding[0]},
756+
{"tokens_evaluated", n_tokens},
757757
};
758758
}
759759
};

examples/server/tests/unit/test_embedding.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_embedding_multiple():
4848

4949

5050
@pytest.mark.parametrize(
51-
"content,is_multi_prompt",
51+
"input,is_multi_prompt",
5252
[
5353
# single prompt
5454
("string", False),
@@ -61,19 +61,20 @@ def test_embedding_multiple():
6161
([[12, 34, 56], [12, "string", 34, 56]], True),
6262
]
6363
)
64-
def test_embedding_mixed_input(content, is_multi_prompt: bool):
64+
def test_embedding_mixed_input(input, is_multi_prompt: bool):
6565
global server
6666
server.start()
67-
res = server.make_request("POST", "/embeddings", data={"content": content})
67+
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
6868
assert res.status_code == 200
69+
data = res.body['data']
6970
if is_multi_prompt:
70-
assert len(res.body) == len(content)
71-
for d in res.body:
71+
assert len(data) == len(input)
72+
for d in data:
7273
assert 'embedding' in d
7374
assert len(d['embedding']) > 1
7475
else:
75-
assert 'embedding' in res.body
76-
assert len(res.body['embedding']) > 1
76+
assert 'embedding' in data[0]
77+
assert len(data[0]['embedding']) > 1
7778

7879

7980
def test_embedding_pooling_none():
@@ -85,7 +86,7 @@ def test_embedding_pooling_none():
8586
})
8687
assert res.status_code == 200
8788
assert 'embedding' in res.body[0]
88-
assert len(res.body[0]['embedding']) == 3
89+
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
8990

9091
# make sure embedding vector is not normalized
9192
for x in res.body[0]['embedding']:
@@ -172,7 +173,7 @@ def test_same_prompt_give_same_result():
172173
def test_embedding_usage_single(content, n_tokens):
173174
global server
174175
server.start()
175-
res = server.make_request("POST", "/embeddings", data={"input": content})
176+
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
176177
assert res.status_code == 200
177178
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
178179
assert res.body['usage']['prompt_tokens'] == n_tokens
@@ -181,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
181182
def test_embedding_usage_multiple():
182183
global server
183184
server.start()
184-
res = server.make_request("POST", "/embeddings", data={
185+
res = server.make_request("POST", "/v1/embeddings", data={
185186
"input": [
186187
"I believe the meaning of life is",
187188
"I believe the meaning of life is",

0 commit comments

Comments
 (0)