Inductive bias and why overfitting a 1/2 trillion words looks sexy.

Why deep learning does so well at generating natural language, and why it doesn't.

The syntax of natural language is a giant bag of rules for how different words and phrases fit together. For example, "I am hungry" or "I hunger" is syntactically valid, "I hungry" is not. If you've ever tried to conjugate context-dependent irregular verbs in a second language, you intuitively know how challenging syntax is.

One of the primary strengths of deep learning is its unparalleled ability to learn arbitrarily complex statistical relationships if given enough data. I wrote previously about how various "syntactic features" of a language (e.g., case, agreement, and word order) cause statistical regularities in natural language text. With enough data, deep learning allows one to learn a model on these regularities that encodes what is essentially a long list of probabilistic syntactic rules. Learning statistical trends in syntax is more feasible than exhaustively hard-coding an explicit list of syntax rules.

If the only thing deep learning could do were string words together into syntactically valid orderings, this alone would be extremely valuable. Of course, deep learning goes beyond syntax. "The rug believes steel chocolate" is a syntactically valid yet meaningless sentence. Standard deep learning natural language models will generate far more coherent sentences than this.

Does deep learning's ability to learn and generate coherent utterances enable it to learn the meaning of natural language (in more precise terms, semantics and pragmatics)? That is a point of contention among AI researchers. I argue it does not (see my Stranger Things post). In short, to say that deep learning learns meaning, is to say that meaning is nothing but statistical patterns between words.

In my view, this is the linguistic equivalent to saying "correlation equals causality."

So what elements of language deep learning can deep learning capture? To answer this, we need to understand the inductive biases of popular deep neural network architectures applied to language modeling.

Inductive bias is the bias a learning algorithm needs to generalize beyond the training data. It is a necessary component of any machine learning algorithm we expect to use on practical problems. But they always have trade-offs; sometimes the algo is biased from statistical regularities that you want to capture.

The inductive biases of recurrent neural networks

Recurrent neural networks (RNN) are a broad class of artificial neural networks used to model sequences. The connections in an RNN have an ordering that captures statistical patterns in sequential phenomena. Researchers have applied them to stock prices changing in time, genetic information changing across nucleotides in DNA, and of course, words changing across sentences. The literature frequently characterizes RNNs as having a "recency bias."

The recency bias means that when learning statistical patterns between elements of a sequence, the RNN (1) favors making connections between elements of a sequence that are close to one another, (2) doesn't care where in the sequence the elements are, just how close they are to one another. The recency bias emerges from the fact that RNN only detects patterns along the input's sequential order. For example, it can learn that "deep" tends to be followed by "learning." However, it can't learn that "learning" tends to be preceded by "deep"; it can only predict what words come next, not what words came before.

The recency bias makes some sense in natural language. Imagine you were a robot reading this text. Wouldn't it make sense to assume ordered pairs like {"deep", "learning"} and {"inductive", "bias"} are connected, while pairs like {"recurrent", "research"} are less so? 

Still, the recency bias has its drawbacks in modeling natural language. It is common in language for the meaning of an utterance to depend on words separated by a long sequence of words. RNNs can have difficulty learning connections between such words*. A simple example comes from word order in language. English has a subject-verb-object word ordering, while Japanese has a subject-object-verb word ordering. Thus in Japanese, the subject and the verb of a phrase can have many words in between each other. Studies have shown RNNs work better on subject-verb-object languages than subject-object-verb languages.

Secondly, the recency bias means RNNs have trouble with paraprosdokian - figures of speech where the latter part of a sentence changes the meaning in a way that is hard to predict from the first part of the sentence. E.g., "Take my wife—please!"

Know a deep learning magical thinker? Why not share their post and burst their bubble?


Compositionality, hierarchy, and recurrence in neural nets

The principle of hierarchy and compositionality states that the meaning of a complex expression is composed of meanings of constituent expressions. Rules govern how those constituent expressions combine to form the overall meaning of the complex expression. We can derive these rules by reducing the meaningful components of an expression to symbols. For example, the sentence "Socrates was a man" will be reduced to something like "S was a M."

We can use the principle of compositionality as an inductive bias in natural language machine learning. A learning algorithm with this inductive bias would prefer generalizations that make some connection between S and M. One way to explicitly induce this bias is using an explicit graph representation in the model. For example, tree-structured neural networks encode a particular tree representation that captures the compositional semantics of natural language (Bowman 2015).

The "recurrent" in the recurrent neural network refers to how the RNN is structured. When we train an RNN model on natural language text, the model processes words in order. Given a word, the model learns a bit about the probability of seeing that word given the words that came before it. But, as explained above, it doesn't "remember" those previous words directly. Rather, it applies a function to the new word that indirectly captures the information in the previous words. It recursively applies this same function on every word as it proceeds.

It seems this creates a kind of "recurrency bias" that captures compositionality and hierarchy (Tran 2018).

If that sounds hand-wavy, that is because it is. In deep learning architectures, inductive bias is described in terms of properties of the structure. There is a gap between the properties of neural net architecture (e.g., recurrency) and the abstractions of the domain you wish to model (e.g., natural language and the things natural language talks about).

Transformer networks lack inductive bias (and that's not great).

The cutting edge of neural language models is transformer networks. The most notable example at the time of writing is OpenAI's GPT-3. Much has been written about GPT-3, and you can find apps online that demonstrate its (or its predecessor GPT-2's ) impressive feats of natural language generation.

The transformer model does away with recurrency, relying entirely on a self-attention mechanism. Each token's (a token means a word, phrase, punctuation, and other symbols) representation is directly informed by all the other tokens' representations. This hyperconnectivity makes transformers more expressive than RNNs when representing context shared across longer sequences of tokens.

However, this removes the RNN's inductive bias towards recency and compositionality. As a result, transformers struggle to generalize as well as RNN's given a fixed data size. This inability to generalize is evident in work that shows that RNNs do much better than transformers on tasks that require capturing hierarchical structures (Dehghani, 2018).

Put simply, transformers dump the inductive biases of RNNs and don't seem to replace them with anything else.

Is that a bad thing? After all, GPT-3 has accomplished impressive feats of synthetic language generation.

Inductive bias is necessary to generalize beyond the training data. In massive models like GPT-3, the training data is a corpus of internet documents with half a trillion tokens. The deep transformer network encodes the statistical relationships between all those words in 175 billion tokens.

In other words, even if GPT-3 doesn't generalize that well, it still gets impressive results because interpolating within a training dataset that huge gets you pretty far.


*Difficulty in learning connections across distance is a special instance of a general technical challenge called the vanishing gradient problem)