GPUs & Batching

Dynamic batching and using a GPU

Batching

Warning

I was not able to simulate a situation where dynamic batching is better than not batching. Apparently it can take time and lots of experiments to get right. Follow this guide for more information. This is a topic I may revisit in the future.

According to the docs:

Model Server has the ability to batch requests in a variety of settings in order to realize better throughput. The scheduling for this batching is done globally for all models and model versions on the server to ensure the best possible utilization of the underlying resources no matter how many models or model versions are currently being served by the server. You can enable this by using the --enable_batching flag and control it with the --batching_parameters_file.

This is an example batching parameters file:

%%writefile batch-config.cfg
max_batch_size { value: 1000 }
batch_timeout_micros { value: 1000 }
max_enqueued_batches { value: 16 }
num_batch_threads { value: 16 }
Overwriting batch-config.cfg
Guidance on batch configuration

Guidance for these config files are here there is no “right answer”. For GPUs, the guidance is this:

GPU: One Approach

If your model uses a GPU device for part or all of your its inference work, consider the following approach:

  1. Set num_batch_threads to the number of CPU cores.

  2. Temporarily set batch_timeout_micros to a really high value while you tune max_batch_size to achieve the desired balance between throughput and average latency. Consider values in the hundreds or thousands.

  3. For online serving, tune batch_timeout_micros to rein in tail latency. The idea is that batches normally get filled to max_batch_size, but occasionally when there is a lapse in incoming requests, to avoid introducing a latency spike it makes sense to process whatever’s in the queue even if it represents an underfull batch. The best value for batch_timeout_micros is typically a few milliseconds, and depends on your context and goals. Zero is a value to consider; it works well for some workloads. (For bulk processing jobs, choose a large value, perhaps a few seconds, to ensure good throughput but not wait too long for the final (and likely underfull) batch.)

Test the server

The model we are going to serve is generated in this note.

I’m going to start two TF Serving instances, one thats regular CPU and one that does batching on GPU. I’m running both commands from the /home/hamel/tf-serving/ directory.

CPU Version

docker run \
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving/model/,target=/models/model \
--net=host -t tensorflow/serving --grpc_max_threads=1000
Note

--net=host binds all ports to the host, which is convenient for testing

Test the CPU version:

! curl http://localhost:8501/v1/models/model
{
 "model_version_status": [
  {
   "version": "1",
   "state": "AVAILABLE",
   "status": {
    "error_code": "OK",
    "error_message": ""
   }
  }
 ]
}

GPU Version

Pre-requisites

You must install nvidia-docker first

Docker Command

You can pass additional arguments like --enable_batching to the docker run ... command just like you would if you were running tfserving locally.

Note that we need the --gpus all flag to enable GPUs with nvidia-Docker. Furthermore, use the latest-gpu tag to enable GPUs as well as the --port and --rest_api_port so that it doesn’t conflict with the other tf serving instance I have running:

docker run --gpus all \
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving,target=/models \
--net=host -t tensorflow/serving:latest-gpu --enable_batching \
--batching_parameters_file=/models/batch-config.cfg --port=8505 \
--rest_api_port=8506 --grpc_max_threads=1000
--grpc_max_threads flag

I found that in non-batch mode I can easily overwhelm the server with gRPC requests. I wasn’t able to overwhelm the server over REST. Setting --grpc_max_threads=1000 takes care of this.

Other flags

There are lots of flags. Hannes uses these additional ones, and they seem to make things a bit faster.

--enable_model_warmup  \
--tensorflow_intra_op_parallelism=4 \
--tensorflow_inter_op_parallelism=4
Understanding the volume mount

On the host, the config file is located at /home/hamel/hamel/notes/serving/tfserving/batch-config.cfg and the model is located at /home/hamel/hamel/notes/serving/tfserving/model/

The Docker file will try to import the model like this:

# Set where models should be stored in the container
ENV MODEL_BASE_PATH=/models
RUN mkdir -p ${MODEL_BASE_PATH}

# The only required piece is the model name in order to differentiate endpoints
ENV MODEL_NAME=model

# Create a script that runs the model server so we can use environment variables
# while also passing in arguments from the docker command line
RUN echo '#!/bin/bash \n\n\
tensorflow_model_server --port=8500 --rest_api_port=8501 \
--model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} \
"$@"' > /usr/bin/tf_serving_entrypoint.sh \
&& chmod +x /usr/bin/tf_serving_entrypoint.sh

By default it will try to get models from ${MODEL_BASE_PATH}/${MODEL_NAME} which is /models/model. So when we mount /home/hamel/hamel/notes/serving/tfserving from the host to /models in the container.

In the container:

  • The model files will be available at models/model as expected
  • The config file will be available at models/batch-config.cfg

Test the TF-Serving GPU api:

! curl http://localhost:8506/v1/models/model
{
 "model_version_status": [
  {
   "version": "1",
   "state": "AVAILABLE",
   "status": {
    "error_code": "OK",
    "error_message": ""
   }
  }
 ]
}

Benchmark

“All benchmarks are wrong, some are useful”

We are going to send 5 instances to score 10,000 times and measure the total inference time. We will parallelize the 10,000 requests (each with 5 instances to score) with threads. As a reminder, The model we are going to serve is generated in this note.

Prepare the data

from tensorflow import keras

vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review

_, (x_val, _) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)

sample_data = x_val[:5, :]
data = [sample_data] * 10000

The prediction code

import json, requests
import numpy as np

from fastcore.parallel import parallel
from functools import partial
parallel_pred = partial(parallel, threadpool=True, n_workers=500)


def predict_rest(data, port):
    json_data = json.dumps(
    {"signature_name": "serving_default", "instances": data.tolist()}
    )
    url = f"http://localhost:{port}/v1/models/model:predict"

    json_response = requests.post(url, data=json_data)
    response = json.loads(json_response.text)
    rest_outputs = np.array(response["predictions"])
    return rest_outputs
rest_outputs = predict_rest(sample_data, '8501')
rest_outputs
array([[0.89650154, 0.10349847],
       [0.00330466, 0.9966954 ],
       [0.13089457, 0.8691054 ],
       [0.49083445, 0.50916553],
       [0.0377177 , 0.96228224]])

gRPC

This is the code that will be used to make gRPC prediction requests. For more discussion about gRPC, see this note

import grpc
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
# Create a channel that will be connected to the gRPC port of the container



def predict_grpc(data, input_name='input_1', port='8505'):
    
    options = [('grpc.max_receive_message_length', 100 * 1024 * 1024)]
    channel = grpc.insecure_channel(f"localhost:{port}", options=options) # the gRPC port for the GPU server was set at 8505
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    # Create a gRPC request made for prediction
    request = predict_pb2.PredictRequest()

    # Set the name of the model, for this use case it is "model"
    request.model_spec.name = "model"

    # Set which signature is used to format the gRPC query
    # here the default one "serving_default"
    request.model_spec.signature_name = "serving_default"

    # Set the input as the data
    # tf.make_tensor_proto turns a TensorFlow tensor into a Protobuf tensor
    request.inputs[input_name].CopyFrom(tf.make_tensor_proto(data))

    # Send the gRPC request to the TF Server
    result = stub.Predict(request)
    return result

CPU Server

The CPU server is running on port 8501.

REST CPU

The REST API endpoint on the CPU-bound server.

cpu_pred = partial(predict_rest, port = '8501')
%%time
results = parallel_pred(cpu_pred, data)
CPU times: user 27.7 s, sys: 5.56 s, total: 33.3 s
Wall time: 26 s

grpc CPU

This is using the same CPU-bound TF Serving server, but is hitting the gRPC endpoint.

predict_grpc_cpu = partial(predict_grpc, port='8500')
%%time
results = parallel_pred(predict_grpc_cpu, data)
CPU times: user 7.5 s, sys: 2.33 s, total: 9.84 s
Wall time: 7.63 s

GPU Server with batching

The GPU server is running on port 8506 (we already started it above).

REST

gpu_pred = partial(predict_rest, port = '8506')
%%time
results = parallel_pred(gpu_pred, data)
CPU times: user 27.1 s, sys: 3.44 s, total: 30.6 s
Wall time: 27 s

gRPC with batch

This is much faster than the REST endpoint! This is also much faster than the CPU version on this specific example. However, the batching part doesn’t appear to be providing any speedup at all, because the non-batch gRPC version is almost the same speed (if not a little bit faster).

%%time
result = parallel(predict_grpc, data)
CPU times: user 2.71 s, sys: 551 ms, total: 3.26 s
Wall time: 6.6 s

GPU server without batching

docker run --gpus all --mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving,target=/models --net=host -t tensorflow/serving:latest-gpu --port=8507 --rest_api_port=8508

REST

gpu_pred_no_batch = partial(predict_rest, port = '8508')
%%time
results = parallel_pred(gpu_pred_no_batch, data)
CPU times: user 26.9 s, sys: 3.61 s, total: 30.5 s
Wall time: 25.7 s

gRPC without batching

When I initially did this I got an error that said “Resources Exhausted”. I was able to solve this by increasing the threads with the flag --grpc_max_threads=1000 when running the Docker container.

predict_grpc_no_batch = partial(predict_grpc, port='8507')
%%time
result = parallel_pred(predict_grpc_no_batch, data)
CPU times: user 5.06 s, sys: 1.42 s, total: 6.48 s
Wall time: 6.65 s