Modality-Agnostic Pre-Training

Cedric Warny
11 min readApr 18, 2024

Self-supervised pre-training has become a common technique in machine learning, especially in natural language processing (NLP). In fact, pre-training has almost become the entire focus of NLP. ChatGPT is basically self-supervised pre-training, followed by a phase of reinforcement learning via human feedback (RLHF). The bulk of ChatGPT’s capabilities, though, comes from the pre-training phase, as is the bulk of the compute spend. Self-supervised pre-training is also becoming more and more common in the other big machine learning domain, namely computer vision, but still lagging behind NLP. Those different “domains” of machine learning often map to the so-called “modalities” they deal with, i.e. the kind of input: NLP deals with the text modality, while computer vision deals with images. In yet other modalities, such as pointclouds for instance, self-supervised pre-training remains rare.

In this post, I want to present a taxonomy of the existing pre-training techniques and discuss which techniques are the most scalable across modalities. This is because there’s a common trend toward machine learning models becoming more and more multimodal: from chat bots being able to process either text or images, to autonomous vehicles perceiving their environment via various sensors. This post is inspired by a recent series of papers about joint-embedding predictive architectures by the FAIR lab at Meta (more on those later).

What’s self-supervision and why is it useful?

Self-supervised pre-training is a phase in machine learning where you train a model to solve a very general task that isn’t really useful in itself. Importantly, the “labels” for that task are not supplied by humans, which is why it’s called “self-supervision”. In NLP, an example of such a task is to guess the next word in a sentence. You didn’t need a human annotator to provide you with the correct guess. I describe this task as “not really useful in itself” because who cares about a model really good at predicting the next word in a sentence? The neural net trained in that fashion isn’t really helpful in completing a useful task, like, for instance, translating from one language to another. In computer vision, an example of a pre-training task would be to match together pictures of the same thing taken from different perspectives. Is it useful in itself? Questionable. But we’ll see why they can be very useful in a larger context.

Typically, those neural networks are considered “foundation models” or “base models” or “backbones” (those are interchangeable terms) that are to be “inserted” into slightly more complex architectures or procedures that are geared toward solving specific problems. For ChatGPT, the pre-trained base model goes through a subsequent RLHF phase to make it actually useful. In computer vision, one typically attaches a “head” (another smaller neural net, typically untrained) to the pre-trained backbone and trains the whole thing (or just the head) on a task of real interest (eg, image classification). The training phase that comes after the pre-training phase is typically human-supervised, rather than self-supervised. As the name implies, RLHF is human-supervised. Image classification is typically trained on labels supplied by humans.

People often confuse self-supervised pre-training with semi-supervised pre-training, and claim they do the former when in fact they are just doing the latter. Those are quite different. In semi-supervision, one trains the full neural net (not just the backbone) on the task of interest (not an unrelated “useless” task), but the labels aren’t provided by humans, they are provided by another model (typically a larger version of the same model being trained, often referred to as the “teacher model”). It therefore feels like “self-supervision”. After all, it’s a model training another model! But the teacher model itself was trained with human supervision, so it doesn’t count as self-supervision.

So if those pre-training tasks are useless, what’s the point? Why are we spending many millions of dollars to train those foundation models? It’s because pre-training backbones has shown to have many benefits across ML domains, from natural language processing to computer vision, both in terms of performance and cost savings on a downstream task of interest. In terms of performance, pre-trained models generally outperform from-scratch models ceteris paribus. This means that we can choose to trade some of that performance boost to save on cost instead (e.g., train on less data).

Self-supervised learning is fundamentally about understanding the inner structure of the world, in a task-agnostic manner. This is akin to learning a useful inductive bias in the form of a good “world model”. Armed with that inductive bias, the model typically learns target tasks more easily, i.e. achieves better performance with the same amount of data, or the same performance with less data. In computer vision, learning a “good world model” means the backbone learns basic physics, develops expectations, priors. It learns what cars typically look like, how they tend to move. There’s even evidence that the backbone learns about things like object permanence. These aren’t things that semi-supervision can teach. In autonomous vehicles, a better world model doesn’t just yield better average performance, it can also directly help solve edge case problems, such as re-identifying a car that we lost track of because it was occluded by some other car, thanks to the model having developed some basic notion of object permanence. In NLP, a good world model means the backbone has learned grammar and semantics.

An additional important property of pre-training is that you often don’t need to fine-tune the entire model on your target task. Fine-tuning just the task-specific head often works well. This has further potential for cost-saving as the model release cycle could just involve fine-tuning the head, while the fixed cost of pre-training the backbone is amortized over multiple releases.

Pre-training approaches

There are three main categories of pre-training methods:

  1. Invariance-based methods
  2. Generative methods
  3. Embedding prediction methods

Invariance-based methods learn to represent something in a way that is invariant to some transformations of that thing. In Figure 1 below, x is obtained from y by some transformation or augmentation. Both x and y are then encoded and their embeddings compared. In computer vision, this would translate to mapping an image and a transformation of that image to a similar embedding. Such transformations include decolorization, saturation, cropping, rotation, etc. Typically such transformations are modality-specific and hand-crafted (“rotating sounds” doesn’t make sense).

Figure 1: Invariance-based pre-training methods. Source: V-JEPA.

Generative methods learn to reconstruct a corrupted input. A typical corruption technique involves removing parts of the input, and asking the model to guess them back, as illustrated in Figure 2 below. This is also known as masking and is a modality-agnostic pre-training method. Masking images, videos, text, sounds, pointclouds, etc. makes equal sense. Generative methods are called “generative” because they create something in input space, something that can be consumed by a human brain. An image, a speech, a video, a text, etc.

Figure 2: Generative pre-training methods.

Embedding prediction methods learn to represent different parts of an input so that their representations are predictive of each other. Those methods are similar to generative methods in that they involve recovering masked parts of the input, but they differ in that this happens in latent space, not in input space, as illustrated in Figure 3 below. In that sense, they are not generative, as they do not create something that can be consumed by a human brain.

Figure 3: Embedding prediction pre-training methods.

All those methods are used to pre-train backbones, which are then included in task-specific models. Those methods can be compared in terms of semanticness and generality.

Semanticness captures whether a representation is “high level” or “low level”. A high-level representation encodes the gist of an input, without too much attention to detail. A low-level representation retains fine-grained information about the input. Generally speaking, a pre-trained backbone is said to have good semantics if, on a target task, it achieves good frozen evaluation (only the head is fine-tuned), or good end-to-end evaluation (the full neural net is fine-tuned) with few shots (jargon for “only a little bit of human supervision”).

Generality, in turn, captures whether the pre-trained model adapts equally well to a variety of target tasks, but also whether the pre-training method works for a variety of modalities.

In Table 1 below, I score the various pre-training methods according to their semanticness and generality.

Table 1: Comparison of pre-training methods.

The case for the embedding prediction pre-training method

A good pre-training methodology should:

  • Work well for different input modalities
  • Have both high semanticness and high generality
  • Be cheap

Joint-embedding predictive architectures (JEPA) meet all these criteria.

In this section, I’ll discuss the embedding prediction methodology (based mostly on the I-JEPA and V-JEPA papers) with an eye toward autonomous vehicles, which constitute a good example of a ML system that clearly has multiple input modalities (camera, lidar, radar, etc.), and which could therefore benefit from a modality-agnostic pre-training methodology. For an even more thorough discussion of those papers, see my explainer video.

JEPAs have been shown to require less compute than alternative pre-training methods. Figure 4 below compares models whose backbones were pre-trained using an invariance-based method (iBOT), a generative method (MAE), and an embedding prediction method (I-JEPA) in terms of the models’ accuracy on ImageNet for various backbone sizes. Models whose backbone was pre-trained using an embedding prediction scheme consistently show better performance for cheaper runs.

Figure 4: Comparison of various pre-training methods in terms of accuracy on ImageNet and pre-training compute cost. Accuracy is computed after end-to-end fine-tuning on 1% of ImageNet.

On the Something-Something-v2 video classification task, models whose backbone were pre-trained using an embedding prediction scheme again achieve better performance with a significantly shorter pre-training schedule, compared to the generative pre-training method known as MAE (see Figure 5).

Figure 5: Comparison between a generative pre-training method (MAE) and an embedding prediction pre-training method (V-JEPA) on a downstream video-based recognition task. V-JEPA achieves better downstream performance with significantly less pre-training time. Y axis corresponds to classification accuracy.

In addition to being cheaper, JEPAs save on developer time. That is because the same pre-training methodology can be applied to all our input modalities. While there’s a plethora of generative pre-training methods for image or video data, there’s almost none for lidar data, due to the sparsity of the input space (pointclouds). The advent of embedding prediction methods allows us to unify the pre-training procedures for all the input modalities in autonomous vehicles, resulting in development efficiency.

JEPAs have also been shown to have good performance in frozen evaluation (evaluation after only fine-tuning the model head) and are good few-shot learners (i.e., can learn well with only a little help from friendly humans). Good frozen evaluation means a pre-trained backbone need not be fine-tuned at every release cycle, thereby speeding up development, as well as reducing compute cost by only training the model’s head for each release. A model pre-trained with the embedding prediction scheme, when evaluated on a video-based action recognition task using frozen evaluation, is a mere 5 percentage points off from its end-to-end evaluation performance (see Table 2 below).

Table 2: Top 1 accuracy on the Something-Something-v2 dataset for a vision transformer backbone with an attentive probe as head. The vision transformer was pre-trained with the embedding prediction methodology. In frozen evaluation, we only fine-tune the head. In end-to-end evaluation, we fine-tune the entire model (more expensive).

Being a good few-shot learner yields further cost savings in terms of data annotation. Models pre-trained with the embedding prediction methodology have been shown to maintain their performance much better than alternative pre-training methodologies on downstream tasks as the amount of supervision in the fine-tuning phase is reduced (see Table 3).

Table 3: Top-1 accuracy on Something-Something-v2 as the fraction of labels used in fine-tuning (i.e., amount of supervision) is reduced. The gap between the two methodologies increases, indicating that backbones pre-trained with the embedding prediction methodology are better few-shot learners than backbones pre-trained with a generative methodology.

This greater efficiency both in terms of compute and label comes down to a better world model. Importantly, such a world model need not simply improve things “on average”. The autonomous vehicle (AV) industry is already doing well “on average” and is now in the business of solving edge cases. But a better world model can also help in edge cases. For example, many AV companies are still struggling with embarrassing errors in perception, involving spatio-temporal inconsistency (unrealistic jitter, position change inconsistent with velocity estimation, elastic shape for rigid objects), and recovering from occlusion. Pre-training can help those edge cases. Backbones pre-trained with JEPAs have been shown to exhibit spatio-temporal consistency between the masked and unmasked regions of a video. In particular, the V-JEPA paper argues that “some of the samples […] demonstrate an understanding of object-permanence, as the visual objects remain consistent after partial occlusion.”

Putting it all together, the general modality-agnostic pre-training scheme I would advocate for is illustrated in Figure 6 below. The same scheme is re-usable across our input modalities, with the tokenization process being the only difference between modalities. The scheme is therefore highly general and its development cost can be amortized on at least two detection backbones.

Figure 6: High-level modality-agnostic pre-training scheme. The tokenization process would be the only part of the pipeline that would be modality-specific.

“Tokenization” is just the process of chunkifying a raw input. In NLP, tokenization means splitting a piece of text into words or even “sentence pieces” or “word pieces” that do not necessarily correspond to words. In computer vision, tokenization can mean to split an image into its constituent pixels, or into little patches of pixels. Tokenization basically determines the basic unit into which a raw input is decomposed. Typically, each token is then accompanied with some metadata such as where in a sentence this sentence piece was, or where in an image this patch was. As you decompose the raw input into discrete units, you necessarily need to attach some metadata to each unit in order to relate it back to the whole.

In summary, JEPAs yield three types of efficiency:

  1. Developer efficiency: The same pre-training methodology can be replicated across input modalities.
  2. Compute efficiency: Pre-training model backbones using an embedding prediction scheme is cheaper than alternative pre-training schemes. In addition, releases need only fine-tune the model’s head.
  3. Label efficiency: models whose backbone was pre-trained using an embedding prediction scheme need fewer labels to achieve the same performance as models whose backbone was pre-trained using alternative schemes.

This efficiency doesn’t just come from better average performance, but could also impact performance on edge cases where many AV companies are still struggling (e.g., better recovery from occlusion via developing an “object permanence” inductive bias during pre-training).

--

--