Basics

A minimal end-to-end example of TF Serving

These notes use code from here and this tutorial on tf serving.

Create The Model

Note

I didn’t want to use an existing model file from a tfserving tutorial, so I’m creating a new model from scratch.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from train import get_model

vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)
25000 Training sequences
25000 Validation sequences
Important

get_model is defined here

model = get_model(maxlen=maxlen, vocab_size=vocab_size, 
                  embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim)
Warning

You should be careful to specify dtype properly for the input layer, so that the tfserving api validation will work properly. Like this:

inputs = layers.Input(shape=(maxlen,), dtype='int32')
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 200)]             0         
                                                                 
 token_and_position_embeddin  (None, 200, 32)          646400    
 g (TokenAndPositionEmbeddin                                     
 g)                                                              
                                                                 
 transformer_block (Transfor  (None, 200, 32)          10656     
 merBlock)                                                       
                                                                 
 global_average_pooling1d (G  (None, 32)               0         
 lobalAveragePooling1D)                                          
                                                                 
 dropout_2 (Dropout)         (None, 32)                0         
                                                                 
 dense_2 (Dense)             (None, 20)                660       
                                                                 
 dropout_3 (Dropout)         (None, 20)                0         
                                                                 
 dense_3 (Dense)             (None, 2)                 42        
                                                                 
=================================================================
Total params: 657,758
Trainable params: 657,758
Non-trainable params: 0
_________________________________________________________________

Train Model

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(
    x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
)
Epoch 1/2
782/782 [==============================] - 49s 58ms/step - loss: 0.3977 - accuracy: 0.8056 - val_loss: 0.2856 - val_accuracy: 0.8767
Epoch 2/2
782/782 [==============================] - 19s 24ms/step - loss: 0.1962 - accuracy: 0.9258 - val_loss: 0.3261 - val_accuracy: 0.8608

Save Model

You can serialize your tensorflow models to a SavedModel format using tf.saved_model.save(...). This format is documented here. We are saving two versions of the model in order to discuss features of how TF Serving can serve multiple model versions.

!rm -rf ./model
def save_model(model_version, model_dir="./model"):

    model_export_path = f"{model_dir}/{model_version}"

    tf.saved_model.save(
        model,
        export_dir=model_export_path,
    )

    print(f"SavedModel files: {os.listdir(model_export_path)}")

save_model(model_version=1)
save_model(model_version=2)
WARNING:absl:Found untraced functions such as embedding_layer_call_fn, embedding_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 26). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./model/1/assets
WARNING:absl:Found untraced functions such as embedding_layer_call_fn, embedding_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 26). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./model/2/assets
INFO:tensorflow:Assets written to: ./model/1/assets
SavedModel files: ['fingerprint.pb', 'variables', 'assets', 'saved_model.pb']
INFO:tensorflow:Assets written to: ./model/2/assets
SavedModel files: ['fingerprint.pb', 'variables', 'assets', 'saved_model.pb']

Model versioning is done by saving your model into a directory with an integer. By default, the directory with the highest integer will be served. You can change this with config files.

!ls model/
1  2

Validate the API Schema

The output of the below command will show the input schema and shape, as well as the output shape of the API we will create with tfserving.

Thie below flags are mostly boilerplate. I don’t know what signature really means just yet.

!saved_model_cli show --dir ./model/2 --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
  inputs['input_1'] tensor_info:
      dtype: DT_INT32
      shape: (-1, 200)
      name: serving_default_input_1:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['dense_3'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 2)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

Launch the docker container

The TFServing docs really want you to use docker. But you can use the CLI tensorflow_model_server instead, which is what is packaged in the Docker container. This is what their docs say:

The easiest and most straight-forward way of using TensorFlow Serving is with Docker images. We highly recommend this route unless you have specific needs that are not addressed by running in a container.

TIP: This is also the easiest way to get TensorFlow Serving working with GPU support.

It worth looking at The Dockerfile for TFServing:

ENV MODEL_BASE_PATH=/models
RUN mkdir -p ${MODEL_BASE_PATH}

# The only required piece is the model name in order to differentiate endpoints
ENV MODEL_NAME=model


RUN echo '#!/bin/bash \n\n\
tensorflow_model_server --port=8500 --rest_api_port=8501 \
--model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} \
"$@"' > /usr/bin/tf_serving_entrypoint.sh \
&& chmod +x /usr/bin/tf_serving_entrypoint.sh

this means that it is looking in /models/model by default. We can consider this when mounting the local model directory into the container.

Suppose my local model is located at /home/hamel/hamel/notes/serving/tfserving/model. This is how you would run the Docker container:

docker run -p 8500:8500 \
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving/model,target=/models/model \
--net=host -t tensorflow/serving

TFServing on a GPU

See the note on using GPUs in TF Serving.

However, it probably only makes sense to enable the GPU if you are going to enable batching, or if a single prediction are GPU intensive (like Stable Diffusion)

Testing the API

According to the documentation we can see the status of our model like this:

GET http://host:port/v1/models/${MODEL_NAME}, which for us is:

curl https://localhost:8501/v1/models/model

! curl http://localhost:8501/v1/models/model
{
 "model_version_status": [
  {
   "version": "2",
   "state": "AVAILABLE",
   "status": {
    "error_code": "OK",
    "error_message": ""
   }
  }
 ]
}

Note how this shows the highest version number by default. You can access different model versions through different endpoints and supplying the right config files.

Model Versioning

Models that you save into the directory have a version number, for example our model is saved at home/hamel/hamel/notes/serving/tfserving/model with directories with versions 1 and 2.

!ls /home/hamel/hamel/notes/serving/tfserving/model
1  2

By default, TF Serving will always serve the model with the highest version number. However, you can change that with a model server config. You can also serve multiple versions of a model, add labels to models, etc. This is probably one of the most useful aspects of TF Serving. Here are some configs that allow you to serve multiple versions at the same time:

%%writefile ./model/models.config


model_config_list {
 config {
    name: 'model'
    base_path: '/models/model/'
    model_platform: 'tensorflow'
    model_version_policy: {all: {}}
        }
}
Overwriting ./model/models.config

If you wanted to specify specific models to serve, you could name the versions instead of specifying all like this:

%%writefile ./model/models-specific.config

model_config_list {
 config {
    name: 'model'
    base_path: '/models/model/'
    model_platform: 'tensorflow'
    model_version_policy {
      specific {
        versions: 1
        versions: 2
      }
    }
  }
}
Overwriting ./model/models-specific.config

To read the config files, we need to pass these additional flags when running the container:

docker run \
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving/model,target=/models/model \
--net=host \
-t tensorflow/serving \
--model_config_file=/models/model/models-specific.config \
--model_config_file_poll_wait_seconds=60 

The flag --model_config_file_poll_wait_seconds=60 tells the server to check for a new config file at the path every 60 seconds. This is optional but likely a good idea so you can change your config file without rebooting the server.

To access a specific version of the model, you would make a request to

http://host:port/v1/models/${MODEL_NAME}[/versions/${VERSION}|/labels/${LABEL}]:predict. For example, for version 1 the endpoint would be http://localhost:8501/v1/models/model/versions/1:predict.

If you did not care about the version, and just wanted the highest version we can use the general endpoint without the version which will serve the highest version by default:

http://localhost:8501/v1/models/model:predict

We can test that all of these version is avialable to serve like so:

! curl http://localhost:8501/v1/models/model/versions/2
{
 "model_version_status": [
  {
   "version": "2",
   "state": "AVAILABLE",
   "status": {
    "error_code": "OK",
    "error_message": ""
   }
  }
 ]
}
! curl http://localhost:8501/v1/models/model/versions/1
{
 "model_version_status": [
  {
   "version": "1",
   "state": "AVAILABLE",
   "status": {
    "error_code": "OK",
    "error_message": ""
   }
  }
 ]
}
Warning

TF Serving doesn’t make all versions available by default, only the latest one (with the highest number). You have to supply a config file if you want multiple versions to be made available at once. You probably should use labels to make URLs consistent in production scenarios.

Make a prediction request

REST

Time to make a prediction request. We will first try the REST API, which says the api endpoint is as follows: Note that v1 is just a hardcoded thing that has to do with the version of tfServing, not the version of the model:

POST http://host:port/v1/models/${MODEL_NAME}[/versions/${VERSION}|/labels/${LABEL}]:predict

import json, requests
import numpy as np

sample_data = x_val[:2, :]

data = json.dumps(
    {"signature_name": "serving_default", "instances": sample_data.tolist()}
)
url = "http://localhost:8501/v1/models/model:predict" # this would be "http://localhost:8501/v1/models/model/versions/1:predict" for version 1


def predict_rest(json_data, url):
    json_response = requests.post(url, data=json_data)
    response = json.loads(json_response.text)
    rest_outputs = np.array(response["predictions"])
    return rest_outputs

rest_outputs = predict_rest(data, url)
rest_outputs
array([[0.94086391, 0.05913605],
       [0.00317052, 0.99682945]])
model_outputs = model.predict(sample_data)
1/1 [==============================] - 0s 210ms/step

Let’s compare this to our model’s output. It’s close enough :)

assert np.allclose(rest_outputs, model_outputs, rtol=1e-4)

gRPC

  • The payload format for grpc uses Protocol Buffers which are compressed better than JSON, which might make latency lower. This makes a difference for higher payload sizes, like images.
  • gRPC has some kind of bi-directional streaming whereas REST is just a response/request model. I don’t know what this means.
  • gRPC uses a newer HTTP protocol than REST. I don’t know what this means.
import grpc

# Create a channel that will be connected to the gRPC port of the container
channel = grpc.insecure_channel("localhost:8500")
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# Get the serving_input key
loaded_model = tf.saved_model.load(model_export_path)
input_name = list(
    loaded_model.signatures["serving_default"].structured_input_signature[1].keys()
)[0]
input_name
'input_1'
def predict_grpc(data, input_name, stub):
    # Create a gRPC request made for prediction
    request = predict_pb2.PredictRequest()

    # Set the name of the model, for this use case it is "model"
    request.model_spec.name = "model"

    # Set which signature is used to format the gRPC query
    # here the default one "serving_default"
    request.model_spec.signature_name = "serving_default"

    # Set the input as the data
    # tf.make_tensor_proto turns a TensorFlow tensor into a Protobuf tensor
    request.inputs[input_name].CopyFrom(tf.make_tensor_proto(data))

    # Send the gRPC request to the TF Server
    result = stub.Predict(request)
    return result

sample_data = tf.convert_to_tensor(x_val[:2, :], dtype='int32')

grpc_outputs = predict_grpc(sample_data, input_name, stub)

Inspect the gRPC response

We can see all the fields that the gRPC response has. In this situation, the name of the final layer of our model will be the key that containst the predictions, which is dense_3 in this case.

grpc_outputs
outputs {
  key: "dense_3"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 2
      }
      dim {
        size: 2
      }
    }
    float_val: 0.9408639073371887
    float_val: 0.059136051684617996
    float_val: 0.0031705177389085293
    float_val: 0.9968294501304626
  }
}
model_spec {
  name: "model"
  version {
    value: 2
  }
  signature_name: "serving_default"
}

We can also get the name of the last layer of the model like this:

loaded_model.signatures["serving_default"].structured_outputs
{'dense_3': TensorSpec(shape=(None, 2), dtype=tf.float32, name='dense_3')}

Reshaping the Response

shape = [x.size for x in grpc_outputs.outputs['dense_3'].tensor_shape.dim]

grpc_preds = np.reshape(grpc_outputs.outputs['dense_3'].float_val, shape)
grpc_preds
array([[0.94086391, 0.05913605],
       [0.00317052, 0.99682945]])

The predictions are close enough. I am not sure why they wouldn’t be exactly the same.

assert np.allclose(model_outputs, grpc_preds,rtol=1e-4)