Transformers are one of the crucial building blocks of modern day NLP. There has been much confusion on them, even though thousands of blog posts, animations exist. I finally clarified my understanding after a short 5 minute conversation!
So here is my brief, opinionated explanation of a transformer!
Transfomer:
- Is a parallelizable, and efficient way of performing neural language modelling
- Dispenses with RNNs; uses only attention
- Is non-autoregressive, and autoregressive
Because of those last two points, the architecture deals with inputs in a very special way!
Input: an entire sequence at once. Unlike an RNN, tokens are NOT fed one-by-one! For this reason, we do require a positional encoding!
Output: A single token. The transformer will look at the entire input sequence, as well as the previous words that it decoded/generated. This means that it is non-autoregressive in the input sequence; it can operate on the entire sequence at once. Hence why it requires an entire input sequence
Some notes: During training, it can be trained in a very rapid way. Since most models are generally trained using teacher forcing anyways, we actually do not care about what tokens the transformer previously generated. Hence, during training time, we are essentially just maximizing the probability of producing a specific token (the next one, at time step t) conditioned on the input sequence, and the groud truth decoded output sequence 0 to t-1 , which are provided. This is where the causal triangular masking comes into play; where we hide the output from next step of the transformer; this is so we can train quickly (in a batched fashion; without waiting for the transformer to decode itself). Note this is also why we have an output positional embedding! This is so the transformer is able to correctly make sense of the ordering of its tokens.
However, during testing/inference/generation, it must be trained in a slow, autoregressive fashion. This is because here, we do not have access to the ground truth tokens! In order to generate N tokens, we must run N forward passes (of the decoder; we can cache the model when run on the input sequence). This is similar to how a regular seq2seq model generates, as it moves the sliding window left to right. (kind of; depending on the architecture, a seq2seq can also ‘just’ require N passes of the decoder)
We have all the regular generation details when running the forward pass of the decoder: it is a function $d(xt, h{t-1})$, which takes in the previous token just generated, and the previous hidden state(s) and then generates the output probabilities for the next token. Then, we run some max or other discrete to get the actual token generated at that timestep, and then we feed it in to the decoder at the next time step (probably going through a positional encoding as well)
Commentary: Humans do not necessarily read exactly left to right. There is evidence that we do some type of convolution/high level scan of the words, at a high level, and then we go into more detail/precision (character by character) when we need. This could be another research idea too! (Hierarchical NLP)
Deep fake text generation has the potential to change the world in the future. We can solve this problem, via occupying the gap!
Recall the steps to writing the proposal Overall, yesterday I learnt a lot of stuff! Went to the RL talk (on hierarchical reinforcement learning; where we have multiple agents each responsible for a single task; constrained state space, but shared action space), and then also went to the ARIA, where had great discussions with Pricing Optimization, Transformer, and the RL project.