LLM tokenization: BPE algorithm
I recently learned the basic tokenization algorithm for LLMs.
It's called BPE (Byte-Pair Encoding), and it's a fun algorithmic problem that doesn't require other ML background. Optimizing the encoding step is like a LeetCode hard.
I explain it here in my own words from memory, as it's my preferred way of solidifying knowledge.
Sharing as an article because of the post length limit.
Overview
There are three pieces/algorithms:
Creating the token alphabet
As the B in BPE suggests, we work at the byte level.
Every byte is by itself a valid token, meaning that we start with 2^8 = 256 valid tokens.
Tokens also have a numeric ID, which, for single-byte tokens, is just the byte value.
One thing we need to decide upfront is how big we want our token alphabet to be.
We can set it to anything, but a common size is ~50k tokens. It's a trade-off:
More tokens means more compression when encoding text (context won't fill as quickly), but also more computation (every time we do a 'next token' prediction, we compute one probability per possible token).
The other thing we need is a text dataset to base our vocabulary on. Our vocab will be engineered to "compress" that specific dataset, so we want it to be large and representative of the kind of text we'll be encoding.
The vocab is constructed by a greedy algorithm.
We think of the dataset as a sequence of tokens. In practice, this looks like a sequence of numerical token IDs.
Initially, every byte in the dataset is its own token.
We then repeat the following steps in a loop until we reach the desired vocab size:
The intuition is that every time we combine two tokens, the token sequence gets shorter. By targeting the most frequent pairs, we compress text the most.
Here is a full example:
dataset: "BCDEDEDE"
target vocab size: 258 (we want to add two tokens to the initial 256)
1. initial token sequence: B, C, D, E, D, E, D, E
most frequent consecutive pair: D, E (3 occurrences)
new token: "DE" (ID 256)
2. new token sequence: B, C, DE, DE, DE
most frequent pair: DE, DE (2 occurrences)
new token: "DEDE" (ID 257)
3. final token sequence: B, C, DEDE, DE
Note: we need a deterministic tie-breaking rule for choosing pairs to merge when the most frequent pair is the same token twice, like in the case of DE, DE. We could end up with DEDE, DE, or DE, DEDE. By convention, we merge the first one.
Without getting into the weeds, it is not hard to implement each iteration of the loop in O(n) time, where n is the dataset length. Thus, the total runtime is O(V * n), where V is the desired vocab size.
I haven't looked into optimizations for this step, but it doesn't seem too important because V is not too large (~50k) and this is a one-off preprocessing step.
Encoding
This step is about turning input text into a sequence of tokens.
The general idea is to apply the same sequence of token replacements as we did when creating the vocab.
Let S be the input byte sequence.
Note: the most frequent pair in the dataset used to create the vocab may not be the most frequent pair in S. In fact, we don't use token-pair frequencies in S for anything. But the more aligned they are with the dataset, the more compact the encoding will be.
We can describe the encoding by looking at the naive algorithm, which repeatedly merges token pairs in the same order as the merge order used in the vocab creation algorithm.
We think of S as a sequence of tokens. Initially, every byte in S is its own token.
Recommended by LinkedIn
Then, for each token X in the vocab starting from ID 256:
Let x1 and x2 be the two tokens that we need to concatenate to get X.
Again, we need to be careful with the case x1 == x2 to make the encoding deterministic. If we find x1, x1, x1, we want to end up with X, x1, not x1, X.
### Optimizing the encoding
Since we do this operation extremely often, it makes sense to optimize it.
The key idea is to use a priority queue to quickly find the highest-priority pair of tokens to merge.
First, we initialize a hash map: (int, int) -> int, with one entry for each vocab token past the initial 256 byte-based tokens, capturing the merge rule to obtain it: (x1, x2) -> X.
As usual, we think of S as a sequence of tokens, starting with the initial byte tokens.
Since we are focusing on efficiency, let's be precise about the data structure for the token sequence: a doubly-linked list. It allows us to efficiently insert and remove elements from the middle of the list.
In addition to next and prev fields, each list node has 4 fields:
The need for the extra fields will be apparent soon.
The priority queue (PQ) is our final data structure. We initialize it by iterating through the initial linked list. For each consecutive pair, (x1, x2), if it exists in the hash map, (x1, x2) -> X, we add an entry to the PQ.
The entry consists of a pointer to x1's list node. The priority is based on the token ID of X, with lower IDs having higher priority (they were merged earlier during the vocab creation process). In other words, we use a min-heap.
We use the index of x1 to break priority ties (lower indices have higher priority). This ensures that the encoding is deterministic, and it's why we need the index field.
For example: if the hash map contains (x1, x1) -> X, and the token sequence contains x1, x1, x1, we add two entries to the PQ, one for the first x1 and one for the second x1, but the first one has higher priority because of the lower index.
The deleted flag is needed because we'll use a technique known as "lazy deletions" in the PQ. It will be used to identify stale entries in the PQ (see Step 2 of the main loop).
The main loop runs until the PQ is empty:
Big O analysis:
The total runtime is O(n log n).
Decoding
Decoding is the most straightforward part, as we don't need to "replay" the token combination rules in a careful order.
We can precompute a table of token ID -> byte sequence:
We then simply construct the output text by concatenating byte_seq(X) for each token X in the token sequence.
Final words
This covers vanilla BPE. Production tokenizers add tweaks like special handling of whitespace and punctuation. But the algorithm above is at the core of the tokenization process surrounding every LLM call.
Further reading:
Correction regarding: "More tokens means more compression when encoding text (context won't fill as quickly), but also more computation (every time we do a 'next token' prediction, we compute one probability per possible token)." The claim about more computation may be backward. A smaller token vocab means a longer token sequence for the same text, and that's probably the most important scaling factor (each attention pattern/matrix grows quadratically on the context length). Perhaps a better reason why more tokens may be bad is that, if you have too many tokens, then they won't appear often enough in the training data to actually learn anything about them. (For an extreme example, consider tokens as long as entire sentences.)
I got nerd-sniped by this problem! Here is a follow-up post about my failed attempts to optimize the BPE encoding. It's on X because it exceeds post length limit and it's not worth another article: https://x.com/Nil053/status/2018232578080424375
Nice
Thanks for sharing your thoughts and explantion here Nil! Very helpful!