Optimizing latency

An exploration of ways to optimize on latency.
Published

March 1, 2024

Summary

Below is a summary of my findings:

  • 🏁 mlc is the fastest. This is so fast that I’m skeptical and am now motivated to measure quality (if I have time). When checking the outputs manually, they didn’t seem that different than other approaches.
  • ❤️ CTranslate2 is my favorite tool, which is among the fastest but is also the easiest to use. The documentation is the best out of all of the solutions I tried. Furthermore, I think that the ergonomics are excellent for the models that they support. Unlike vLLM, CTranslate doesn’t seem to support distributed inference just yet.
  • 🛠️ vLLM is really fast, but CTranslate can be much faster. On other hand, vLLM supports distributed inference, which is something you will need for larger models. vLLM might be the sweet spot for serving very large models.
  • 😐 Text Generation Inference is an ok option (but nowhere near as fast as vLLM) if you want to deploy HuggingFace LLMs in a standard way. TGI has some nice features like telemetry baked in (via OpenTelemetry) and integration with the HF ecosystem like inference endpoints. One thing to note that as of 7/28/2023, the license for TGI was changed to be more restrictive that may interfere with certain commercial uses. I am personally not a fan of the license.

Rough Benchmarks

This study focuses on various approaches to optimizing latency. Specifically, I want to know which tools are the most effective at optimizing latency for open source LLMs. In order to focus on latency, I hold the following variables constant:

  • batch size of n = 1 for all prediction requests (holding throughput constant).1
  • All experiments were conducted on a Nvidia A6000 GPU, unless otherwise noted.
  • Max output tokens were always set to 200.
  • All numbers are calculated as an average over a fixed set of 9 prompts.
  • The model used is meta-llama/Llama-2-7b-hf on the HuggingFace Hub 2.

In addition to batch size of n = 1 and using a A6000 GPU (unless noted otherwise), I also made sure I warmed up the model by sending an initial inference request before measuring latency.

Llama-v2-7b benchmark: batch size = 1, max output tokens = 200
avg tok/sec avg time (seconds) avg output token count
platform options gpu
CTranslate2 float16 quantization A6000 44.8 4.5 200.0
int8 quantization A6000 62.6 3.2 200.0
HF Hosted Inference Endpoint - A10G 30.4 6.6 202.0
HuggingFace Transformers (no server) - A6000 24.6 7.5 181.4
nf4 4bit quantization bitsandbytes A6000 24.3 7.6 181.4
TGI - A6000 21.1 9.5 200.0
quantized w/ GPTQ A6000 23.6 8.8 200.0
quantized w/ bitsandbytes A6000 1.9 103.0 200.0
mlc q4f16 A6000 117.1 1.3 153.9
text-generation-webui exllama A6000 77.0 1.7 134.0
vllm - A100 (on Modal Labs) 41.5 3.4 143.1
A6000 46.4 3.8 178.0

In some cases I did not use an A6000 b/c the platform didn’t have that particular GPU available. You can ignore these rows if you like, but I still think it is valuable information. I had access to a A6000, so I just used what I had.

I noticed that the output of the LLM was quite different (less tokens) when using vLLM. I am not sure if I did something wrong here, or it changes the behavior of the LLM.

Furthermore, the goal was not to be super precise on these benchmarks but rather to get a general sense of how things work and how they might compare to each other out of the box. Some of the tools above are inference servers which perform logging, tracing etc. in addition to optimizing models which effect latency. The idea is to see where there are significant differences between tools. I discussed this more here.

Background

One capability you need to be successful with open source LLMs is the ability to serve models efficiently. There are two categories of tools for model inference:

  • Inference servers: these help with providing a web server that can provide a REST/grpc or other interface to interact with your model as a service. These inference servers usually have parameters to help you make trade-offs between throughput and latency. Additionally, some inference servers come with additional features like telemetry, model versioning and more. You can learn more about this topic the serving section of these notes. For LLMs, popular inference servers are the Text Generation Inference (TGI) and vLLM.

  • Model Optimization: These modify your model to make them faster for inference. Examples include quantization, Paged Attention, Exllama and more.

It is common to use both Inference servers and Model Optimization techniques in conjunction. Some inference servers like TGIand vLLM even help you apply optimization techniques.3

Notes On Tools

Other than benchmarking, an important goal of this study was to understand how to use different platforms & tools.

mlc

Start with compiling the model as shown in these docs

After installing MLC, you can compile meta-llama/Llama-2-7b-chat-hf like so:

python3 -m mlc_llm.build \
--hf-path meta-llama/Llama-2-7b-chat-hf \
--target cuda --quantization q4f16_1

The arguments for the compliation are documented here. This puts the model in the ./dist/ folder with the name Llama-2-7b-chat-hf-q4f16_1.

You can use their python client to interact with the compiled model:

from mlc_chat import ChatModule, ChatConfig
cfg = ChatConfig(max_gen_len=200)
cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1", chat_config=cfg)
output = cm.generate(prompt=prompt)

You can see the full benchmarking code here.

Warning

I wasn’t able to get meta-llama/Llama-2-7b-hf to run correctly with the supplied python client so I am using the chat variant (Llama-2-7b-chat-hf) as a proxy. I asked the kind folks who work on the mlc project and they said the python client is currently designed for chat, such that they have this system prompt that is hard coded for llama models:

  conv.system =
      ("[INST] <<SYS>>\n\nYou are a helpful, respectful and honest assistant. "
       "Always answer as helpfully as possible, while being safe. "
       "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, "
       "or illegal content. "
       "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
       "If a question does not make any sense, or is not factually coherent, explain why instead "
       "of answering something not correct. "
       "If you don't know the answer to a question, please don't share false "
       "information.\n<</SYS>>\n\n ");

If you want to fix this, you must edit mlc-chat-config.json, changing conv_template to LM. These docs say more about the config.json.

The config file is located in ./dist/<model-name>/params/mlc-chat-config.json. For example:

> cat ./dist/Llama-2-7b-hf-q4f16_1/params/mlc-chat-config.json

{
    "model_lib": "Llama-2-7b-hf-q4f16_1",
    "local_id": "Llama-2-7b-hf-q4f16_1",
    "conv_template": "llama-2",
    "temperature": 0.7,
    "repetition_penalty": 1.0,
    "top_p": 0.95,
    "mean_gen_len": 128,
    "max_gen_len": 512,
    "shift_fill_factor": 0.3,
    "tokenizer_files": [
        "tokenizer.json",
        "tokenizer.model"
    ],
    "model_category": "llama",
    "model_name": "Llama-2-7b-hf"
}

CTranslate2

CTranslate2 is an optimization tool that can make models ridiculously fast. h/t to Anton. The documentation for CTranslate2 contains specific instructions for llama models.

To optimize llama v2, we first need to quantize the model. This can be done like so:

ct2-transformers-converter --model meta-llama/Llama-2-7b-hf --quantization int8 --output_dir llama-2-7b-ct2 --force

meta-llama/Llama-2-7b-hf refers to the HuggingFace repo for this model. The benchmarking code is as follows (can also be found here):

import time
import ctranslate2
import transformers
import sys
sys.path.append('../common/')
from questions import questions
import pandas as pd

generator = ctranslate2.Generator("llama-2-7b-ct2", device="cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

def predict(prompt:str):
    "Generate text give a prompt"
    start = time.perf_counter()
    tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
    results = generator.generate_batch([tokens], sampling_topk=1, max_length=200, include_prompt_in_result=False)
    tokens = results[0].sequences_ids[0]
    output = tokenizer.decode(tokens)
    request_time = time.perf_counter() - start
    return {'tok_count': len(tokens),
            'time': request_time,
            'question': prompt,
            'answer': output,
            'note': 'CTranslate2 int8 quantization'}

if __name__ == '__main__':
    counter = 1
    responses = []

    for q in questions:
        if counter >= 2: responses.append(predict(q))
        counter += 1

    df = pd.DataFrame(responses)
    df.to_csv('bench-ctranslate-int8.csv', index=False)

Text Generation Inference (TGI)

License Restrictions

The license for TGI was recently changed away from Apache 2.0 to be more restrictive. Be careful when using TGI in commercial applications.

Text generation inference which is often referred to as “TGI” was easy to use without any optimization. You can run it like this:

“start_server.sh”
#!/bin/bash

if [ -z "$HUGGING_FACE_HUB_TOKEN" ]
then
  echo "HUGGING_FACE_HUB_TOKEN is not set. Please set it before running this script."
  exit 1
fi

model="TheBloke/Llama-2-7B-GPTQ"
volume=$PWD/data

docker run --gpus all \
 -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
 -e GPTQ_BITS=4 -e GPTQ_GROUPSIZE=128 \
 --shm-size 5g -p 8081:80 \
 -v $volume:/data ghcr.io/huggingface/text-generation-inference \
 --max-best-of 1 "$@"

We can then run the server with this command:

bash start_server.sh --model-id "meta-llama/Llama-2-7b-hf"
Help

You can see all the options for the TGI container with the help flag like so:

docker run ghcr.io/huggingface/text-generation-inference --help | less

Quantization

Quantization was very difficult to get working. There is a —quantize flag with accepts bitsandbytes and gptq. The bitsandbytes approach makes inference much slower, which others have reported.

To make gptq work for llama v2 models requires a bunch of work, you have to install the text-generation-server which can take a while and is very brittle to get right. I had to step through the Makefile carefully. After that you have to download the weights with:

text-generation-server download-weights meta-llama/Llama-2-7b-hf

You can run the following command to perform the quantization (the last argument is the destination directory where the weights are stored).

text-generation-server quantize "meta-llama/Llama-2-7b-hf" data/quantized/

However, this step is not needed for the most popular models, as someone will likely already have quantized and uploaded them to the Hub.

Pre-Quantized Models

Alternatively, you can use a pre-quantized model that has been uploaded to the Hub. TheBloke/Llama-2-7B-GPTQ is a good example of one. To get this to work, you have to be careful to set the GPTQ_BITS and GPTQ_GROUPSIZE environment variables to match the config. For example This config necessitates setting GPTQ_BITS=4 and GPTQ_GROUPSIZE=128 These are already set in start_server.sh shown above. This PR will eventually fix that.

To use the TheBloke/Llama-2-7B-GPTQ with TGI, I can use the same bash script with the following arguments:

bash start_server.sh --model-id TheBloke/Llama-2-7B-GPTQ --quantize gptq

Comparison Without TGI Server

When I first drafted this study I got the following response on twitter:

Phillip certainly has a point! I am indeed testing both! I’m looking for big differences in tools here, and since some inference servers have optimization tools, and some optimization tools do not have an inference server I cannot do a true apples to apples comparison. However, I think its still useful to try different things as advertised to see what is possible, and also take note of really significant gaps in latency between tools.

Therefore, I ran the following tests to perform the similar optimizations as TGI, but without the server to see what happened:

HuggingFace Transformers

I was able to get slightly better performance without the TGI server as predicted by Phillip, but it did not account for the the massive gap between some tools (which is exactly the kind of thing I was looking for).

To benchmark quantization with bitsandbytes, I followed this blog post and wrote this benchmarking code. I quantized the model by loading it like this:

model_id = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.bfloat16
)
model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)

Unlike TGI, I was able to get bitsandbytes to work properly here, but just like TGI it didn’t speed anything up for me with respect to inference latency. As reflected in the benchmark table, I got nearly the same results with transformers without any optimizations.

GPTQ

I also quantized the model using AutoGPTQ without an inference server to compare against TGI. The code for that is here.

The results were so bad ~ 5 tok/sec that I decided not to put this in the table, because it seemed quite off to me.

Text Generation WebUI

Aman let me know about text-generation-web-ui, and also these instructions for quickly experimenting with ExLlama and ggml. I wasn’t able to get the ggml variant to work properly, unfortunately. If you are really serious about using exllama, I recommend trying to use it without the text generation UI and look at the exllama repo, specifically at test_benchmark_inference.py. (I didn’t have time for this, but if I was going to use exllama for anything serious I would go this route).

From the root of the text-generation-web-ui repo, you can run the following commands to start an inference server optimized with ExLlama:

python3 download-model.py TheBloke/Llama-2-7B-GPTQ
python3 server.py --listen --extensions openai --loader exllama_hf --model TheBloke_Llama-2-7B-GPTQ

After the server was started, I used this code to conduct the benchmark.

Overall, I didn’t like this particular piece of software much. It’s bit bloated because its trying to do too many things at once (An inference server, Web UIs, and other optimizations). That being said, the documentation is good and it is easy to use.

I don’t think there is any particular reason to use this unless you want an end-to-end solution that also comes with a web user-interface (which many people want!).

vLLM

vLLM only works with CUDA 11.8, which I configured using this approach. After configuring CUDA and installing the right version of PyTorch, you need to install the bleeding edge from git:

pip install -U git+https://github.com/vllm-project/vllm.git

A good recipe to use for vLLM can be find on these Modal docs. Surprisingly, I had much lower latency when running on a local A6000 vs. a hosted A100 on Modal Labs. It’s possible that I did something wrong here. Currently, vLLM is the fastest solution for when you need distributed inference (i.e. when your model doesn’t fit on a single GPU)..

vLLM offers a server, but I benchmarked the model locally using their tools instead. The code for the benchmarking can be found here:

from vllm import SamplingParams, LLM

#from https://modal.com/docs/guide/ex/vllm_inference

questions = [
    # Coding questions
    "Implement a Python function to compute the Fibonacci numbers.",
    "Write a Rust function that performs binary exponentiation.",
    "What are the differences between Javascript and Python?",
    # Literature
    "Write a story in the style of James Joyce about a trip to the Australian outback in 2083, to see robots in the beautiful desert.",
    "Who does Harry turn into a balloon?",
    "Write a tale about a time-traveling historian who's determined to witness the most significant events in human history.",
    # Math
    "What is the product of 9 and 8?",
    "If a train travels 120 kilometers in 2 hours, what is its average speed?",
    "Think through this step by step. If the sequence a_n is defined by a_1 = 3, a_2 = 5, and a_n = a_(n-1) + a_(n-2) for n > 2, find a_6.",
]

MODEL_DIR = "/home/ubuntu/hamel-drive/vllm-models"

def download_model_to_folder():
    from huggingface_hub import snapshot_download
    import os

    snapshot_download(
        "meta-llama/Llama-2-7b-hf",
        local_dir=MODEL_DIR,
        token=os.environ["HUGGING_FACE_HUB_TOKEN"],
    )
    return LLM(MODEL_DIR)


def generate(question, llm, note=None):
    response = {'question': question, 'note': note}
    sampling_params = SamplingParams(
        temperature=1.0,
        top_p=1,
        max_tokens=200,
    )
    
    start = time.perf_counter()
    result = llm.generate(question, sampling_params)
    request_time = time.perf_counter() - start

    for output in result:
        response['tok_count'] = len(output.outputs[0].token_ids)
        response['time'] = request_time
        response['answer'] = output.outputs[0].text
    
    return response

if __name__ == '__main__':
    llm = download_model_to_folder()
    counter = 1
    responses = []

    for q in questions:
        response = generate(question=q, llm=llm, note='vLLM')
        if counter >= 2:
            responses.append(response)
        counter += 1
    
    df = pd.DataFrame(responses)
    df.to_csv('bench-vllm.csv', index=False)

HuggingFace Inference Endpoint

I deployed an inference endpoint on HuggingFace for meta-llama/Llama-2-7b-hf, on a Nvidia A10G GPU. I didn’t try to turn on any optimizations like quantization and wanted to see what the default performance would be like.

The documentation for these interfaces can be found here. There is also a python client.

Their documentation says they are using TGI under the hood. However, my latency was significantly faster on their hosted inference platform than using TGI locally. This could be due to the fact that I used a A10G with them but only a A6000 locally. It’s worth looking into why this discrepancy exists further.

The code for this benchmark can be found here.

Footnotes

  1. It is common to explore the inference vs throughput frontier when conducting inference benchmarks. I did not do this, since I was most interested in latency. Here is an example of how to conduct inference benchmarks that consider both throughput and latency.↩︎

  2. For Llama v2 models, you must be careful to use the models ending in -hf as those are the ones that are compatible with the transformers library.↩︎

  3. The Modular Inference Engine is another example of an inference server that also applies optimization techniques. At the time of this writing, this is proprietary technology, but its worth keeping an eye on this in the future.↩︎