import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from train import get_model
= 20000 # Only consider the top 20k words
vocab_size = 200 # Only consider the first 200 words of each movie review
maxlen = 32 # Embedding size for each token
embed_dim = 2 # Number of attention heads
num_heads = 32 # Hidden layer size in feed forward network inside transformer ff_dim
Basics
These notes use code from here and this tutorial on tf serving.
Create The Model
I didn’t want to use an existing model file from a tfserving tutorial, so I’m creating a new model from scratch.
= keras.datasets.imdb.load_data(num_words=vocab_size)
(x_train, y_train), (x_val, y_val) print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
= keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_train = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen) x_val
25000 Training sequences
25000 Validation sequences
get_model
is defined here
= get_model(maxlen=maxlen, vocab_size=vocab_size,
model =embed_dim, num_heads=num_heads, ff_dim=ff_dim) embed_dim
You should be careful to specify dtype
properly for the input layer, so that the tfserving
api validation will work properly. Like this:
= layers.Input(shape=(maxlen,), dtype='int32') inputs
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 200)] 0
token_and_position_embeddin (None, 200, 32) 646400
g (TokenAndPositionEmbeddin
g)
transformer_block (Transfor (None, 200, 32) 10656
merBlock)
global_average_pooling1d (G (None, 32) 0
lobalAveragePooling1D)
dropout_2 (Dropout) (None, 32) 0
dense_2 (Dense) (None, 20) 660
dropout_3 (Dropout) (None, 20) 0
dense_3 (Dense) (None, 2) 42
=================================================================
Total params: 657,758
Trainable params: 657,758
Non-trainable params: 0
_________________________________________________________________
Train Model
compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.= model.fit(
history =32, epochs=2, validation_data=(x_val, y_val)
x_train, y_train, batch_size )
Epoch 1/2
782/782 [==============================] - 49s 58ms/step - loss: 0.3977 - accuracy: 0.8056 - val_loss: 0.2856 - val_accuracy: 0.8767
Epoch 2/2
782/782 [==============================] - 19s 24ms/step - loss: 0.1962 - accuracy: 0.9258 - val_loss: 0.3261 - val_accuracy: 0.8608
Save Model
You can serialize your tensorflow models to a SavedModel
format using tf.saved_model.save(...)
. This format is documented here. We are saving two versions of the model in order to discuss features of how TF Serving can serve multiple model versions.
!rm -rf ./model
def save_model(model_version, model_dir="./model"):
= f"{model_dir}/{model_version}"
model_export_path
tf.saved_model.save(
model,=model_export_path,
export_dir
)
print(f"SavedModel files: {os.listdir(model_export_path)}")
=1)
save_model(model_version=2) save_model(model_version
WARNING:absl:Found untraced functions such as embedding_layer_call_fn, embedding_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 26). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./model/1/assets
INFO:tensorflow:Assets written to: ./model/1/assets
SavedModel files: ['fingerprint.pb', 'variables', 'assets', 'saved_model.pb']
WARNING:absl:Found untraced functions such as embedding_layer_call_fn, embedding_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 26). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./model/2/assets
INFO:tensorflow:Assets written to: ./model/2/assets
SavedModel files: ['fingerprint.pb', 'variables', 'assets', 'saved_model.pb']
Model versioning is done by saving your model into a directory with an integer. By default, the directory with the highest integer will be served. You can change this with config files.
!ls model/
1 2
Validate the API Schema
The output of the below command will show the input schema and shape, as well as the output shape of the API we will create with tfserving.
Thie below flags are mostly boilerplate. I don’t know what signature
really means just yet.
!saved_model_cli show --dir ./model/2 --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['input_1'] tensor_info:
dtype: DT_INT32
shape: (-1, 200)
name: serving_default_input_1:0
The given SavedModel SignatureDef contains the following output(s):
outputs['dense_3'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
Launch the docker container
The TFServing docs really want you to use docker. But you can use the CLI tensorflow_model_server
instead, which is what is packaged in the Docker container. This is what their docs say:
The easiest and most straight-forward way of using TensorFlow Serving is with Docker images. We highly recommend this route unless you have specific needs that are not addressed by running in a container.
TIP: This is also the easiest way to get TensorFlow Serving working with GPU support.
It worth looking at The Dockerfile for TFServing:
ENV MODEL_BASE_PATH=/models
RUN mkdir -p ${MODEL_BASE_PATH}
# The only required piece is the model name in order to differentiate endpoints
ENV MODEL_NAME=model
RUN echo '#!/bin/bash \n\n\
tensorflow_model_server --port=8500 --rest_api_port=8501 \
--model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} \
"$@"' > /usr/bin/tf_serving_entrypoint.sh \
&& chmod +x /usr/bin/tf_serving_entrypoint.sh
this means that it is looking in /models/model
by default. We can consider this when mounting the local model directory into the container.
Suppose my local model is located at /home/hamel/hamel/notes/serving/tfserving/model
. This is how you would run the Docker container:
docker run -p 8500:8500 \
\
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving/model,target=/models/model -t tensorflow/serving --net=host
TFServing on a GPU
See the note on using GPUs in TF Serving.
However, it probably only makes sense to enable the GPU if you are going to enable batching, or if a single prediction are GPU intensive (like Stable Diffusion)
Testing the API
According to the documentation we can see the status of our model like this:
GET http://host:port/v1/models/${MODEL_NAME}
, which for us is:
curl https://localhost:8501/v1/models/model
! curl http://localhost:8501/v1/models/model
{
"model_version_status": [
{
"version": "2",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
Note how this shows the highest version number by default. You can access different model versions through different endpoints and supplying the right config files.
Model Versioning
Models that you save into the directory have a version number, for example our model is saved at home/hamel/hamel/notes/serving/tfserving/model
with directories with versions 1
and 2
.
!ls /home/hamel/hamel/notes/serving/tfserving/model
1 2
By default, TF Serving will always serve the model with the highest version number. However, you can change that with a model server config. You can also serve multiple versions of a model, add labels to models, etc. This is probably one of the most useful aspects of TF Serving. Here are some configs that allow you to serve multiple versions at the same time:
%%writefile ./model/models.config
model_config_list {
config {'model'
name: '/models/model/'
base_path: 'tensorflow'
model_platform: all: {}}
model_version_policy: {
} }
Overwriting ./model/models.config
If you wanted to specify specific models to serve, you could name the versions instead of specifying all
like this:
%%writefile ./model/models-specific.config
model_config_list {
config {'model'
name: '/models/model/'
base_path: 'tensorflow'
model_platform:
model_version_policy {
specific {1
versions: 2
versions:
}
}
} }
Overwriting ./model/models-specific.config
To read the config files, we need to pass these additional flags when running the container:
docker run \
\
--mount type=bind,source=/home/hamel/hamel/notes/serving/tfserving/model,target=/models/model \
--net=host \
-t tensorflow/serving \
--model_config_file=/models/model/models-specific.config --model_config_file_poll_wait_seconds=60
The flag --model_config_file_poll_wait_seconds=60
tells the server to check for a new config file at the path every 60 seconds. This is optional but likely a good idea so you can change your config file without rebooting the server.
To access a specific version of the model, you would make a request to
http://host:port/v1/models/${MODEL_NAME}[/versions/${VERSION}|/labels/${LABEL}]:predict
. For example, for version 1 the endpoint would be http://localhost:8501/v1/models/model/versions/1:predict
.
If you did not care about the version, and just wanted the highest version we can use the general endpoint without the version which will serve the highest version by default:
http://localhost:8501/v1/models/model:predict
We can test that all of these version is avialable to serve like so:
! curl http://localhost:8501/v1/models/model/versions/2
{
"model_version_status": [
{
"version": "2",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
! curl http://localhost:8501/v1/models/model/versions/1
{
"model_version_status": [
{
"version": "1",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
TF Serving doesn’t make all versions available by default, only the latest one (with the highest number). You have to supply a config file if you want multiple versions to be made available at once. You probably should use labels to make URLs consistent in production scenarios.
Make a prediction request
REST
Time to make a prediction request. We will first try the REST API, which says the api endpoint is as follows: Note that v1
is just a hardcoded thing that has to do with the version of tfServing, not the version of the model:
POST http://host:port/v1/models/${MODEL_NAME}[/versions/${VERSION}|/labels/${LABEL}]:predict
import json, requests
import numpy as np
= x_val[:2, :]
sample_data
= json.dumps(
data "signature_name": "serving_default", "instances": sample_data.tolist()}
{
)= "http://localhost:8501/v1/models/model:predict" # this would be "http://localhost:8501/v1/models/model/versions/1:predict" for version 1
url
def predict_rest(json_data, url):
= requests.post(url, data=json_data)
json_response = json.loads(json_response.text)
response = np.array(response["predictions"])
rest_outputs return rest_outputs
= predict_rest(data, url) rest_outputs
rest_outputs
array([[0.94086391, 0.05913605],
[0.00317052, 0.99682945]])
= model.predict(sample_data) model_outputs
1/1 [==============================] - 0s 210ms/step
Let’s compare this to our model’s output. It’s close enough :)
assert np.allclose(rest_outputs, model_outputs, rtol=1e-4)
gRPC
- The payload format for grpc uses Protocol Buffers which are compressed better than JSON, which might make latency lower. This makes a difference for higher payload sizes, like images.
- gRPC has some kind of bi-directional streaming whereas REST is just a response/request model. I don’t know what this means.
- gRPC uses a newer HTTP protocol than REST. I don’t know what this means.
import grpc
# Create a channel that will be connected to the gRPC port of the container
= grpc.insecure_channel("localhost:8500") channel
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
= prediction_service_pb2_grpc.PredictionServiceStub(channel) stub
# Get the serving_input key
= tf.saved_model.load(model_export_path)
loaded_model = list(
input_name "serving_default"].structured_input_signature[1].keys()
loaded_model.signatures[0] )[
input_name
'input_1'
def predict_grpc(data, input_name, stub):
# Create a gRPC request made for prediction
= predict_pb2.PredictRequest()
request
# Set the name of the model, for this use case it is "model"
= "model"
request.model_spec.name
# Set which signature is used to format the gRPC query
# here the default one "serving_default"
= "serving_default"
request.model_spec.signature_name
# Set the input as the data
# tf.make_tensor_proto turns a TensorFlow tensor into a Protobuf tensor
request.inputs[input_name].CopyFrom(tf.make_tensor_proto(data))
# Send the gRPC request to the TF Server
= stub.Predict(request)
result return result
= tf.convert_to_tensor(x_val[:2, :], dtype='int32')
sample_data
= predict_grpc(sample_data, input_name, stub) grpc_outputs
Inspect the gRPC response
We can see all the fields that the gRPC response has. In this situation, the name of the final layer of our model will be the key that containst the predictions, which is dense_3
in this case.
grpc_outputs
outputs {
key: "dense_3"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
dim {
size: 2
}
}
float_val: 0.9408639073371887
float_val: 0.059136051684617996
float_val: 0.0031705177389085293
float_val: 0.9968294501304626
}
}
model_spec {
name: "model"
version {
value: 2
}
signature_name: "serving_default"
}
We can also get the name of the last layer of the model like this:
"serving_default"].structured_outputs loaded_model.signatures[
{'dense_3': TensorSpec(shape=(None, 2), dtype=tf.float32, name='dense_3')}
Reshaping the Response
= [x.size for x in grpc_outputs.outputs['dense_3'].tensor_shape.dim]
shape
= np.reshape(grpc_outputs.outputs['dense_3'].float_val, shape)
grpc_preds grpc_preds
array([[0.94086391, 0.05913605],
[0.00317052, 0.99682945]])
The predictions are close enough. I am not sure why they wouldn’t be exactly the same.
assert np.allclose(model_outputs, grpc_preds,rtol=1e-4)