Max Inference Engine

Attempting to load Mistral-7b in Modular’s new Max Inference Engine

February 29, 2024


Today, MAX runs only on CPUs, but the roadmap points to GPUs coming soon.


Mark Saroufim and I attempted to load Mistral-7b into the Max Engine.

Here are my initial impressions.

  • The Modular team is super engaging and friendly! I recommend going to their discord with questions.
  • Pytorch currently feels like a second class citizen (but may not be for long).
    • You have to do many more steps to potentially load a Pytorch model vs. TF, and those steps are not clear.1
    • However, I’ve heard from the team that they plan on investing heavily in Pytorch, even more so than Tensorflow.
  • I’m not sure why they led with BERT/Tensorflow examples. I would like to see paved paths for modern LLMs like Llama or Mistral. Keep an eye on this repo for examples as these will be added soon.
  • Model compilation and loading took 5 minutes. A progress bar for compilation would be really helpful.
  • Torchscript as a serialization format is older and in maintenance mode compared to more recent torch.compile or torch.export but Max doesn’t support that yet. Discussion is here. You will probably have better luck with LLMs by first exporting models to ONNX via Optimum.
  • Printing the model is not informative, like it is when you print a torch model (doesn’t show you all the layers and shapes).
  • We couldn’t quite understand the output of the model and we eventually hypothesized that torch.script is not the right serialization path for Mistral, but we aren’t sure. I think users may get confused by this.
  • Max currently targets CPUs rather than GPUs. I am not concerned by this as the roadmap points to GPUs coming soon. I’m hoping that the team can make AMD GPUs fly so we can break the hegemony of Nvidia.

I’m optimistic that these papercuts will be resolved soon. I’m pretty bullish on the talent level of the team working on these things.

Attempting To Load Mistral In The Max Engine

Today, the Modular team released the Max Inference Engine:

MAX Engine is a next-generation compiler and runtime system for neural network graphs. It supercharges the execution of AI models in any format (including TensorFlow, PyTorch, and ONNX), on a wide variety of hardware. MAX Engine also allows you to extend these models with custom ops that MAX Engine can analyze and optimize with other ops in the graph.

These docs show how to load a TensorFlow model, but I want to load a pytorch LLM like Mistral-7b. I documented my attempt at doing so here.

1. Serialize Model as Torchscript

In order to load your model in the Max engine we must serialize the model as torchscript. We can do this by tracing the model graph and then using to save the model.

import time
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# load model artifacts from the hub
model = AutoModelForCausalLM.from_pretrained(hf_path,torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(hf_path)
tokenizer.pad_token = tokenizer.eos_token

# trace the model and save it with torchscript
text = "This is text to be used for tracing"
# I'm setting the arguments for tokenizer once so I can reuse it later (they need to be consistent)
max_tokenizer = partial(tokenizer, return_tensors="pt",padding="max_length", max_length=max_seq_len)
inputs = max_tokenizer(text)
traced_model = torch.jit.trace(model, [inputs['input_ids'], inputs['attention_mask']]), model_path)

2. Specify Input Shape

Having a set input shape is required for compilation.

This next bit is from This code. Apparently there is a way to specify dynamic values for the sequence len and batch size, but we couldn’t figure that out easily from the docs.

from max import engine
input_spec_list = [
    engine.TorchInputSpec(shape=tensor.size(), dtype=engine.DType.int64)
    for tensor in inputs.values()
options = engine.TorchLoadOptions(input_spec_list)

3. Compile and Load Model

start_time = time.time()
session = engine.InferenceSession()
model = session.load(model_path, options)

end_time = time.time()

Wow! The model takes ~5 minutes to compile and load. Subsequent compilations are faster, but NOT if I restart the Jupyter Kernel.

elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")
Elapsed time: 371.86387610435486 seconds

4. Inference

We failed to get this to work

Even though we could call model.execute the outputs we got didn’t make much sense to us, even after some investigation. Our hypothesis is that execute is not calling model.generate. But this is where we gave up.

Be sure to set return_token_type_ids=False, note that I’m using the same arguments for padding and max_length that I used for tracing the model (because I’m using the max_tokenizer which I defined) so the shape is consistent.

INPUT="Why did the chicken cross the road?"
inp = max_tokenizer(INPUT, return_token_type_ids=False)
out = model.execute(**inp)

Get the token ids (predictions) and decode them:

preds = out['result0'].argmax(axis=-1)

We tried to debug this but could not figure out what was wrong, so we gave up here. We aren’t sure why the output looks like this. See what the output is supposed to look like in this section.

(Scroll to the right to see the full output)

' '.join(tokenizer.batch_decode(preds, skip_special_tokens=False))
'ммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммм # do you chicken cross the road?\n'

We were intrigued by the M and Mark joked that it is some interesting illuminati secret code injected into the model for (i.e. M for Modular) which I thought was funny :)

Our theory is that torchscript is not the right way to serialize this model and this is some kind of silent failure, but it is hard to know.

HuggingFace Comparison

Because the outputs of the Max model seemed wonky, we did a sanity check to see what the outputs look like when using HuggingFace transformers.


The below code loads model onto the GPU to quickly generate predictions for comparison. (Max doesn’t work with GPUs yet).

from transformers import AutoTokenizer, AutoModelForCausalLM
hfmodel = AutoModelForCausalLM.from_pretrained(hf_path,torchscript=True).cuda()
hftokenizer = AutoTokenizer.from_pretrained(hf_path)
hftokenizer.pad_token = hftokenizer.eos_token

_p="Why did the chicken cross the road?"
input_ids = hftokenizer(_p, return_tensors="pt", 
out_ids = hfmodel.generate(input_ids=input_ids, max_new_tokens=15, 
out = hftokenizer.batch_decode(out_ids.detach().cpu().numpy(), 

To get to the other side.

Why did the chicken


  1. The documentation states: “This example uses is a TensorFlow model (which must be converted to SavedModel format), and it’s just as easy to load a model from PyTorch (which must be converted to TorchScript format).” The just as easy part raised my expectations a bit too high as the Pytorch path is not as seamless.↩︎