@@ -48,7 +48,7 @@ def test_embedding_multiple():
48
48
49
49
50
50
@pytest .mark .parametrize (
51
- "content ,is_multi_prompt" ,
51
+ "input ,is_multi_prompt" ,
52
52
[
53
53
# single prompt
54
54
("string" , False ),
@@ -61,19 +61,20 @@ def test_embedding_multiple():
61
61
([[12 , 34 , 56 ], [12 , "string" , 34 , 56 ]], True ),
62
62
]
63
63
)
64
- def test_embedding_mixed_input (content , is_multi_prompt : bool ):
64
+ def test_embedding_mixed_input (input , is_multi_prompt : bool ):
65
65
global server
66
66
server .start ()
67
- res = server .make_request ("POST" , "/embeddings" , data = {"content " : content })
67
+ res = server .make_request ("POST" , "/v1/ embeddings" , data = {"input " : input })
68
68
assert res .status_code == 200
69
+ data = res .body ['data' ]
69
70
if is_multi_prompt :
70
- assert len (res . body ) == len (content )
71
- for d in res . body :
71
+ assert len (data ) == len (input )
72
+ for d in data :
72
73
assert 'embedding' in d
73
74
assert len (d ['embedding' ]) > 1
74
75
else :
75
- assert 'embedding' in res . body
76
- assert len (res . body ['embedding' ]) > 1
76
+ assert 'embedding' in data [ 0 ]
77
+ assert len (data [ 0 ] ['embedding' ]) > 1
77
78
78
79
79
80
def test_embedding_pooling_none ():
@@ -85,7 +86,7 @@ def test_embedding_pooling_none():
85
86
})
86
87
assert res .status_code == 200
87
88
assert 'embedding' in res .body [0 ]
88
- assert len (res .body [0 ]['embedding' ]) == 3
89
+ assert len (res .body [0 ]['embedding' ]) == 5 # 3 text tokens + 2 special
89
90
90
91
# make sure embedding vector is not normalized
91
92
for x in res .body [0 ]['embedding' ]:
@@ -172,7 +173,7 @@ def test_same_prompt_give_same_result():
172
173
def test_embedding_usage_single (content , n_tokens ):
173
174
global server
174
175
server .start ()
175
- res = server .make_request ("POST" , "/embeddings" , data = {"input" : content })
176
+ res = server .make_request ("POST" , "/v1/ embeddings" , data = {"input" : content })
176
177
assert res .status_code == 200
177
178
assert res .body ['usage' ]['prompt_tokens' ] == res .body ['usage' ]['total_tokens' ]
178
179
assert res .body ['usage' ]['prompt_tokens' ] == n_tokens
@@ -181,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
181
182
def test_embedding_usage_multiple ():
182
183
global server
183
184
server .start ()
184
- res = server .make_request ("POST" , "/embeddings" , data = {
185
+ res = server .make_request ("POST" , "/v1/ embeddings" , data = {
185
186
"input" : [
186
187
"I believe the meaning of life is" ,
187
188
"I believe the meaning of life is" ,
0 commit comments