Machine Learning surprises all of us with incredible applications. Computers are able recognize people, objects and even drive a car on their own. All these systems are based in deep neural networks (DNN) which try to imitate the behavior of the human brain.
Imagine a baby’s brain. It might be empty, and it has to start to learn things from scratch, being able to differentiate between different objects and people. This process can take several months, even years. So DNN are quite the same.
At the beginning of the training process DNNs can’t make the difference between many things, but while the process is going forward it becomes more and more intelligent. However, as it happens with the babies’ learning process, a DNN takes quite long to learn new things.
Another drawback of a DNN is that you need a huge amount of data to do the training process from scratch. The amount of time and data required to train a DNN, lead to very few people training the whole DNN.
One real case scenario might be that we want our computer to be able to differentiate flowers among 5 different kinds of them. The first thing we will need is a big dataset of flowers. Each entry of the dataset must contain a flower image and a label saying what kind of flower it is.
To solve this problem we have two different approaches:
- Train the DNN from scratch: We will need a huge dataset, with at least 10.000 images for each class and several days to do the training process.
- Reuse a pre trained DNN: We can use any pretrained network for ImageNet and take advantage of its capabilities to create a new classifier.
This second approach is called transfer learning and that is what we are talking about today!
It is possible to take the output of any of the layers in a deep neural network and use it as an input of your own architecture. One common practice, is to feed the output of the last step of the convolutional layers into a new classifier. This output is a one dimension vector and it contains the features that the DNN has been able to extract from the image, using the weights it had learned before.
So removing the last fully connected layer allows you to treat the rest of the network as a feature extractor for the new dataset. In our case, we removed the last fully connected layers that classifies more than 1000 classes to create our classifier with the 5 different types of flowers.
As you can see in the image above, the DNN network creates a bottleneck with the main features of the image you have passed to it and then it is used to feed the flower classifier, which actually classify bottlenecks!
You can remove the last fully-connected layer of the DNN, but you can also fine-tune the weights of the other layers continuing backpropagation. You could retrain every single layer of the DNN if you want to, but it won’t be efficient at all. DNNs have a particular characteristic, the earlier layers of the network contain more generic features, while the later layers become more dataset specific.
Another good option to reuse a DNN is to freeze the weights of the earliest layers which generate the general features and retrain the last two or three convolutional layers with the new dataset. As in our flower classifier, the pretrained network is very accurate and we will achieve a good result with its weights.
It sounds quite easy, doesn’t it? However, fine tuning and feature extraction does not always work. You need to take into consideration your dataset and how the original DNN has been trained. The size of the new dataset and the similarity between the dataset are the key for a good fine tuning.
If your dataset is small, like the flowers dataset, you probably will have to retrain just a classifier, since it could generate overfitting. If both datasets are similar, we expect that the DNN can generate specific features, so you can take the bottleneck from the last layer of the DNN and then retrain a new classifier with it. This case is the one we have used with the flowers dataset. In the other case scenario, you should consider to extract the features not from the last layer of the network but from an earlier stage, due to the more general features are generated in the early layers.
In case the dataset is large, fine tuning is a good option. The network will be trained according to the similarity with the dataset. If the dataset is similar to the original one, you only need to train the last layers. On the other hand, if the datasets are not similar, it is still beneficial to initialize the new net with the weights from a pretrained model and our dataset is good enough to fine tune the entire network. In both cases, having a large dataset guarantees us not to overfit the net.
|Similar datasets||Not similar datasets|
|Small dataset||Train a new classifier with the bottleneck of the original net.||Train a new classifier with the output of an earlier layer because it will have more generic features.|
|Big dataset||Fine tune only the last layers of the original net.||Fine tune the entire original net.|
Back to the flower classificator, we have a small dataset similar to the pretrained net, so we are training a new classifier with the bottleneck of the original net.
It is time to get down to work! You can find a working transfer learning example in Tensorflow in our GitHub, where the proposed problem of a flower classificator is resolved. We have taken a pretrained ImageNet Inception V3 network and use it as a feature extractor for our new flower dataset. Then, we have trained a classifier whose input are Inception’s bottlenecks. You can try to try to classify other objects, but don’t forget to follow the tips explained above!
Brownlee, J. (2018). A Gentle Introduction to Transfer Learning for Deep Learning. [online] Machine Learning Mastery. Available at: https://machinelearningmastery.com/transfer-learning-for-deep-learning/
Cs231n.github.io. (2018). CS231n Convolutional Neural Networks for Visual Recognition. [online] Available at: http://cs231n.github.io/transfer-learning/
TensorFlow. (2018). How to Retrain an Image Classifier for New Categories | TensorFlow. [online] Available at: https://www.tensorflow.org/tutorials/image_retraining
Towards Data Science. (2018). Transfer Learning – Towards Data Science. [online] Available at: https://towardsdatascience.com/transfer-learning-946518f95666
Towards Data Science. (2018). What is Transfer Learning? – Towards Data Science. [online] Available at: https://towardsdatascience.com/what-is-transfer-learning-8b1a0fa42b4