-
Notifications
You must be signed in to change notification settings - Fork 11.9k
server : fix logprobs, make it OAI-compatible #10783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
74dc729
7828013
01afafe
cc90cdb
29c1495
396ade0
22b72c8
ed7f2d5
06bb38e
196e237
262950d
630ddcc
c0cca53
ecadd37
75fe775
8734df7
fd4cf34
d2463dc
65ef1c8
a217382
5b966df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -416,6 +416,7 @@ inline std::string stop_type_to_str(stop_type type) { | |
|
||
struct completion_token_output { | ||
llama_token tok; | ||
float prob; | ||
std::string text_to_send; | ||
struct token_prob { | ||
llama_token tok; | ||
|
@@ -427,25 +428,46 @@ struct completion_token_output { | |
json to_json() const { | ||
json probs_for_token = json::array(); | ||
for (const auto & p : probs) { | ||
std::string tok_str(p.tok_str); | ||
tok_str.resize(validate_utf8(tok_str)); | ||
probs_for_token.push_back(json { | ||
{"tok_str", p.tok_str}, | ||
{"prob", p.prob}, | ||
{"id", p.tok}, | ||
{"token", tok_str}, | ||
{"bytes", str_to_bytes(p.tok_str)}, | ||
{"logprob", logarithm(p.prob)}, | ||
}); | ||
} | ||
return probs_for_token; | ||
} | ||
|
||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) { | ||
json out = json::array(); | ||
for (const auto & prob : probs) { | ||
const std::string tok_str = prob.text_to_send; | ||
for (const auto & it : probs) { | ||
ggerganov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::string tok_str(it.text_to_send); | ||
tok_str.resize(validate_utf8(tok_str)); | ||
out.push_back(json { | ||
{"content", tok_str}, | ||
{"probs", prob.to_json()}, | ||
{"id", it.tok}, | ||
{"token", tok_str}, | ||
{"logprob", logarithm(it.prob)}, | ||
{"bytes", str_to_bytes(it.text_to_send)}, | ||
{"top_logprobs", it.to_json()}, | ||
}); | ||
} | ||
return out; | ||
} | ||
|
||
static float logarithm(float x) { | ||
// nlohmann::json converts -inf to null, so we need to prevent that | ||
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x); | ||
} | ||
|
||
static std::vector<unsigned char> str_to_bytes(const std::string & str) { | ||
std::vector<unsigned char> bytes; | ||
for (unsigned char c : str) { | ||
bytes.push_back(c); | ||
} | ||
return bytes; | ||
} | ||
}; | ||
|
||
struct server_task_result_cmpl_final : server_task_result { | ||
|
@@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result { | |
{"tokens_cached", n_tokens_cached}, | ||
{"timings", timings.to_json()}, | ||
}; | ||
if (!probs_output.empty()) { | ||
if (!stream && !probs_output.empty()) { | ||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); | ||
} | ||
return res; | ||
|
@@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result { | |
finish_reason = "stop"; | ||
} | ||
|
||
json choices = json::array({json{ | ||
json choice = json{ | ||
{"finish_reason", finish_reason}, | ||
{"index", 0}, | ||
{"message", json{ | ||
{"content", content}, | ||
{"role", "assistant"} | ||
} | ||
}}}); | ||
}}; | ||
|
||
if (!stream && probs_output.size() > 0) { | ||
choice["logprobs"] = json{ | ||
{"content", completion_token_output::probs_vector_to_json(probs_output)}, | ||
}; | ||
} | ||
|
||
std::time_t t = std::time(0); | ||
|
||
json res = json { | ||
{"choices", choices}, | ||
{"choices", json::array({choice})}, | ||
{"created", t}, | ||
{"model", oaicompat_model}, | ||
{"object", "chat.completion"}, | ||
|
@@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result { | |
finish_reason = "stop"; | ||
} | ||
|
||
json choices = json::array({json{{"finish_reason", finish_reason}, | ||
{"index", 0}, | ||
{"delta", json::object()}}}); | ||
json choice = json{ | ||
{"finish_reason", finish_reason}, | ||
{"index", 0}, | ||
{"delta", json::object()} | ||
}; | ||
|
||
json ret = json { | ||
{"choices", choices}, | ||
{"choices", json::array({choice})}, | ||
{"created", t}, | ||
{"id", oaicompat_cmpl_id}, | ||
{"model", oaicompat_model}, | ||
|
@@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result { | |
int32_t n_decoded; | ||
int32_t n_prompt_tokens; | ||
|
||
std::vector<completion_token_output> probs_output; | ||
completion_token_output prob_output; | ||
result_timings timings; | ||
|
||
// OAI-compat fields | ||
|
@@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result { | |
if (timings.prompt_n > 0) { | ||
res.push_back({"timings", timings.to_json()}); | ||
} | ||
if (!probs_output.empty()) { | ||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); | ||
if (!prob_output.probs.empty()) { | ||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}); | ||
} | ||
return res; | ||
} | ||
|
@@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result { | |
}}); | ||
} | ||
|
||
GGML_ASSERT(choices.size() >= 1); | ||
|
||
if (prob_output.probs.size() > 0) { | ||
choices[0]["logprobs"] = json{ | ||
{"content", completion_token_output::probs_vector_to_json({prob_output})}, | ||
}; | ||
} | ||
|
||
json ret = json { | ||
{"choices", choices}, | ||
{"created", t}, | ||
|
@@ -951,7 +989,6 @@ struct server_slot { | |
|
||
// stats | ||
size_t n_sent_text = 0; // number of sent text character | ||
size_t n_sent_token_probs = 0; | ||
|
||
int64_t t_start_process_prompt; | ||
int64_t t_start_generation; | ||
|
@@ -973,7 +1010,6 @@ struct server_slot { | |
stopping_word = ""; | ||
n_past = 0; | ||
n_sent_text = 0; | ||
n_sent_token_probs = 0; | ||
task_type = SERVER_TASK_TYPE_COMPLETION; | ||
|
||
generated_token_probs.clear(); | ||
|
@@ -1713,34 +1749,15 @@ struct server_context { | |
|
||
bool process_token(completion_token_output & result, server_slot & slot) { | ||
// remember which tokens were sampled - used for repetition penalties during sampling | ||
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); | ||
const std::string token_str = result.text_to_send; | ||
slot.sampled = result.tok; | ||
|
||
// search stop word and delete it | ||
slot.generated_text += token_str; | ||
slot.has_next_token = true; | ||
|
||
// check if there is incomplete UTF-8 character at the end | ||
bool incomplete = false; | ||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { | ||
unsigned char c = slot.generated_text[slot.generated_text.size() - i]; | ||
if ((c & 0xC0) == 0x80) { | ||
// continuation byte: 10xxxxxx | ||
continue; | ||
} | ||
if ((c & 0xE0) == 0xC0) { | ||
// 2-byte character: 110xxxxx ... | ||
incomplete = i < 2; | ||
} else if ((c & 0xF0) == 0xE0) { | ||
// 3-byte character: 1110xxxx ... | ||
incomplete = i < 3; | ||
} else if ((c & 0xF8) == 0xF0) { | ||
// 4-byte character: 11110xxx ... | ||
incomplete = i < 4; | ||
} | ||
// else 1-byte character or invalid byte | ||
break; | ||
} | ||
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); | ||
|
||
if (!incomplete) { | ||
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); | ||
|
@@ -1869,6 +1886,32 @@ struct server_context { | |
return slot.has_next_token; // continue | ||
} | ||
|
||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { | ||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx); | ||
int n_vocab = llama_n_vocab(llama_get_model(ctx)); | ||
size_t n_probs = slot.params.sampling.n_probs; | ||
|
||
bool found_sampled_tok = false; | ||
result.probs.reserve(n_probs); | ||
for (int i = 0; i < n_vocab; i++) { | ||
// set probability for sampled token | ||
if (cur[i].id == result.tok) { | ||
found_sampled_tok = true; | ||
result.prob = cur[i].p; | ||
} | ||
// set probability for top n_probs tokens | ||
result.probs.push_back({ | ||
cur[i].id, | ||
common_detokenize(ctx, {cur[i].id}, special), | ||
cur[i].p | ||
}); | ||
// break if we have all the necessary data | ||
if (result.probs.size() == n_probs && found_sampled_tok) { | ||
break; | ||
} | ||
} | ||
} | ||
|
||
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { | ||
send_error(task.id, error, type); | ||
} | ||
|
@@ -1906,17 +1949,7 @@ struct server_context { | |
|
||
// populate res.probs_output | ||
if (slot.params.sampling.n_probs > 0) { | ||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); | ||
|
||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); | ||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); | ||
|
||
std::vector<completion_token_output> probs_output; | ||
if (probs_pos < probs_stop_pos) { | ||
res->probs_output = std::vector<completion_token_output>( | ||
slot.generated_token_probs.begin() + probs_pos, | ||
slot.generated_token_probs.begin() + probs_stop_pos); | ||
} | ||
res->prob_output = tkn; // copy the token probs | ||
} | ||
|
||
// populate timings if this is final response or timings_per_token is enabled | ||
|
@@ -2728,7 +2761,9 @@ struct server_context { | |
continue; // continue loop of slots | ||
} | ||
|
||
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); | ||
const int tok_idx = slot.i_batch - i; | ||
|
||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); | ||
|
||
slot.i_batch = -1; | ||
|
||
|
@@ -2747,17 +2782,12 @@ struct server_context { | |
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; | ||
|
||
completion_token_output result; | ||
result.tok = id; | ||
|
||
const auto * cur_p = common_sampler_get_candidates(slot.smpl); | ||
result.tok = id; | ||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); | ||
result.prob = 1.0f; // set later | ||
|
||
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { | ||
auto tok_id = cur_p->data[i].id; | ||
result.probs.push_back({ | ||
tok_id, | ||
tokens_to_output_formatted_string(ctx, tok_id), | ||
i >= cur_p->size ? 0.0f : cur_p->data[i].p, | ||
}); | ||
if (slot.params.sampling.n_probs > 0) { | ||
populate_token_probs(slot, result, params_base.special, tok_idx); | ||
} | ||
|
||
if (!process_token(result, slot)) { | ||
|
@@ -2841,7 +2871,11 @@ struct server_context { | |
for (size_t i = 0; i < ids.size(); ++i) { | ||
completion_token_output result; | ||
|
||
result.tok = ids[i]; | ||
result.tok = ids[i]; | ||
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); | ||
result.prob = 1.0f; // set later | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the branch for speculative decoding. I'm not sure now I can get token probs here. Could you give me some clues? @ggerganov (Or we can skip this for now if it's complicated) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think we need to update |
||
|
||
// TODO: set result.probs | ||
|
||
if (!process_token(result, slot)) { | ||
// release slot because of stop condition | ||
|
Uh oh!
There was an error while loading. Please reload this page.