Attention Heatmaps

Understanding Text with Attention Heatmaps

Multi-head attention is a powerful concept that enhances the attention mechanism in natural language processing by allowing the model to focus on different aspects of a sentence simultaneously, creating comprehensive representations that capture different perspectives. The Attention Heatmap provides an intuitive graphical visualization of these attention scores, allowing users to see which words receive the most emphasis across different attention heads. 

In Masked Language Model (MLM) tasks within BERT, these concepts work together to accurately predict missing words based on surrounding context, revealing how the model understands relationships and meaning in text. This collaboration ensures that BERT excels at understanding language nuances, helping researchers refine training strategies for improved language comprehension.

Figure 1: Input Text

Consider this simple toy example to model and understand a small number of sentences. In Figure 1, we have a collection of six sentence pairs (or sequences) which will be used in training a Masked Language Model (MLM). In MLM, certain words in a text sequence are intentionally hidden (masked), and the model's goal is to predict these missing words based on the surrounding context. This training approach enables the model to understand the meaning and relationships between words, leading to better comprehension of the language. We use sentence pairs because they enable the model in predicting the sentence flow and sequence. 

This example showcases a small but meaningful set of sentence pairs. Each sentence pair provides context that helps the model understand how different elements of a sentence relate to one another. For instance, the sentence pair 'The cat chased the mouse. It was a fast chase.' gives the model context about the action and its attribute. Similarly, 'I need to buy milk. Please go to the store.' provides an understanding of the necessity and the action taken to fulfill that need.


By training on such sentence pairs, the MLM can learn to predict missing words in similar contexts, enhancing its ability to understand and generate human-like text.

In Figure 2, you see the sentences with certain words replaced by [MASK] tokens. This masking process is crucial for training the MLM. Masked input text involves replacing some words in the sentences with [MASK] tokens. The goal of the MLM is to predict these masked words based on the context provided by the surrounding words.

Figure 2: Masked input text

Attention Score Calculation

Figure 3: Attention Score Calculation

The attention score calculation and heatmap shown in this tool are derived from the last attention layer of a simple BERT model trained on the given input sequences. This visualization reflects the final attention distributions that the model uses to predict masked words, demonstrating the learned context and relationships from all preceding layers. For a more detailed end-to-end demonstration, you can refer to another tool on this portal called "BERT Explorer".

The "Attention Score Calculation" section visually represents how input data (here, "Sample 0") is processed to compute attention scores within the multi-head attention mechanism.



Understanding the Attention Heatmap

The attention heatmap demonstrates the inner workings of multi-head attention by displaying attention scores in a grid format. These scores show how much focus each word in a sentence receives from every other word, which helps the model understand and generate context-aware representations.

Figure 4: Attention Heatmap

In the context of understanding how multi-head attention works within a Masked Language Model (MLM), it is crucial to visualize how attention scores are distributed across different words in a sentence. We can call this visualization an attention heatmap and an example of which is shown in Figure 4. 

Different input text samples can be used to observe how attention scores vary for each word in different sentences. For example, the sample heatmap is calculated for the input sample "the cat chased the mouse. it was a fast chase". The attention patterns also vary between epochs as these attention patterns evolve. For example at epoch 3750 of 5000  shown in Figure 5), we see the intermediate stage of learning and how attention distributions are captured at this stage of model training. 

There is another aspect of attention scores which is its calculation across multiple heads. Each head focuses on different aspects of the input data, offering unique perspectives on word relationships. For example, Head 3 word relationships are captured in Figures 4 and 5.

In the attention heatmap, the rows represent words or tokens in the input sequence, while columns show the words they attend to. Each cell in the grid displays a normalized attention score between 0 and 1, with higher scores highlighted in brighter colors. The darker cells imply lesser attention focus. For instance, in the sentence "the cat chased the mouse. it was a fast chase", head 3 shows strong attention between "cat" and "chased", and "fast" and "chased", indicating the head's focus on subject-verb-adverb relationships.

Although not shown in the diagram, the outputs from each head are concatenated to form a single matrix, which is then used to compute the final output of the attention mechanism. This combined matrix retains diverse information from all heads, capturing various relationships within the input data.

This mechanism of computing attention scores and generating context-aware word embeddings is a fundamental component of the encoder. An encoder in typical NLP models consists of a multi-head attention mechanism followed by a feed-forward neural network. By stacking multiple encoders, the model can build complex representations of the input data through successive layers, capturing intricate patterns and relationships, leading to more accurate predictions of masked words.

The output from the final encoder layer, which contains rich, context-aware representations of the input tokens, is then fed into a prediction layer. This layer uses the learned context and relationships from all preceding encoders to accurately predict the masked words. This detailed process of prediction, including the prediction layer, is shown in another tool called “BERT Explorer” for a more comprehensive understanding of the BERT model.

Figure 5: Attention Heatmap

After understanding the attention mechanisms and their role in predicting masked words, it's time to see the results. The final output prediction (Figure 6) showcases the model's ability to accurately fill in the blanks.

Figure 6: Prediction of Masked Words

This output displays the sentence after the model predicts the masked words as it is fully trained. Initially, the masked words in the input sentence are replaced with [MASK] tokens. The model then predicts the most likely words to replace these tokens, reconstructing the sentence. This prediction process highlights the model's ability to comprehend and generate contextually appropriate words, demonstrating the efficacy of multi-head attention and the overall BERT model architecture. By analyzing these predictions at different stages of training, one can gain insights into the model's learning progress and its understanding of language.

Summary

The exploration of the Attention Heatmap tool reveals how multi-head attention mechanisms within models like BERT help in understanding the intricate relationships between words in a text. By visualizing attention scores, the heatmap demonstrates how different words focus on each other, providing insights into the model's interpretation of the text. The tool shows how attention scores are distributed across multiple heads and epochs, highlighting how the model's focus shifts and refines over time. This visualization aids in comprehending how BERT captures various linguistic patterns and dependencies, enhancing its ability to understand and generate human language accurately.