More pre-trained models, please

More pre-trained models, please

Every classification task is linearly separable in the right feature space. This statement is a clue to building ML solutions that scale to many different use cases.

Here’s a visual explanation. Say you have data points in a given two-dimensional vector space. If you can use a straight line to separate the two classes, the dataset (and prediction problem) is linearly separable. Below, A is, and B isn’t.

Kujutiste tulemus päringule linearly separable

Note that B is still separable! It’s simply not linear; the separating line is more complicated. Linear separability is great because it lets you use straightforward, computationally cheap, easy-to-interpret, easy-to-implement models.

Now, there’s a problem with getting linearly separable models in image classification: the pixel space of all possible images is very complicated and not at all linearly separable. The solution since about 2012 is convolutional neural networks. You can imagine them learning two components simultaneously—first, a feature extractor: a function that takes in an image and gives out a vector. Second, a simple linear classifier on top: a function that takes in said vector and gives out class probabilities.

You can repurpose the feature extractor from any trained neural network. If you train your own linear classifier on top of the feature extractor, you have done transfer learning. That is, you’ve used the information learned in one task to make a different task faster to learn. Humans do it all the time, but it’s not very common in most machine learning pipelines I’ve seen.

So imagine you run a real estate ads website. You want to let users find images of specific rooms faster, so you need a classifier that can separate images of kitchens from images of bathrooms. The fastest way to get started is to use a neural net trained for object classification (separating cats, bananas, greenhouses, sweatshirts, etc.) as a feature extractor and use a small dataset of kitchen and bathroom images to train a linear classifier.

The benefit of such an approach is that you need much less task-specific data. For example, you could plausibly get to ~90% accuracy (assuming balanced classes) with only 10 labeled images of kitchens and bathrooms by pre-training on the ~1M images in ImageNet. The bulk of the data and compute budget is spent on the feature extractor, and the upside is you don’t have to train your own because ImageNet-trained nets are freely available online.

There are many public datasets on which you could pre-train. So, it’s surprising that people mostly seem to use ImageNet-based feature extractors, given its limited scope. Every extractor is only meant to detect things relevant for making good decisions on the original dataset. For ImageNet, that is “distinguishing physical objects in photos”. ImageNet-trained extractors have never seen cartoons, illustrations, or satellite images. There are few images taken in low-light conditions, with flash glare, etc.

ImageNet extractors are also trained to ignore “style” — everything other than content. A yellow-tinted versus a blue-tinted image of a house should produce the same classification on ImageNet. The extractor is not incentivized to differentiate the two, which might be necessary to predict a photo’s aesthetic quality.

It would be powerful to have lots of different feature extractors, each for a different task. One that cares a lot about style but very little about content. One that looks only at text layout but not content. One that looks at text content but ignores layout. Several extractors specialised in satellite images with resolution 1m/pixel, 10m/pixel, 100m/pixel, etc. There’s lots of room for variation. You might even have classical computer vision-based extractors with well-defined interpretations, like dominant colors.

Given a broad library of such feature extractors, it becomes much easier to build an ML model: all you need to do is a) pick a feature extractor and b) train a linear model. This is simpler than training a whole network from scratch not only because it takes less compute, but also because it takes much less skill. A domain expert will usually have a good intuition for what the relevant features for a decision are. Content or style? Is color essential? What are the images taken of?

Every feature extractor defines a multidimensional space, which is hard to intuit. But since you can calculate distances between points there, you can do tricks to let humans better understand the underlying structure of the space and dataset. You can show the most similar images to a given image in this space. You can show the main clusters or outliers that are furthest away from the central mass of examples. I’ve this sort of exploration and gotten a very good feel for the datasets I’ve spent lots of time with.


The value of all this is to use the thing humans are great at: intuitively figuring out the underlying structure. At the same time, computers do the heavy lifting: operationalizing that intuition, going through lots of images, and surfacing relevant ones. Put together, it’s the fastest way to build computer vision models, and coincidentally also the most general one!

If two companies (say, TransferWise and Revolut) both use models for processing passports for very similar purposes but have a few crucial differences in what is acceptable, they could still share the feature extractor but train their own models. You can even imagine a whole industry association pooling data and compute to build a shared feature extractor. Why? To get a model more powerful than what each could create on their own.

The reason I think about this is that it’s a step towards broader use of AI. If building and maintaining a simple ML-based feature requires a 200k€/year team, only large and motivated companies will even consider it. If it’s a matter of labeling 20 images in simple cloud software, I believe it can catch on in most small/startup technology companies.