Optimizing style transfer to run on mobile with TFLite
Neural style transfer is an optimization technique used to take two images, a content image (such as a building) and a style image (such as artwork by an iconic painter), and blend them together so the output image looks like the content image “painted” in the style of the reference image. Today, we are excited to share a pre-trained style transfer TensorFlow Lite model that is optimized for mobile, and an Android and an iOS sample app that uses the model to stylize any images.
In this article, we will walk you through the journey of optimizing the large TensorFlow model for mobile deployment, and how to use it efficiently in a mobile app with TensorFlow Lite. We hope that you can use our pre-trained style transfer model or leverage our insights for your use cases.
|An example of style transfer|
Style transfer was first published in A Neural Algorithm of Artistic Style. The original technique, however, was computationally expensive and it can take several seconds to stylize an image even on high-end GPUs. Subsequent work by several authors (for example) showed how to speed up style transfer.
After evaluating several model architectures, we decided to start with a pre-trained arbitrary style transfer model from Magenta for our sample app. The model can take any content and style image as input, then use a feedforward neural network to generate a stylized output image. This model allows much faster style transfer compared to the technique in Gatys’s paper, but it is still quite large (44 MB) and slow (2340 ms on Pixel 4 CPU). Therefore, we need to optimize the model to make it suitable to use in mobile applications. This article shares our experience doing so, with resources you can take advantage of in your work.
Optimizing the model architecture
|The structure of our style transfer model|
The Magenta’s arbitrary style transfer model consists of two subnetworks:
- Style prediction network: converts the style image to a style embedding vector.
- Style transform network: applies the style embedding vector on the content image to generate a stylized image.
Magenta’s style prediction network has an InceptionV3 backbone, so we replaced it with a MobileNetV2 backbone, which is optimized for mobile. The style transform network consists of several convolution layers. We applied the width multiplier idea from MobileNet, scaling down the number of output channels of all convolution layers by a factor of 4.
Then, we had to decide how to train our model. We experimented with multiple options: training the mobile model from scratch or distilling from the pre-trained Magenta’s model. We found that fixing the weights of MobileNetV2 while optimizing other parameters from scratch gave the best result.
We were able to achieve a similar level of style and content loss, while significantly shrinking and speeding up the model.
|* Benchmarked on Pixel 4 CPU using TensorFlow Lite with 2 threads, April 2020.
* See this paper for more details about the definition of loss function used in this style transfer model
Once we have settled on the model architecture, we continue to shrink our mobile model further with quantization using the TensorFlow Model Optimization Toolkit. This is an important technique that is applicable for most mobile deployment of TensorFlow models, as it can shrink the model size up to 4X and speed up model inference with insignificant quality trade-off.
Among the quantization options available that TensorFlow provides, we decided to use post-training integer quantization because it has the right balance of simplicity and model quality. We only needed to provide a small portion of our training dataset when converting the TensorFlow model to TensorFlow Lite.
After quantization, our model is more than an order smaller and faster than the original model, while maintaining the same level of style and content loss.
|* Benchmarked on Pixel 4 CPU using TensorFlow Lite with 2 threads, April 2020.|
Deployment to mobile
We implemented an Android app to demonstrate how to use the style transfer model. The app takes a style image, a content image, and outputs an image that mixes the style and content of the input images.
We use the phone’s camera to capture the content images with the Camera2 API and provide a set of famous paintings to be used as style images. As mentioned above, there are two steps to apply a style to a content image. Firstly, we extract the style as an array floats using the style prediction network. Then we apply this style to the content image using the style transform network.
In order to achieve the best performance on both CPU and GPU, we created two sets of TensorFlow Lite models optimized for each chip. We use the int8 quantized model for CPU inference, and float16 quantized model for GPU inference. GPU generally achieves better performance than CPU but it currently only supports float models, which are larger than int8 quantized models. Here is how the int8 and the float16 model perform.
|* Benchmarked on Pixel 4 using TensorFlow Lite, April 2020.|
Another possible performance gain is to cache the results of the style prediction network if you only plan to support a fixed set of style images in your mobile app. This will make your app smaller as you do not need to include the style prediction network, which accounts for 91% of the total network size. This is the main reason why the process is splitted into two models instead of only one.
The sample can be found on GitHub and the main class applying style is the StyleTransferModelExecutor.
It is important that we do not run style transfer on the UI thread as it is computational expensive. We instead use the ViewModel class from AndroidX and a Coroutine to run it on a dedicated background thread and easily update the view. Besides, when running a model using GPU delegate, TF Lite interpreter initialization, GPU delegate initialization and inference have to run all on the same thread.
Style transfer in production
The Google Arts & Culture app recently added Art Transfer that uses TensorFlow Lite to run style transfer on-device. The model used is very similar to the one above but prioritizes quality over speed and model size. Try it out if you are interested in seeing style transfer in production.
If you want to add Style Transfer to your own app, you can start by downloading the mobile sample. Both model versions, the float16 (predict network, transform network) and the int8 quantized version (predict network, transform network), are available on TensorFlow Hub. We can’t wait to see what you can create! Don’t forget to share with us your creations.
Running machine learning models on-device has the benefits of keeping the users data private while enabling features with low latency.
In this post, we have shown that directly converting a TensorFlow model to TensorFlow Lite might be just the first step. To achieve good performance, developers should optimize their model with quantization, and find the right trade-off between model quality, model size, and inference time.
We used the resources below to create our model. They might be also applicable to your on-device machine learning use cases.
- Magenta model repository
Magenta is an open source project powered by TensorFlow. It uses machine learning to make music and art. There are many models that can be converted to TensorFlow Lite, including this style transfer model.
- TensorFlow Model Optimization Toolkit
Model Optimization Toolkit provides multiple methods to optimize your model, including quantization and pruning.
- TensorFlow Lite delegates
TensorFlow Lite can leverage many different types of hardware accelerator available on devices, including GPUs and DSPs, to speed up model inference.
Related Google News:
- Ship your Go applications faster to Cloud Run with ko February 16, 2021
- How to trigger Cloud Run actions on BigQuery events February 10, 2021
- Improving Mobile App Accessibility with Icon Detection January 28, 2021
- Eventarc brings eventing to Cloud Run and is now GA January 28, 2021
- Lifecycle of a container on Cloud Run January 26, 2021
- A Google designer takes us inside Search’s mobile redesign January 22, 2021
- Introducing WebSockets, HTTP/2 and gRPC bidirectional streams for Cloud Run January 21, 2021
- Improved mobile device management rules experience in the Admin console January 13, 2021