Skip to content

server : fix logprobs, make it OAI-compatible #10783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ node index.js

### POST `/completion`: Given a `prompt`, it returns the predicted completion.

> [!IMPORTANT]
>
> This endpoint is **not** OAI-compatible

*Options:*

`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true:
Expand Down Expand Up @@ -448,27 +452,48 @@ These words will not be included in the completion, so make sure to add them to

- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.

- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:

```json
{
"content": "<the token selected by the model>",
"probs": [
"content": "<the generated completion text>",
...
"completion_probabilities": [
{
"id": <token id>,
"prob": float,
"tok_str": "<most likely token>"
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
{
"id": <token id>,
"prob": float,
"token": "<token text>",
"bytes": [int, int, ...],
},
{
"id": <token id>,
"prob": float,
"token": "<token text>",
"bytes": [int, int, ...],
},
...
]
},
{
"id": <token id>,
"prob": float,
"tok_str": "<second most likely token>"
"token": "<most likely token>",
"bytes": [int, int, ...],
"top_logprobs": [
...
]
},
...
]
},
```

Notice that each `probs` is an array of length `n_probs`.

- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
Expand Down
160 changes: 97 additions & 63 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ inline std::string stop_type_to_str(stop_type type) {

struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct token_prob {
llama_token tok;
Expand All @@ -427,25 +428,46 @@ struct completion_token_output {
json to_json() const {
json probs_for_token = json::array();
for (const auto & p : probs) {
std::string tok_str(p.tok_str);
tok_str.resize(validate_utf8(tok_str));
probs_for_token.push_back(json {
{"tok_str", p.tok_str},
{"prob", p.prob},
{"id", p.tok},
{"token", tok_str},
{"bytes", str_to_bytes(p.tok_str)},
{"logprob", logarithm(p.prob)},
});
}
return probs_for_token;
}

static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
json out = json::array();
for (const auto & prob : probs) {
const std::string tok_str = prob.text_to_send;
for (const auto & it : probs) {
std::string tok_str(it.text_to_send);
tok_str.resize(validate_utf8(tok_str));
out.push_back(json {
{"content", tok_str},
{"probs", prob.to_json()},
{"id", it.tok},
{"token", tok_str},
{"logprob", logarithm(it.prob)},
{"bytes", str_to_bytes(it.text_to_send)},
{"top_logprobs", it.to_json()},
});
}
return out;
}

static float logarithm(float x) {
// nlohmann::json converts -inf to null, so we need to prevent that
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
}

static std::vector<unsigned char> str_to_bytes(const std::string & str) {
std::vector<unsigned char> bytes;
for (unsigned char c : str) {
bytes.push_back(c);
}
return bytes;
}
};

struct server_task_result_cmpl_final : server_task_result {
Expand Down Expand Up @@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()},
};
if (!probs_output.empty()) {
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
return res;
Expand All @@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}

json choices = json::array({json{
json choice = json{
{"finish_reason", finish_reason},
{"index", 0},
{"message", json{
{"content", content},
{"role", "assistant"}
}
}}});
}};

if (!stream && probs_output.size() > 0) {
choice["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json(probs_output)},
};
}

std::time_t t = std::time(0);

json res = json {
{"choices", choices},
{"choices", json::array({choice})},
{"created", t},
{"model", oaicompat_model},
{"object", "chat.completion"},
Expand Down Expand Up @@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}

json choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
json choice = json{
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
};

json ret = json {
{"choices", choices},
{"choices", json::array({choice})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
Expand All @@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result {
int32_t n_decoded;
int32_t n_prompt_tokens;

std::vector<completion_token_output> probs_output;
completion_token_output prob_output;
result_timings timings;

// OAI-compat fields
Expand Down Expand Up @@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result {
if (timings.prompt_n > 0) {
res.push_back({"timings", timings.to_json()});
}
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
if (!prob_output.probs.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
}
return res;
}
Expand Down Expand Up @@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result {
}});
}

GGML_ASSERT(choices.size() >= 1);

if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output})},
};
}

json ret = json {
{"choices", choices},
{"created", t},
Expand Down Expand Up @@ -951,7 +989,6 @@ struct server_slot {

// stats
size_t n_sent_text = 0; // number of sent text character
size_t n_sent_token_probs = 0;

int64_t t_start_process_prompt;
int64_t t_start_generation;
Expand All @@ -973,7 +1010,6 @@ struct server_slot {
stopping_word = "";
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;

generated_token_probs.clear();
Expand Down Expand Up @@ -1713,34 +1749,15 @@ struct server_context {

bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
const std::string token_str = result.text_to_send;
slot.sampled = result.tok;

// search stop word and delete it
slot.generated_text += token_str;
slot.has_next_token = true;

// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
}
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();

if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
Expand Down Expand Up @@ -1869,6 +1886,32 @@ struct server_context {
return slot.has_next_token; // continue
}

void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
int n_vocab = llama_n_vocab(llama_get_model(ctx));
size_t n_probs = slot.params.sampling.n_probs;

bool found_sampled_tok = false;
result.probs.reserve(n_probs);
for (int i = 0; i < n_vocab; i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
found_sampled_tok = true;
result.prob = cur[i].p;
}
// set probability for top n_probs tokens
result.probs.push_back({
cur[i].id,
common_detokenize(ctx, {cur[i].id}, special),
cur[i].p
});
// break if we have all the necessary data
if (result.probs.size() == n_probs && found_sampled_tok) {
break;
}
}
}

void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type);
}
Expand Down Expand Up @@ -1906,17 +1949,7 @@ struct server_context {

// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);

const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());

std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
res->probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
res->prob_output = tkn; // copy the token probs
}

// populate timings if this is final response or timings_per_token is enabled
Expand Down Expand Up @@ -2728,7 +2761,9 @@ struct server_context {
continue; // continue loop of slots
}

llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
const int tok_idx = slot.i_batch - i;

llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);

slot.i_batch = -1;

Expand All @@ -2747,17 +2782,12 @@ struct server_context {
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;

completion_token_output result;
result.tok = id;

const auto * cur_p = common_sampler_get_candidates(slot.smpl);
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later

for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
auto tok_id = cur_p->data[i].id;
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result, params_base.special, tok_idx);
}

if (!process_token(result, slot)) {
Expand Down Expand Up @@ -2841,7 +2871,11 @@ struct server_context {
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

result.tok = ids[i];
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
result.prob = 1.0f; // set later
Copy link
Collaborator Author

@ngxson ngxson Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the branch for speculative decoding. I'm not sure now I can get token probs here. Could you give me some clues? @ggerganov

(Or we can skip this for now if it's complicated)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need to update common_sampler_sample_and_accept_n() to return the probs. But let's fix this later.


// TODO: set result.probs

if (!process_token(result, slot)) {
// release slot because of stop condition
Expand Down
Loading
Loading