Distributed model inference using Spark
During the past few days, I have been working on running batch inference for a large amount of data using Spark. The goal was to generate embeddings:
- text embeddings using a fine-tuned BERT model
- image embeddings using a fine-tuned ViT model
I will use this TIL to share some of the key learnings I could get from this experience.
Pandas UDF vs Spark UDF
Pandas UDF instead of Spark UDF helps improve performance, also allows running in batches of rows/partition instead of a single row/partition (everything’s explained here).
@pandas_udf(...)
def predict(batch: pd.Series) -> pd.Series:
...
Broadcasting the model
Broadcasting the model to the workers should help as a read-only copy is kept on each spark worker (it’s also possible to broadcast the state_dict
only - example for PyTorch - and load the model every time your Pandas UDF is called, just make sure you are using a large batch otherwise you will spend more time loading the model than actually running inference).
broadcasted_model = spark.sparkContext.broadcast(model)
GPU vs CPU
Inference using a single GPU was much faster than using a CPU-based cluster (even with 8 cores/worker, which is basically having 8 threads/worker running in parallel).
model.to(device)
Predictions in batch
Run predictions in batch e.g. using PyTorch’s DataLoader
helps a lot. This means that you are going to run inference for multiple inputs at a time (you can update the batch size according to both the GPU usage and memory available), also don’t mistake this with the Pandas UDF batches, they are two different things.
from torch.utils.data import DataLoader
dataloader = DataLoader(input_dataset, batch_size=512)
for batch in dataloder:
batch.to(device)
model.predict(batch)
GPU usage
Always validate the GPU usage. Try to keep it above 90% to make sure the time is actually being spent with inference and not with the iterator or something else - this can easily be checked through Ganglia’s UI if you are using Databricks.
HDFS - Small files problem
When using images from an HDFS, consider loading them first to a table on Databricks - reading each image from the filesystem is really expensive when each image is a separate file. I opted for loading them to a delta table to have metadata along with less partitioned files, which also considerably improved the loading performance.
Always validate the number of partitions
This last one might sound a bit stupid but always make sure you are using multiple spark workers as well - while debugging I ran multiple experiments with df.limit(x)
, the thing is that this returns a dataframe with a single partition, thus only one spark worker was actually running the inference.