Skip to content

lookup: use hashmaps, select most frequent tokens, abort draft early if no good candidates #5462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Feb 12, 2024

While I was working on #5398 I took a look at the conventional lookup example and noticed that it has some issues. This PR attempts to fix those. The changes are:

  • On master the lookup example always selects the first occurrence of an n-gram and then fills up the draft with the exact sequence of tokens following that n-gram. With this PR the example instead builds up the draft one token at a time, each time selecting the token that has most frequently followed the previous few tokens.
  • There are thresholds for sample size and empirical probability that the most likely token must meet in order to be drafted. If the token does not meet these thresholds it is not drafted and the draft is aborted early. This is done because the scaling of t/s with batch size is not perfect and unlikely tokens will on average just slow the generation down.
  • The interpretation of the value for ngram_min is fixed. On master with ngram_min = 1 the minimal n-gram size is actually 2, with this PR it is 1. However, on master drafts based on the occurrence of only a single token are mostly useless anyways because their acceptance rate is very low. But the filters added with this PR greatly increase the acceptance rate of those drafts that pass the filters.
  • Hashmaps are used as the data structure to hold the information regarding the number of times that specific tokens have followed specific n-grams. The inner hashmap maps tokens to the number of times that these tokens have followed an n-gram. The outer hashmap then maps n-grams to these empirical distributions. The hashmaps can be built incrementally as new tokens are sampled. The time required to update the hashmaps with a single token is constant. The time required to find the most likely token is asymptotically proportional to the number of tokens fed to the hashmap but will in practice be almost constant because for a given n-gram only a small fraction of tokens will actually appear as a continuation. This data structure is not strictly needed as of right now but it will be very useful for further lookup sampling projects.

I commands like this for testing:

export model_name=miqu-70b && export quantization=q5_k_m
export nd=3
./lookup --model models/opt/${model_name}-${quantization}.gguf -ngl 99 --ctx-size 4096 --split-mode row --n-predict 1024 --seed 1337 --temp 0 --ignore-eos --draft $nd --color --prompt "[INST] Write a love story about two stars that tragically ends in a type Ia supernova. Use a lot of emotional and dramatic language. [/INST]"

The prompt is intentionally chosen in an adversarial way: it contains few token sequences that can be copied verbatim to the generation. The model is Miqu q5_K_M run on 3 P40s. I get these results:

--n-draft 0 1 2 3 4 5 6 7
n_drafted master 0 202 346 492 668 820 978 1127
n_drafted PR 0 164 238 330 434 529 625 725
n_accept master 0 104 134 143 139 142 143 145
n_accept PR 0 98 119 126 127 130 131 131
accept % master - 51.485% 38.728% 29.065% 20.808% 17.317% 14.622% 12.866%
accept % PR - 59.756% 50.000% 38.182% 29.263% 24.575% 20.960% 18.069%
t/s master 8.905 8.662 8.802 8.692 7.913 7.718 7.511 7.342
t/s PR 8.896 8.814 9.047 8.994 8.456 8.345 8.194 8.044

Note: the batch size for lookup decoding is --draft + 1 which is why the table only goes up to 7. With this PR ~90-95% as many tokens as on master get correctly drafted but with ~30-40% fewer incorrectly drafted tokens. As a consequence the total t/s increases. But because P40s have low compute relative to more modern GPUs they scale comparatively poorly with batch sizes > 1. So for this hardware and no prompt that already contains a lot of usable token sequences there is only a very small speedup, if any. I currently don't have a suitable instruct model on hand to test t/s on my RTX 3090 and non-instruct models (with these settings) tend to repeat themselves a lot which is good for lookup decoding but bad in terms of output quality.

After this PR I intend to implement lookup decoding not just based on the current context but also based on general text statistics and previous user generations. I think the best results will be achieved with a hierarchical system: first look for suitable tokens in the current user session, then in the previous user session, then in a more general text corpus like wikitext. The tradeoff is between sample size and relevance to the current generation. You could potentially use the higher sample size statistics to select among multiple candidates with more relevance to the current generation. For this the hashmap data structure will be very useful because it only stores those n-gram -> token mappings that are actually observed and as such needs very little memory: with a prototype the hashmap based on ~500 MiB of wikitext was only ~1 MiB in size.

When it comes to the implementation considerations laid out in #4235 , the only issue that should arise with the implementation in this PR is that setting a fixed size for the n-gram cache would skew the statistics used for creating the draft. But I think the caches will be small enough that this will not be necessary in the first place.

@JohannesGaessler
Copy link
Collaborator Author

I did another test for a prompt where you would more commonly use lookup decoding. I asked Miqu q5_K_M to analyze the hashmap code I added in this PR:

[INST] Explain to me what the following C++ code does:

\`\`\`
auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
    // atcs = all_token_counts: the hashmaps to modify.
    // inp_data: the token sequence on which the hashmaps are based.
    // inp_size: the current size of inp_data.
    // nnew: how many new tokens have been appended to inp_data since the last call to this function.
    //
    // In order to get correct results inp_data can ONLY BE APPENDED TO.
    // Changes in the middle need a complete rebuild.
    for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
        all_token_hashmap * atc = atcs + ngram_size - ngram_min;

        const int i_start = std::max(inp_size - nnew, ngram_size);
        for (int i = i_start; i < inp_size; ++i) {
            const int ngram_start = i - ngram_size;
            uint64_t ngram = inp_data[ngram_start];
            for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
                const uint64_t ngram_part = inp_data[j];
                ngram <<= 16;
                ngram |= ngram_part;
            }
            const llama_token token = inp_data[i];

            all_token_hashmap::iterator token_counts_it = atc->find(ngram);
            if (token_counts_it == atc->end()) {
                token_hashmap token_counts;
                token_counts.emplace(token, 1);
                atc->emplace(ngram, token_counts);
            } else {
                token_hashmap::iterator tc_it = token_counts_it->second.find(token);
                if (tc_it == token_counts_it->second.end()) {
                    token_counts_it->second.emplace(token, 1);
                } else {
                    tc_it->second++;
                }
            }
        }
    }
};
\`\`\` [/INST]

I reduced the number of generated tokens to 512 because otherwise the model would start generating nonsense towards the end. These are the results:

--draft 0 1 2 3 4 5 6 7
n_drafted master 0 146 246 345 468 570 672 777
n_drafted PR 0 143 257 333 445 532 617 700
n_accept master 0 67 90 98 96 99 101 101
n_accept PR 0 60 83 85 92 101 96 97
accept % master - 45.890% 36.585% 28.406% 20.513% 17.368% 15.030% 12.999%
accept % PR - 41.958% 32.296% 25.526% 20.674% 18.985% 15.559% 13.857%
t/s master 8.511 8.032 8.261 8.143 7.075 6.891 6.699 6.484
t/s PR 8.509 7.967 8.077 7.975 7.144 7.050 6.818 6.657

Against my expectation the acceptance rate is lower compared to the story prompt even though a lot of token sequences can be directly cited from the prompt. It may be an issue with the length of the generated response. Compared to the story prompt the use of filters seems to be detrimental for small --draft values. But I don't want to spend a lot of time on tuning those filters since they will need to be readjusted anyways when more information from e.g. wikitext is added for the drafting.

@JohannesGaessler
Copy link
Collaborator Author

Obsoleted by #5479

@LeonEricsson
Copy link
Contributor

I'm happy someone got around to improving on this, I've read your PRs and love the work. I hope to get around to reading your implementation some time soon, it's clear you've greatly improved on my somewhat lacking go at this 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants