HomeBlogsA Peek in Byte-Pair Encoding
Go Back

A Peek in Byte-Pair Encoding

Ashish Khare

2 Aug 2025

banner The image is courtsey of Andrej minibpe. Screenshot of tiktokenizer.

Web scraping has taught me how important it is to structure your data in the correct format for optimal representation and future usage. Now, while working with LLMs, language models, and transformers, this belief has only strengthened. There are several advantages to using better encoding:

  1. A longer context window
  2. A smaller, sanitized corpus (though this may increase the vocabulary size)
  3. Improved ability of LLMs to form associations between tokens and their precedence
  4. A smaller inference footprint (fewer tokens to predict)

Recently, I watched the Tokenizer Video by Andrej[1], where he explained how byte-pair encoding reduced the complexity of the input corpus and how it was implemented for GPT-2, GPT-4, and Meta’s LLaMA. To my surprise, two things stood out: firstly, the algorithm itself is a greedy approach and super easy to implement (anyone with adequate coding knowledge can implement it from scratch). Secondly, they used regex to pre-split the text and get chunks of different words, numbers, apostrophes, and spaces. Using regex actually increased the encoding efficiency, as all the space blocks were treated cohesively.

I also read this paper called A Formal Perspective on Byte-Pair Encoding[2], which discusses optimizations to the greedy BPE algorithm that improve runtime complexity from O(nm) to O(n log m), where n is the input text length and m is the number of merges to perform. More on that later.

Byte-Pair Encoding

In simple terms, you find pairs of characters along with their frequencies, replace the most frequent pair with a new character, and add this new character to your vocabulary.

As you may have guessed, this requires several passes over the input string, updating the pair-frequency store and making the appropriate replacements in the text—hence, the term greedy.

merge 1 - p i c k e d p i c k l e d p i c k l e s
merge 2 - (pi) c k e d (pi) c k l e d (pi) c k l e s
merge 3 - (pi) (ck) e d (pi) (ck) l e d (pi) (ck) l e s
merge 4 - ((pi)(ck)) e d ((pi)(ck)) l e d ((pi)(ck)) l e s
merge 5 - ((pi)(ck)) (ed) ((pi)(ck)) l (ed) ((pi)(ck)) l e s
final   - ((pi)(ck)) (ed) (((pi)(ck))l) (ed) (((pi)(ck))l) e s

Points to note:

Before any merges, the vocabulary has 8 characters: Σ = {'d', 'p', 'c', 's', 'l', 'k', 'e', 'i'}.

After performing all merges, the vocabulary size increases by 5, resulting in 13 elements: Σ = {'d', 'p', 'c', 's', 'l', 'k', 'e', 'i', 'pi', 'ck', 'ed', 'pick', 'pickl'}.

As a result of these operations, the vocabulary size grows, and so does the size of the embedding table used to store embeddings for each token. However, in exchange for a larger embedding table, the model can encode more information per input sequence and capture longer contexts more effectively.

Here are the bpe and merge functions written in an extremely Pythonic style, taken from [2]. I encourage you to study how and why they chose this implementation—it mostly comes down to writing clean, idiomatic Python.

from collections import Counter
from typing import Union, Tuple, List

def bpe(xs: Union[str, List], V: int):
    for _ in range(V):
        pairs = Counter(zip(xs, xs[1:]))
        top_pair = pairs.most_common(1)[0][0]
        xs = merge(list(xs), top_pair)
    return xs

def merge(xs: List, pair: Tuple):
    ys = []
    while xs:
        if tuple(xs[:2]) == pair:
            ys.append(pair)
            xs = xs[2:]
        else:
            ys.append(xs.pop(0))
    return ys

OpenAI discussed the use of BPE in their GPT-2 paper, specifically in section 2.2.

Here is the implementation of a basic tokenizer class based on the video by Andrej.

class BasicTokenizer:
  def __init__(self):
    self.vocab = {}
    self.merges = {}

  def train(self, text, vocab_size, verbose=False):
    """
    Assumption: text follows UTF-8 and lies between [0, 256)
    """
    # Convert text to bytes
    seq = list(text.encode("utf-8"))

    # catch merges
    merges = {}
    vocab = {idx: bytes([idx]) for idx in range(256)}

    for i in range(vocab_size - 256):
      # find pairs in the sequence
      pairs = find_pairs(seq)

      # (Addition) Stop if no pairs exists
      if not pairs:
        print("Given vocab size is ")
        break

      # find the most frequent pair
      max_val_pair = max(pairs, key=pairs.get)

      # find the next token value
      replacement = 256 + i

      # replace the most frequent pair with the next token value
      seq = merge_tokens(seq, max_val_pair, replacement)

      # store the merge made
      merges[max_val_pair] = replacement

      # add the merge token to the vocab
      vocab[replacement] = vocab[max_val_pair[0]] + vocab[max_val_pair[1]]

      # (optionally) output the merge made
      if verbose:
        print(f'Merged {max_val_pair} -> {replacement}')

    self.merges = merges
    self.vocab = vocab

  def byte_encode(self, text):
    # string -> bytes
    return list(text.encode("utf-8"))

  def byte_decode(self, seq):
    # bytes -> string
    return bytes(seq).decode('utf-8')

  def encode(self, text):
    # Convert text to bytes
    seq = self.byte_encode(text)

    while True:

      # Get all pairs from the input text
      pairs = find_pairs(seq).keys()

      # Find the minimum valued pair from the merges
      pair_to_replace = min(
          pairs,
          key=lambda p: self.merges.get(p, float("inf")),
      )

      # If not match, all are done
      if pair_to_replace not in self.merges:
        break

      # Grab the tokens to replace
      replacement = self.merges[pair_to_replace]

      # Replace the pairs with the token
      seq = merge_tokens(seq, pair_to_replace, replacement)

    return seq


  def decode(self, seq):

    # Convert the tokens back to bytes
    text_bytes = b"".join(self.vocab[idx] for idx in seq)

    # Decode bytes back to characters
    text = text_bytes.decode("utf-8", errors="replace")

    return text

I've added comments for better readability. Read entire code at Google Colab Notebook

I've made a small addition to the code: while encoding, the loop now breaks early if no more merges can be made, even if iterations are still remaining on the loop counter.

Optimizations

The paper [2] presents two enhanced algorithms. The first discusses using a more efficient data structure for storing pair frequencies and the input text (or tokens). The second focuses on memoization and how it can outperform the brute-force algorithm in certain scenarios.

Linked list + Max heap updates

To achieve this, we need to maintain two data structures:

  1. Linked List of Input Strings The input string is stored as a doubly linked list to make merge operations easier.

    • The token property corresponds to the index in the vocabulary.
    • The value property corresponds to the concatenated byte array for a merged or single token.
# Doubly linked list
class Node:
  def __init__(self, token = None, value = None) -> None:
    self.token = token
    self.value = value
    self.next = None
    self.prev = None
  
  # Get length of the linked list
  def getLength(self):
    cur = self
    count = 0
    while cur:
      count += 1
      cur = cur.next
    return count
  1. Max Heap of Pair Frequencies The heap class has three components: the heap array, the count dictionary, and the positions dictionary.

    • The heap array stores tuples of (frequency, pair), and all heap operations are performed on it.
    • The count dictionary of integer stores the frequency of each pair and is later used to build the heap once training is complete.
    • The positions dictionary of set stores the pairs of nodes from the linked list that correspond to each token pair.
# Max Heap for maintaining positions from the ll
class MaxHeap:
  def __init__(self) -> None:
    self.heap = []
    self.count = defaultdict(int)
    self.positions = defaultdict(set)
  
  # Add pair of nodes
  def addPair(self, a: Node, b: Node):
    pair = (a.token, b.token)
    self.count[pair] += 1
    self.positions[pair].add((a, b))
  
  # remove pair of nodes
  def removePair(self, a: Node, b: Node):
    pair = (a.token, b.token)
    if (a, b) in self.positions[pair]:
        self.positions[pair].remove((a, b))
        self.count[pair] -= 1
        if self.count[pair] == 0:
            del self.count[pair]
            del self.positions[pair]
  
  # create the heap
  def build(self):
    self.heap = [(-freq, pair) for pair, freq in self.count.items()]
    heapq.heapify(self.heap)
  
  # Pop the most frequent pair (if valid)
  def get_max_pair(self):
    while self.heap:
        neg_freq, pair = heapq.heappop(self.heap)
        if pair in self.positions and self.positions[pair]:
            return pair
    return None

Here is an illustration of how to imagine the positions being updated using the linked list and the positions dictionary. The dictionary contains all the positions of that token pair throughout the chain, and once that pair is picked, all its occurrences are merged in the chain.

Linked list updates with positions

Now, let's talk about implementing a tokenizer class for encoding and decoding text. I've also created two helper functions:

  1. Listify – Converts the text string into a linked list.
  2. maxHeaper – Creates a max heap from the linked list of text tokens.
def listify(self, text):
  root = head = Node()

  for char in self.byte_encode(text):
    q = Node(token=char, value=bytes([char]))
    head.next = q
    q.prev = head
    head = q
  
  return root.next

def maxHeaper(self, root: Node):
  # Create a max-heap
  maxheap = MaxHeap()

  # Feed all tokens into the maxheap
  head = root.next
  while head and head.next:
    maxheap.addPair(head, head.next)
    head = head.next

  # Heapify
  maxheap.build()

  return maxheap

The Process

The entire process is illustrated in the image from the paper [2]. It includes the following steps:

  1. Pop the most frequent pair from the heap
  2. Fetch all positions where this pair occurs
  3. Merge the nodes at these positions
  4. Update the heap with the new pairs
  5. Repeat the process

Linked List based optimization

  1. Encode This function works as follows:

    1. Convert the input text string into a linked list of byte tokens
    2. Iteratively apply the learned merges to replace token pairs until no more merges can be applied
    3. Convert the final linked list into an array of token IDs
  2. Decode This function converts the array of token IDs back into a byte sequence and then decodes it into the original string.

def encode(self, text):
  # Convert string to linked list
  head = self.listify(text)

  # Apply merges in order
  for new_token, (a_token, b_token) in self.merges.items():
      node = head
      while node and node.next:
          if (
            node.token == a_token and 
            node.next.token == b_token
          ):
              # Merge node and node.next
              node.token = new_token
              node.value += node.next.value
              node.next = node.next.next
              if node.next:
                  node.next.prev = node
          else:
              node = node.next

  # Extract token IDs
  tokens = []
  node = head
  while node:
      tokens.append(node.token)
      node = node.next

  return tokens

def decode(self, seq):

  # Convert the tokens back to bytes
  text_bytes = b"".join(self.vocab[idx] for idx in seq)

  # Decode bytes back to characters
  text = text_bytes.decode("utf-8", errors="replace")

  return text

I've created a new tokenizer class named LLTokenizer class in the [3]colab notebook. Also, I've tested the tokenizer encoder and decoder.

Further Optimizations

The paper [2] (Section 5) discusses the use of memoization:

While we are not able to devise a polynomial-time scheme, we are able to find an exact algorithm that is, in some cases, faster than the brute-force technique of enumerating all valid merge sequences.

So here's a challenge for the reader.

This implementation is computationally expensive, which is why I’ll skip it for now. However, you’re encouraged to try implementing it yourself.

How slow is it? It maintains a stack where each item includes a full copy of the token sequence. As you can imagine, this grows quickly in both time and memory, especially with longer inputs or higher merge counts.

My take

Introducing additional data structures means allocating extra space proportional to the size of the input sequence. This is the classic trade-off in computer science—you can either save on time or on space. However, given modern hardware and relaxed space constraints, we can afford to use more space to save time. Additionally, the enormous corpus sizes demand faster preprocessing. Using O(N²) algorithms can easily overwhelm the hardware.

I'd love to hear your thoughts!

References