Transfer learning with PyTorch
The future of machine learning is so bright, because incentives are aligned across the board: big players eagerly open source tools and invest in faster hardware to escape their ad-based business models. Tinkerers discover niche applications unheard of before. Data becomes more fungible for private, public, scientific, leisure and adverse use. I could talk for hours on the perfect storm ahead (and maybe some contrarian thoughts), but lets stick to a practical thing: how to utilize the recent uptick of available pre-trained machine learning models.
Transfer learning
Tasks that are related, require a small set of underlying abilities to discriminate unseen data. For example, no matter if you play the guitar or a piano, you will have a better time picking out chords than someone not used to play. In machine learning terms, that means you can copy someone else's training effort and adapt it to rapidly classify pictures of hotdogs and no hotdogs or let it generate rather novel product reviews.
Sebastian Ruder nicely recapped what compounding gains to expect from model reuse.
In particular, transfer learning improves on all dimensions that determine the technical aspects of a machine learning project:
- Human efficiency. You need experts if you want to squeeze out the last ounce of signal, make models interpretable, tractable and robust. Thanks to academic research, architectures bubble up that are battle tested across related tasks.
- Compute efficiency. A state of the art paper usually trains around two weeks on clusters of 2 to 8 GPUs. Although, there is no limit really. With transfer learning, you can save on twiddling internal parameters by doing that only partially or less often.
- Data efficiency. If someone else trained on large datasets (that he or she doesn't even need to disclose), less domain specific data is needed in most cases. In fact, it's incredible how you can repurpose a model contained in a less than a 5MB download.
Landmark Architectures
Transfer learning is widely used to better distinguish specific image categories. Canziani et al. (2016) compared the compute efficiency for prime architectures when competing on the imagenet dataset.
Specifically, you can further divide compute costs into training time, inference time and memory needs, depending on your device constraints. For a recent sidegig, I needed to dig deeper, given very specific constraints.
First, new data comes in often. Second, those images are also potentially proprietary. Thus, retraining must happen reliably on a local mid-tier GPU without pampering by external experts. From a users's perspective, retraining is reliable if it gives consistent results in a consistent amount of time, just as if you'd click on a printing job to be done. Therefore, the baseline benchmark will use a simple optimization method, which favours convergence over data efficiency. Third, each prediction needs to happen at near real-time for a batch of around four images. Therefore we care about inference times. Finally, we care about the exact class of an image (or part of an image) as an input to business decisions. Thus we take the top-1 accuracy into consideration.
Now, to bring down costs for experts to virtually nil with an evolving dataset and fit compute times to hard constraints, lets look into the model zoo of one of the most coveted tools of the trade: PyTorch.
Results
There are a few reasons why PyTorch gets its swag nowadays.
Fair enough: the framework is reasonably complete, succinct, entirely defined in code and easy to debug. Six archetypes can be loaded with a single line of code from the torchvision package: AlexNet, DenseNets, Inception, SqueezeNet and VGG. To get an overview how they came about, I recommend this read.
Unfortunately, the API does not allow to load pretrained models with a custom number of classes sans reimplementing the final layers manually. Therefore, I wrote a function that solves that in a principled way. It merges pre-trained parameters so that they won’t interfere with custom output layers.. Thus, we can systematically compare performance metrics we care about across all available architectures.
The models in the chart were retrained on final layers only (shallow), for the entire set of parameters (deep) or from its initialized state (from scratch). In all runs, the dual K80 GPUs ran on about 75%.
As the data shows, SqueezeNet 1.1 delivers on its promise to be a compute efficient architecture. Because the divide into retrained and static layers is somewhat arbitrary though, a conclusion based solely on shallowly retrained models would be too far fetched. For example, the final classifier of VGG13 has 8194 parameters, whereas ResNet34's final layer is narrower with 1026 parameters. Therefore, only a hyperparameter search on learning strategies would make comparisons for the given objective truly valid. However, SqueezeNet learned incredibly fast shallow and deeply for the low number of classes in this dataset (binary to be precise).
Conclusion
One thing to notice is, how deeply retraining yield less accuracy for the same amount of training time than shallow retraining. That happened to me with other models in other tasks, too. My theory on this is, that wringing deeper convolutions from their previous local optima happens slowly on average (small error gradients) and on a less smooth manifold than if randomly initialized. Thus parameters might end up juxtaposed in a transitioning phase. Such an intermittent mess should be especially true if unseen features differ in scale compared to the original data, which is the case for today's toy dataset: sometimes there are ants up close and sometimes you observe an entire swarm.
So in fact you'd probably be well advised to up the learning rate for deep retraining, or at some point use rather scale invariant networks.
For the purpose of assessing compute efficiency composed of inference and retraining times, this is a pretty good start. If you want to increase retraining efficiency further, you can skip data augmentation and cache results from untrained layers. Of course, conclusions will differ with the number of classes, other factors mentioned before and the particular constraints. That's why you want to arrive on your own. I published the raw metrics and accompanying code on github to do so.
Happy training!
Nice write-up! Really enjoyed your talk at dataday!
Very interesting read, kind Sir.
* and pardon this unbearable typo: use instead of sue of course!