Tokenization Gotchas

Footguns with tokenizers and inferencing LLMs

Background

Lots of people experience fiddly behavior when using LLMs. For example:

If you aren’t careful, these can be very hard to debug. This is because of the subtle ways tokenizers work that is not always easy to see by looking at the text.

Example

The below example demonstrates how things can get confusing and can drift between training and inference time.

from transformers import AutoTokenizer
from functools import partial
model_id = 'Open-Orca/Mistral-7B-OpenOrca'
tok = AutoTokenizer.from_pretrained(model_id)
enc = partial(tok.encode, add_special_tokens=False)
dec = partial(tok.decode)

Many frameworks do prompt construction by concatenating tokens

Popular frameworks like axolotl construct prompts by concatenating tokens instead of strings.1 It is reasonable to decode the training data to check what the prompt template is:

For example, a prompt may be constructed like this:

axolotl = enc('Ok\n') + enc('<|im_start|>')
print(dec(axolotl))
Ok
<|im_start|>

Let’s say you have an inference server

It’s common for inference servers to assemble the prompt for you. The below looks like it should be fine, right?

def inf_server(inp): 
    return f'{inp}\n<|im_start|>'

srv = inf_server('Ok')
print(srv)
Ok
<|im_start|>

Drift between your server and the way the model is trained

Wrong! Notice the difference in the decoding of the prompt vs the training data. This is a subtle problem that can be hard to debug.

print(f'axolotl training data:  {axolotl}')
print(f"your server's decoding: {enc(srv)}")
axolotl training data:  [6504, 13, 32001]
your server's decoding: [6504, 32001]

Solutions

1. Decode your inference data

Decode your inference data right before your forward pass. For example, you’ll notice the newline is missing if you do this. This is one way to tell that something fishy is going on.

dec(enc(srv))
'Ok<|im_start|>'

2. Use HF chat templating

Use the new HuggingFace chat template when possible. This will help avoid these issues (however, I would still check using method #1 to be sure!). Related GitHub Issue comment.

Example: Axolotl vs. HuggingFace Chat Templates

This is real example of how tokenization drift can bite you.

Chat Template From HuggingFace

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")

chat = [
   {"role": "system", "content": "lorem"},
   {"role": "user", "content": "abc"},
   {"role": "assistant", "content": "ipsum"},
   {"role": "user", "content": "123"},
   {"role": "assistant", "content": "sit"},
]

ids = tokenizer.apply_chat_template(chat)
print(tokenizer.decode(ids))
<s>[INST] <<SYS>>
lorem
<</SYS>>

abc [/INST] ipsum</s><s>[INST] 123 [/INST] sit</s>

Same thing decoded from Axolotl (with a space after <s>)

Got the token ids from this test.

axolotl_ids = [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 
                29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 
                13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 
                518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 
                29914, 25580, 29962, 7845, 2]
print(tokenizer.decode(axolotl_ids))
<s> [INST] <<SYS>>
lorem
<</SYS>>

abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>

Let’s decode HF tokens one at a time

for i in ids[:9]:
    print(f'{i}: {tokenizer.decode(i)}')
1: <s>
29961: [
25580: INST
29962: ]
3532: <<
14816: SY
29903: S
6778: >>
13: 

Let’s decode Axolotl tokens one at a time

See the second token 518 this is a mismatch with the HF Chat template which is 29961

for i in axolotl_ids[:9]:
    print(f'{i}: {tokenizer.decode(i)}')
1: <s>
518: [
25580: INST
29962: ]
3532: <<
14816: SY
29903: S
6778: >>
13: 

Why does this happen?

Axolotl assembles prompts in token space rather than string space.

tokenizer.encode('<s>', add_special_tokens=False) + tokenizer.encode('[INST]', add_special_tokens=False)
[1, 518, 25580, 29962]

HF Chat templates interpolate strings instead

tokenizer.encode('<s>[INST]', add_special_tokens=False)
[1, 29961, 25580, 29962]

Other Examples

These are other examples of people being bitten by drift between differences in tokenization between training and inference time:

  1. This GitHub Issue.
  2. This Tweet.

Footnotes

  1. This is for good reason, as masking must also be done at the token level.↩︎