Mécanisme d’attention et Transformers#

Attention is all you need.

— Ashish Vaswani et al., Attention Is All You Need (2017)

Le chapitre 20 a présenté les réseaux récurrents (RNN, LSTM, GRU) et mis en évidence leurs deux limitations fondamentales : la difficulté à capturer des dépendances à longue portée et l’impossibilité de paralléliser le calcul sur la dimension temporelle. Le mécanisme d’attention, proposé initialement comme complément aux architectures seq2seq, apporte une solution élégante au premier problème en permettant au modèle de « regarder » directement toutes les positions de la séquence d’entrée. L’architecture Transformer, introduite par Vaswani et al. en 2017, pousse cette idée à son terme en abandonnant entièrement la récurrence au profit de l”auto-attention (self-attention), résolvant ainsi simultanément les deux limitations. Ce chapitre présente le mécanisme d’attention, l’auto-attention, l’attention multi-têtes, l’encodage positionnel et l’architecture Transformer complète, avec des implémentations détaillées en PyTorch.

Hide code cell source

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

sns.set_theme(style="whitegrid", palette="muted", font_scale=1.1)
torch.manual_seed(42)
np.random.seed(42)

Le problème du goulot d’étranglement informationnel#

Rappel : l’architecture seq2seq#

Dans l’architecture encodeur-décodeur classique (chapitre 20), l’encodeur RNN lit la séquence d’entrée \((x_1, \ldots, x_T)\) et produit un unique vecteur de contexte \(\mathbf{c} = h_T\), le dernier état caché. Le décodeur doit ensuite générer toute la séquence de sortie à partir de ce seul vecteur.

Remarque 233

Ce vecteur \(\mathbf{c}\) constitue un goulot d’étranglement (bottleneck) : toute l’information de la séquence d’entrée, quelle que soit sa longueur, doit être compressée dans un vecteur de dimension fixe. En pratique, les performances des modèles seq2seq se dégradent significativement lorsque la longueur des séquences d’entrée augmente, car l’information des premiers tokens est progressivement « écrasée » par celle des tokens suivants.

Hide code cell source

# Illustration du goulot d'étranglement seq2seq
fig, axes = plt.subplots(2, 1, figsize=(9, 7))

# Gauche : dégradation de l'information dans le vecteur de contexte
seq_lens = np.arange(5, 105, 5)
info_retained = np.exp(-0.015 * seq_lens) * 100
axes[0].plot(seq_lens, info_retained, 'o-', color='#E24A33', markersize=4, linewidth=2)
axes[0].set_xlabel("Longueur de la séquence $T$")
axes[0].set_ylabel("Information retenue (%)")
axes[0].set_title("Dégradation du vecteur de contexte (schématique)")
axes[0].axhline(y=50, color='gray', linestyle='--', alpha=0.5)
axes[0].set_ylim(0, 105)

# Droite : score BLEU vs longueur de phrase (simulé)
lengths = np.arange(5, 55, 5)
bleu_no_attn = 35 - 0.5 * lengths + np.random.randn(len(lengths)) * 1.5
bleu_attn = 32 - 0.1 * lengths + np.random.randn(len(lengths)) * 1.0
axes[1].plot(lengths, bleu_no_attn, 's-', label="Seq2Seq (sans attention)", color='#4C72B0')
axes[1].plot(lengths, bleu_attn, 'o-', label="Seq2Seq + attention", color='#55A868')
axes[1].set_xlabel("Longueur de la phrase source")
axes[1].set_ylabel("Score BLEU")
axes[1].set_title("Impact de l'attention sur les phrases longues")
axes[1].legend()

plt.tight_layout()
plt.show()
_images/b88b74043139569bc33b0cadf149aaa3d696927a3fb88afb096aa160e5d2ae01.png

Mécanisme d’attention#

L’idée centrale du mécanisme d’attention est simple : plutôt que de compresser toute la séquence en un unique vecteur, on permet au décodeur d”accéder directement à tous les états cachés de l’encodeur à chaque pas de temps de la génération. Un système de scores détermine l’importance relative de chaque état caché de l’encodeur pour le pas de décodage courant.

Attention de Bahdanau (additive)#

Le mécanisme d’attention a été introduit par Bahdanau, Cho et Bengio en 2015 dans le contexte de la traduction automatique neuronale.

Définition 267 (Attention additive (Bahdanau))

Soient \((\bar{h}_1, \ldots, \bar{h}_T)\) les états cachés de l’encodeur et \(s_{t-1}\) l’état caché courant du décodeur. L”attention de Bahdanau calcule un vecteur de contexte \(\mathbf{c}_t\) propre à chaque pas de décodage \(t\) :

  1. Score d’alignement (fonction de score additive) :

\[e_{t,j} = \mathbf{v}^\top \tanh(W_s \, s_{t-1} + W_h \, \bar{h}_j)\]
  1. Poids d’attention (normalisation par softmax) :

\[\alpha_{t,j} = \frac{\exp(e_{t,j})}{\sum_{k=1}^{T} \exp(e_{t,k})}\]
  1. Vecteur de contexte (combinaison convexe) :

\[\mathbf{c}_t = \sum_{j=1}^{T} \alpha_{t,j} \, \bar{h}_j\]

\(W_s \in \mathbb{R}^{d_a \times d_s}\), \(W_h \in \mathbb{R}^{d_a \times d_h}\) et \(\mathbf{v} \in \mathbb{R}^{d_a}\) sont des paramètres apprenables, et \(d_a\) est la dimension de l’espace d’alignement.

Attention de Luong (multiplicative)#

Luong et al. (2015) ont proposé des fonctions de score alternatives, plus simples et souvent plus efficaces.

Définition 268 (Fonctions de score d’attention)

Soit \(s_t\) l’état du décodeur et \(\bar{h}_j\) un état de l’encodeur. Les principales fonctions de score sont :

Nom

Formule

Complexité

Dot product

\(e_{t,j} = s_t^\top \bar{h}_j\)

\(O(d)\)

Général (bilinéaire)

\(e_{t,j} = s_t^\top W_a \, \bar{h}_j\)

\(O(d^2)\)

Additif (Bahdanau)

\(e_{t,j} = \mathbf{v}^\top \tanh(W_s s_t + W_h \bar{h}_j)\)

\(O(d)\)

Scaled dot product

\(e_{t,j} = \frac{s_t^\top \bar{h}_j}{\sqrt{d}}\)

\(O(d)\)

Le scaled dot product divise par \(\sqrt{d}\) pour éviter que les valeurs du produit scalaire ne deviennent trop grandes en haute dimension, ce qui écraserait les gradients du softmax.

Implémentation de l’attention additive#

Hide code cell source

class BahdanauAttention(nn.Module):
    """Mécanisme d'attention additive (Bahdanau et al., 2015)."""

    def __init__(self, hidden_dim, attention_dim):
        super().__init__()
        self.W_s = nn.Linear(hidden_dim, attention_dim, bias=False)
        self.W_h = nn.Linear(hidden_dim, attention_dim, bias=False)
        self.v = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_outputs):
        """
        Args:
            decoder_state: (batch, hidden_dim)
            encoder_outputs: (batch, seq_len, hidden_dim)
        Returns:
            context: (batch, hidden_dim), weights: (batch, seq_len)
        """
        # decoder_state: (batch, 1, attention_dim)
        score_s = self.W_s(decoder_state).unsqueeze(1)
        # encoder_outputs: (batch, seq_len, attention_dim)
        score_h = self.W_h(encoder_outputs)

        # Scores d'alignement : (batch, seq_len, 1) -> (batch, seq_len)
        energy = self.v(torch.tanh(score_s + score_h)).squeeze(-1)

        # Poids d'attention
        weights = F.softmax(energy, dim=-1)

        # Vecteur de contexte
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, weights

# Démonstration
batch_size, seq_len, hidden_dim, attn_dim = 2, 8, 64, 32
encoder_out = torch.randn(batch_size, seq_len, hidden_dim)
dec_state = torch.randn(batch_size, hidden_dim)

attn = BahdanauAttention(hidden_dim, attn_dim)
ctx, weights = attn(dec_state, encoder_out)
print(f"Vecteur de contexte : {ctx.shape}")
print(f"Poids d'attention   : {weights.shape}")
print(f"Somme des poids     : {weights.sum(dim=-1)}")
Vecteur de contexte : torch.Size([2, 64])
Poids d'attention   : torch.Size([2, 8])
Somme des poids     : tensor([1.0000, 1.0000], grad_fn=<SumBackward1>)

Hide code cell source

# Visualisation des poids d'attention
fig, ax = plt.subplots(figsize=(8, 3))
w = weights.detach().numpy()
sns.heatmap(w, annot=True, fmt=".2f", cmap="YlOrRd", ax=ax,
            xticklabels=[f"enc_{i}" for i in range(seq_len)],
            yticklabels=[f"batch_{i}" for i in range(batch_size)])
ax.set_xlabel("Position dans la séquence source")
ax.set_ylabel("Échantillon du batch")
ax.set_title("Poids d'attention de Bahdanau")
plt.tight_layout()
plt.show()
_images/a5d8d76b05e79302225005f4fa1fbece923f975d2e168e52bb081ff2d1c7bed3.png

Auto-attention (Self-Attention)#

L’attention telle que décrite précédemment relie un décodeur à un encodeur : c’est une attention croisée (cross-attention). L”auto-attention applique le même principe au sein d’une même séquence : chaque position peut « regarder » toutes les autres positions de la même séquence pour construire sa représentation.

Formalisme Query-Key-Value#

Définition 269 (Auto-attention avec Query, Key, Value)

Soit \(X \in \mathbb{R}^{T \times d}\) une séquence de \(T\) vecteurs de dimension \(d\). L’auto-attention projette chaque vecteur en trois rôles à l’aide de matrices apprenables :

\[Q = X W^Q, \quad K = X W^K, \quad V = X W^V\]

\(W^Q, W^K \in \mathbb{R}^{d \times d_k}\) et \(W^V \in \mathbb{R}^{d \times d_v}\).

  • \(Q\) (queries) : ce que chaque position « cherche »

  • \(K\) (keys) : ce que chaque position « annonce »

  • \(V\) (values) : l’information que chaque position « transmet »

La sortie est :

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V\]

Remarque 234

L’analogie avec un système de recherche d’information est éclairante : chaque position émet une requête (\(Q\)), et le score de pertinence entre cette requête et la clé (\(K\)) de chaque autre position détermine combien de valeur (\(V\)) cette position apporte. Le facteur \(\frac{1}{\sqrt{d_k}}\) empêche les produits scalaires de devenir trop grands lorsque \(d_k\) est élevé, ce qui concentrerait le softmax sur un seul élément et annulerait les gradients.

Dérivation détaillée#

Considérons une séquence de \(T\) positions. La matrice \(QK^\top \in \mathbb{R}^{T \times T}\) contient les scores de similarité entre toutes les paires de positions :

\[(QK^\top)_{i,j} = \mathbf{q}_i^\top \mathbf{k}_j = \sum_{m=1}^{d_k} q_{i,m} \, k_{j,m}\]

Si les composantes de \(Q\) et \(K\) sont des variables aléatoires i.i.d. de moyenne nulle et de variance unitaire, alors \(\mathbb{E}[\mathbf{q}_i^\top \mathbf{k}_j] = 0\) et \(\text{Var}[\mathbf{q}_i^\top \mathbf{k}_j] = d_k\). Diviser par \(\sqrt{d_k}\) ramène la variance à 1, ce qui maintient le softmax dans une zone à gradients raisonnables.

Propriété 1 (Complexité de l’auto-attention)

Pour une séquence de longueur \(T\) et une dimension \(d_k\) :

  • Complexité temporelle : \(O(T^2 \cdot d_k)\) pour le calcul de \(QK^\top\)

  • Complexité mémoire : \(O(T^2)\) pour stocker la matrice d’attention

  • Longueur du chemin de gradient : \(O(1)\) entre deux positions quelconques (contre \(O(T)\) pour un RNN)

La complexité quadratique en \(T\) est le principal inconvénient de l’auto-attention, mais elle est largement compensée par la parallélisation totale du calcul et la capacité à capturer des dépendances à longue portée.

Implémentation de l’auto-attention#

Hide code cell source

class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention (Vaswani et al., 2017)."""

    def __init__(self, d_k):
        super().__init__()
        self.scale = d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: (batch, T_q, d_k)
            K: (batch, T_k, d_k)
            V: (batch, T_k, d_v)
            mask: (batch, T_q, T_k) optionnel, True = masqué
        Returns:
            output: (batch, T_q, d_v), weights: (batch, T_q, T_k)
        """
        # Scores : (batch, T_q, T_k)
        scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        output = torch.bmm(weights, V)
        return output, weights

# Démonstration
T, d_model, d_k = 6, 32, 32
X = torch.randn(1, T, d_model)

W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)

Q, K, V = W_Q(X), W_K(X), W_V(X)
attn = ScaledDotProductAttention(d_k)
output, weights = attn(Q, K, V)

print(f"Entrée X      : {X.shape}")
print(f"Sortie         : {output.shape}")
print(f"Poids attention: {weights.shape}")
Entrée X      : torch.Size([1, 6, 32])
Sortie         : torch.Size([1, 6, 32])
Poids attention: torch.Size([1, 6, 6])

Hide code cell source

# Visualisation de la matrice d'auto-attention
fig, ax = plt.subplots(figsize=(6, 5))
w = weights[0].detach().numpy()
labels = [f"pos {i}" for i in range(T)]
sns.heatmap(w, annot=True, fmt=".2f", cmap="Blues", ax=ax,
            xticklabels=labels, yticklabels=labels, square=True)
ax.set_xlabel("Clé (Key)")
ax.set_ylabel("Requête (Query)")
ax.set_title("Matrice d'auto-attention ($T=6$)")
plt.tight_layout()
plt.show()
_images/903911c09a54fae54b0f9c327d5e1dca602517b369b8c921e2fabda51efc32a0.png

Exemple 28 (Intuition de l’auto-attention)

Considérons la phrase « Le chat dort sur le canapé car il est fatigué ». Lorsque l’auto-attention traite le mot « il », elle doit déterminer à quoi ce pronom fait référence. Les scores d’attention entre la query de « il » et les keys de tous les autres mots permettront idéalement d’attribuer un poids élevé à « chat », capturant ainsi la relation de coréférence — une dépendance à longue portée que les RNN peinent à modéliser.

Attention multi-têtes (Multi-Head Attention)#

Motivation#

Une seule tête d’attention ne peut capturer qu’un seul « type » de relation entre positions. Or, dans une phrase, les relations sont multiples : syntaxiques, sémantiques, coréférentielles, etc. L’attention multi-têtes résout ce problème en exécutant plusieurs mécanismes d’attention en parallèle, chacun dans un sous-espace différent.

Définition 270 (Attention multi-têtes)

L”attention multi-têtes avec \(h\) têtes est définie par :

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O\]

où chaque tête \(i\) est :

\[\text{head}_i = \text{Attention}(Q W_i^Q, \, K W_i^K, \, V W_i^V)\]

avec \(W_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}\), \(W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\), \(W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\) et \(W^O \in \mathbb{R}^{h \cdot d_v \times d_{\text{model}}}\).

Typiquement, \(d_k = d_v = d_{\text{model}} / h\), de sorte que le coût total est comparable à celui d’une seule tête de dimension \(d_{\text{model}}\).

Implémentation#

Hide code cell source

class MultiHeadAttention(nn.Module):
    """Attention multi-têtes (Vaswani et al., 2017)."""

    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model doit être divisible par n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        self.scale = self.d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q, K, V: (batch, seq_len, d_model)
            mask: optionnel
        Returns:
            output: (batch, seq_len, d_model)
            weights: (batch, n_heads, seq_len, seq_len)
        """
        B, T_q, _ = Q.shape
        T_k = K.shape[1]

        # Projections linéaires et reshape en (B, n_heads, T, d_k)
        Q = self.W_Q(Q).view(B, T_q, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention par tête
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1), float('-inf'))
        weights = F.softmax(scores, dim=-1)
        context = torch.matmul(weights, V)

        # Concaténation des têtes et projection de sortie
        context = context.transpose(1, 2).contiguous().view(B, T_q, self.d_model)
        output = self.W_O(context)
        return output, weights

# Démonstration
d_model, n_heads, T = 64, 8, 10
X = torch.randn(2, T, d_model)
mha = MultiHeadAttention(d_model, n_heads)
out, attn_weights = mha(X, X, X)  # auto-attention : Q=K=V=X
print(f"Entrée        : {X.shape}")
print(f"Sortie         : {out.shape}")
print(f"Poids attention: {attn_weights.shape}  (batch, têtes, T_q, T_k)")
Entrée        : torch.Size([2, 10, 64])
Sortie         : torch.Size([2, 10, 64])
Poids attention: torch.Size([2, 8, 10, 10])  (batch, têtes, T_q, T_k)

Hide code cell source

# Visualisation des poids d'attention pour chaque tête
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    w = attn_weights[0, i].detach().numpy()
    sns.heatmap(w, cmap="viridis", ax=ax, square=True,
                cbar=False, xticklabels=False, yticklabels=False)
    ax.set_title(f"Tête {i+1}", fontsize=10)
fig.suptitle("Poids d'attention des 8 têtes (auto-attention)", fontsize=13, y=1.01)
plt.tight_layout()
plt.show()
_images/345670baa7448f0613fa0213b534f552fc55aeb4100c8f7a755e430333cf7587.png

Remarque 235

Chaque tête apprend à détecter un type de relation différent. Des analyses empiriques montrent que certaines têtes se spécialisent dans les relations syntaxiques (sujet-verbe), d’autres dans les dépendances positionnelles (positions adjacentes), et d’autres encore dans les relations sémantiques à longue portée. La projection finale \(W^O\) apprend à combiner ces informations complémentaires.

Encodage positionnel (Positional Encoding)#

Pourquoi encoder la position ?#

L’auto-attention est une opération invariante par permutation : si l’on permute l’ordre des tokens dans la séquence, les poids d’attention changent, mais la relation entre chaque paire reste la même. Or, l’ordre des mots est crucial en langue naturelle (« le chat mange la souris » \(\neq\) « la souris mange le chat »). Il faut donc injecter explicitement l’information de position.

Définition 271 (Encodage positionnel sinusoïdal)

L”encodage positionnel sinusoïdal de Vaswani et al. associe à chaque position \(\text{pos}\) et chaque dimension \(i\) du modèle un signal :

\[PE_{(\text{pos}, 2i)} = \sin\!\left(\frac{\text{pos}}{10000^{2i / d_{\text{model}}}}\right)\]
\[PE_{(\text{pos}, 2i+1)} = \cos\!\left(\frac{\text{pos}}{10000^{2i / d_{\text{model}}}}\right)\]

Ce choix garantit que :

  1. Chaque position reçoit un encodage unique.

  2. La distance relative entre deux positions peut être exprimée comme une transformation linéaire de leurs encodages : \(PE_{\text{pos}+k}\) est une fonction linéaire de \(PE_{\text{pos}}\) pour tout décalage fixe \(k\).

  3. Les longueurs de séquence non vues à l’entraînement peuvent être traitées grâce à l’extrapolation naturelle des fonctions sinusoïdales.

Hide code cell source

class PositionalEncoding(nn.Module):
    """Encodage positionnel sinusoïdal."""

    def __init__(self, d_model, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * (-np.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

# Démonstration
pe_module = PositionalEncoding(d_model=64, max_len=100, dropout=0.0)
dummy = torch.zeros(1, 100, 64)
pe_values = pe_module(dummy)[0].numpy()
print(f"Forme de l'encodage positionnel : {pe_values.shape}")
Forme de l'encodage positionnel : (100, 64)

Hide code cell source

# Visualisation de l'encodage positionnel
fig, axes = plt.subplots(2, 1, figsize=(9, 9))

# Gauche : heatmap des encodages positionnels
im = axes[0].imshow(pe_values[:50, :48], aspect='auto', cmap='RdBu_r',
                     interpolation='nearest')
axes[0].set_xlabel("Dimension $i$")
axes[0].set_ylabel("Position")
axes[0].set_title("Encodage positionnel sinusoïdal")
plt.colorbar(im, ax=axes[0], fraction=0.046)

# Droite : courbes pour quelques dimensions
dims_to_plot = [0, 1, 4, 5, 10, 11]
positions = np.arange(100)
for d in dims_to_plot:
    label = f"dim {d} ({'sin' if d % 2 == 0 else 'cos'})"
    axes[1].plot(positions, pe_values[:, d], label=label, alpha=0.8, linewidth=1.5)
axes[1].set_xlabel("Position")
axes[1].set_ylabel("Valeur")
axes[1].set_title("Signaux positionnels pour différentes dimensions")
axes[1].legend(fontsize=8, ncol=2)

plt.tight_layout()
plt.show()
_images/8ca309393a06e7d656f57b2e15a3e143da099ef5015accdff16684e59a623773.png

Remarque 236

Les dimensions basses (fréquence haute) varient rapidement avec la position, encodant les relations locales. Les dimensions hautes (fréquence basse) varient lentement, encodant la position absolue à grande échelle. L’ensemble forme un « spectre » de fréquences analogue à une transformée de Fourier de la position. Depuis 2017, d’autres schémas d’encodage positionnel ont été proposés — encodages apprenables, RoPE (Rotary Position Embedding), ALiBi — mais l’encodage sinusoïdal reste la référence pédagogique.

Architecture Transformer#

Vue d’ensemble#

L’architecture Transformer est composée d’un encodeur et d’un décodeur, chacun constitué d’un empilement de blocs identiques. L’encodeur transforme la séquence d’entrée en une représentation contextuelle riche, et le décodeur génère la séquence de sortie token par token en s’appuyant sur la sortie de l’encodeur.

Définition 272 (Architecture Transformer)

Le Transformer (Vaswani et al., 2017) est composé de :

Encodeur (\(N\) blocs identiques) : chaque bloc contient :

  1. Une couche d”auto-attention multi-têtes

  2. Une couche feed-forward position-wise : \(\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2\)

  3. Des connexions résiduelles et une normalisation de couche (layer norm) après chaque sous-couche : \(\text{LayerNorm}(x + \text{SubLayer}(x))\)

Décodeur (\(N\) blocs identiques) : chaque bloc contient :

  1. Une couche d”auto-attention multi-têtes masquée (empêche de « regarder le futur »)

  2. Une couche d”attention croisée multi-têtes (queries du décodeur, keys/values de l’encodeur)

  3. Une couche feed-forward position-wise

  4. Des connexions résiduelles et layer norm après chaque sous-couche

La configuration standard utilise \(N = 6\) blocs, \(d_{\text{model}} = 512\), \(h = 8\) têtes, et \(d_{ff} = 2048\).

Bloc encodeur#

Hide code cell source

class FeedForward(nn.Module):
    """Réseau feed-forward position-wise."""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class TransformerEncoderBlock(nn.Module):
    """Un bloc encodeur du Transformer."""

    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Auto-attention + résiduelle + layer norm
        attn_out, attn_weights = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))

        # Feed-forward + résiduelle + layer norm
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_out))

        return x, attn_weights

# Démonstration
d_model, n_heads, d_ff = 64, 8, 256
encoder_block = TransformerEncoderBlock(d_model, n_heads, d_ff)
X = torch.randn(2, 10, d_model)
out, weights = encoder_block(X)
print(f"Entrée  : {X.shape}")
print(f"Sortie  : {out.shape}")
print(f"Poids   : {weights.shape}")
Entrée  : torch.Size([2, 10, 64])
Sortie  : torch.Size([2, 10, 64])
Poids   : torch.Size([2, 8, 10, 10])

Propriété 2 (Connexions résiduelles et normalisation)

Les connexions résiduelles (residual connections) facilitent l’entraînement de réseaux profonds en permettant au gradient de se propager directement à travers les couches. La normalisation de couche (Layer Normalization) stabilise les activations en normalisant sur la dimension des features :

\[\text{LayerNorm}(\mathbf{x}) = \frac{\mathbf{x} - \mu}{\sigma + \epsilon} \odot \gamma + \beta\]

\(\mu\) et \(\sigma\) sont la moyenne et l’écart-type calculés sur la dernière dimension, et \(\gamma\), \(\beta\) sont des paramètres apprenables. Contrairement au Batch Normalization (utilisé dans les CNN), la Layer Norm ne dépend pas de la taille du batch et fonctionne identiquement en entraînement et en inférence.

Bloc décodeur#

Hide code cell source

class TransformerDecoderBlock(nn.Module):
    """Un bloc décodeur du Transformer."""

    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.masked_attention = MultiHeadAttention(d_model, n_heads)
        self.cross_attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_out, src_mask=None, tgt_mask=None):
        # Auto-attention masquée (le décodeur ne voit pas le futur)
        attn_out, _ = self.masked_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_out))

        # Attention croisée (queries=décodeur, keys/values=encodeur)
        cross_out, cross_weights = self.cross_attention(x, encoder_out, encoder_out, src_mask)
        x = self.norm2(x + self.dropout2(cross_out))

        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_out))

        return x, cross_weights

# Masque causal pour le décodeur
def create_causal_mask(seq_len):
    """Crée un masque triangulaire supérieur (True = masqué)."""
    mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
    return mask.unsqueeze(0)  # (1, T, T)

# Démonstration
T_dec = 8
causal_mask = create_causal_mask(T_dec)
print(f"Masque causal ({T_dec}x{T_dec}) :")
print(causal_mask[0].int())
Masque causal (8x8) :
tensor([[0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)

Hide code cell source

# Visualisation du masque causal
fig, axes = plt.subplots(2, 1, figsize=(9, 9))

# Gauche : masque causal
sns.heatmap(causal_mask[0].int().numpy(), cmap="RdYlGn_r", ax=axes[0],
            annot=True, fmt="d", square=True, cbar=False,
            xticklabels=[f"$t_{{{i}}}$" for i in range(T_dec)],
            yticklabels=[f"$t_{{{i}}}$" for i in range(T_dec)])
axes[0].set_title("Masque causal (1 = masqué)")
axes[0].set_xlabel("Position clé")
axes[0].set_ylabel("Position requête")

# Droite : scores d'attention après masquage (simulé)
scores_raw = torch.randn(T_dec, T_dec)
scores_masked = scores_raw.masked_fill(causal_mask[0], float('-inf'))
attn_masked = F.softmax(scores_masked, dim=-1).numpy()
sns.heatmap(attn_masked, cmap="Blues", ax=axes[1], annot=True, fmt=".2f",
            square=True, xticklabels=[f"$t_{{{i}}}$" for i in range(T_dec)],
            yticklabels=[f"$t_{{{i}}}$" for i in range(T_dec)])
axes[1].set_title("Poids d'attention après masquage causal")
axes[1].set_xlabel("Position clé")
axes[1].set_ylabel("Position requête")

plt.tight_layout()
plt.show()
_images/a2f9c990724941dd19c3a22ee8abad93f8c3a3b5bfc3960a28a2287b22e86733.png

Transformer complet#

Hide code cell source

class Transformer(nn.Module):
    """Architecture Transformer complète pour des tâches seq2seq."""

    def __init__(self, src_vocab, tgt_vocab, d_model=64, n_heads=4,
                 n_layers=2, d_ff=256, max_len=128, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Embeddings et encodage positionnel
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len, dropout)

        # Blocs encodeur
        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Blocs décodeur
        self.decoder_blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Couche de sortie
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_mask=None):
        x = self.pos_enc(self.src_embed(src) * (self.d_model ** 0.5))
        for block in self.encoder_blocks:
            x, _ = block(x, src_mask)
        return x

    def decode(self, tgt, encoder_out, src_mask=None, tgt_mask=None):
        x = self.pos_enc(self.tgt_embed(tgt) * (self.d_model ** 0.5))
        for block in self.decoder_blocks:
            x, _ = block(x, encoder_out, src_mask, tgt_mask)
        return x

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        encoder_out = self.encode(src, src_mask)
        decoder_out = self.decode(tgt, encoder_out, src_mask, tgt_mask)
        logits = self.fc_out(decoder_out)
        return logits

# Instanciation
model = Transformer(src_vocab=100, tgt_vocab=100, d_model=64, n_heads=4,
                    n_layers=2, d_ff=256)
n_params = sum(p.numel() for p in model.parameters())
print(f"Nombre de paramètres : {n_params:,}")
Nombre de paramètres : 251,236

Hide code cell source

# Diagramme de l'architecture Transformer
fig, ax = plt.subplots(figsize=(12, 10))
ax.set_xlim(0, 12); ax.set_ylim(0, 12)
ax.set_aspect('equal'); ax.axis('off')
ax.set_title("Architecture Transformer (vue schématique)", fontsize=14, pad=15)

colors = {
    'embed': '#C5B0D5', 'attn': '#4C72B0', 'ff': '#55A868',
    'norm': '#F7B6D2', 'output': '#DD8452', 'mask': '#E24A33'
}

# --- Encodeur (colonne gauche) ---
enc_x = 2.5
ax.text(enc_x, 11.3, "ENCODEUR", fontsize=12, ha='center', fontweight='bold',
        color=colors['attn'])

# Embedding + PE
rect = plt.Rectangle((enc_x - 1.2, 10.2), 2.4, 0.6, facecolor=colors['embed'],
                       edgecolor='white', alpha=0.8, linewidth=1.5)
ax.add_patch(rect)
ax.text(enc_x, 10.5, "Embedding\n+ Pos. Encoding", fontsize=7, ha='center', va='center')

# Bloc encodeur (x2)
for j, y_base in enumerate([8.5, 6.5]):
    label_block = f"Bloc {j+1}"
    ax.text(enc_x + 1.6, y_base + 0.8, label_block, fontsize=7, ha='center',
            color='gray', style='italic')

    rect_attn = plt.Rectangle((enc_x - 1.2, y_base + 0.6), 2.4, 0.55,
                                facecolor=colors['attn'], edgecolor='white',
                                alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_attn)
    ax.text(enc_x, y_base + 0.87, "Multi-Head Self-Attention", fontsize=6.5,
            ha='center', va='center', color='white')

    rect_norm = plt.Rectangle((enc_x - 1.2, y_base + 0.3), 2.4, 0.25,
                                facecolor=colors['norm'], edgecolor='white',
                                alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_norm)
    ax.text(enc_x, y_base + 0.42, "Add & LayerNorm", fontsize=6, ha='center', va='center')

    rect_ff = plt.Rectangle((enc_x - 1.2, y_base - 0.2), 2.4, 0.45,
                              facecolor=colors['ff'], edgecolor='white',
                              alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_ff)
    ax.text(enc_x, y_base + 0.02, "Feed-Forward", fontsize=6.5,
            ha='center', va='center', color='white')

    rect_norm2 = plt.Rectangle((enc_x - 1.2, y_base - 0.5), 2.4, 0.25,
                                 facecolor=colors['norm'], edgecolor='white',
                                 alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_norm2)
    ax.text(enc_x, y_base - 0.38, "Add & LayerNorm", fontsize=6, ha='center', va='center')

# --- Décodeur (colonne droite) ---
dec_x = 8.5
ax.text(dec_x, 11.3, "DÉCODEUR", fontsize=12, ha='center', fontweight='bold',
        color=colors['mask'])

rect = plt.Rectangle((dec_x - 1.2, 10.2), 2.4, 0.6, facecolor=colors['embed'],
                       edgecolor='white', alpha=0.8, linewidth=1.5)
ax.add_patch(rect)
ax.text(dec_x, 10.5, "Embedding\n+ Pos. Encoding", fontsize=7, ha='center', va='center')

for j, y_base in enumerate([8.0, 5.0]):
    label_block = f"Bloc {j+1}"
    ax.text(dec_x + 1.6, y_base + 1.5, label_block, fontsize=7, ha='center',
            color='gray', style='italic')

    # Masked self-attention
    rect_mask = plt.Rectangle((dec_x - 1.2, y_base + 1.1), 2.4, 0.5,
                                facecolor=colors['mask'], edgecolor='white',
                                alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_mask)
    ax.text(dec_x, y_base + 1.35, "Masked Self-Attention", fontsize=6.5,
            ha='center', va='center', color='white')

    # Cross-attention
    rect_cross = plt.Rectangle((dec_x - 1.2, y_base + 0.5), 2.4, 0.5,
                                 facecolor=colors['attn'], edgecolor='white',
                                 alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_cross)
    ax.text(dec_x, y_base + 0.75, "Cross-Attention", fontsize=6.5,
            ha='center', va='center', color='white')

    # Flèche de l'encodeur vers le cross-attention
    ax.annotate('', xy=(dec_x - 1.2, y_base + 0.75),
                xytext=(enc_x + 1.2, y_base + 0.75 if j == 0 else 6.0),
                arrowprops=dict(arrowstyle='->', color=colors['attn'], lw=1.5, alpha=0.6))

    # Feed-forward
    rect_ff = plt.Rectangle((dec_x - 1.2, y_base - 0.1), 2.4, 0.5,
                              facecolor=colors['ff'], edgecolor='white',
                              alpha=0.7, linewidth=1.5)
    ax.add_patch(rect_ff)
    ax.text(dec_x, y_base + 0.15, "Feed-Forward", fontsize=6.5,
            ha='center', va='center', color='white')

# Sortie
rect_out = plt.Rectangle((dec_x - 1.2, 3.5), 2.4, 0.5,
                           facecolor=colors['output'], edgecolor='white',
                           alpha=0.8, linewidth=1.5)
ax.add_patch(rect_out)
ax.text(dec_x, 3.75, "Linéaire + Softmax", fontsize=7, ha='center', va='center',
        color='white')

# Étiquettes d'entrée
ax.text(enc_x, 9.7, "Entrée source", fontsize=9, ha='center', color='gray')
ax.annotate('', xy=(enc_x, 10.2), xytext=(enc_x, 9.9),
            arrowprops=dict(arrowstyle='->', color='gray'))
ax.text(dec_x, 9.7, "Sortie décalée", fontsize=9, ha='center', color='gray')
ax.annotate('', xy=(dec_x, 10.2), xytext=(dec_x, 9.9),
            arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()
_images/83e0961160d281dc95dd2ff7b12a2f50f56b65c50a25a38e9e54f1656df4cb62.png

Entraînement d’un Transformer sur une tâche jouet#

Pour illustrer le fonctionnement du Transformer, entraînons-le sur une tâche simple de copie de séquence : le modèle doit apprendre à reproduire la séquence d’entrée en sortie. C’est un test classique pour vérifier qu’une architecture seq2seq fonctionne correctement.

Préparation des données#

Hide code cell source

def generate_copy_data(n_samples, seq_len, vocab_size, pad_idx=0, sos_idx=1):
    """Génère des données pour la tâche de copie.

    Entrée:  [3, 7, 2, 5, 9]
    Cible:   [SOS, 3, 7, 2, 5, 9]
    """
    # Tokens entre 2 et vocab_size-1 (0=PAD, 1=SOS)
    src = torch.randint(2, vocab_size, (n_samples, seq_len))
    tgt_input = torch.cat([torch.full((n_samples, 1), sos_idx), src], dim=1)
    tgt_output = torch.cat([src, torch.full((n_samples, 1), pad_idx)], dim=1)
    return src, tgt_input, tgt_output

vocab_size = 20
seq_len = 8
n_train = 2000
n_test = 200

src_train, tgt_in_train, tgt_out_train = generate_copy_data(n_train, seq_len, vocab_size)
src_test, tgt_in_test, tgt_out_test = generate_copy_data(n_test, seq_len, vocab_size)

print(f"Exemple d'entrée source : {src_train[0].tolist()}")
print(f"Exemple d'entrée cible  : {tgt_in_train[0].tolist()}")
print(f"Exemple de sortie cible : {tgt_out_train[0].tolist()}")
Exemple d'entrée source : [5, 19, 13, 2, 17, 14, 13, 6]
Exemple d'entrée cible  : [1, 5, 19, 13, 2, 17, 14, 13, 6]
Exemple de sortie cible : [5, 19, 13, 2, 17, 14, 13, 6, 0]

Planification du taux d’apprentissage (warmup)#

Définition 273 (Warmup du taux d’apprentissage)

Le Transformer utilise une planification du taux d’apprentissage avec échauffement (warmup) suivie d’une décroissance :

\[\text{lr}(\text{step}) = d_{\text{model}}^{-0.5} \cdot \min\!\left(\text{step}^{-0.5}, \; \text{step} \cdot \text{warmup\_steps}^{-1.5}\right)\]

Le taux augmente linéairement pendant les premiers warmup_steps pas, puis décroît proportionnellement à l’inverse de la racine carrée du numéro de pas. Cette stratégie évite les instabilités au début de l’entraînement, lorsque les paramètres sont encore aléatoires.

Hide code cell source

def transformer_lr_schedule(step, d_model, warmup_steps=400):
    """Planification du LR avec warmup (Vaswani et al., 2017)."""
    if step == 0:
        step = 1
    return d_model ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))

# Visualisation de la planification
steps = np.arange(1, 4001)
lrs = [transformer_lr_schedule(s, d_model=64, warmup_steps=400) for s in steps]

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(steps, lrs, color='#4C72B0', linewidth=2)
ax.axvline(x=400, color='gray', linestyle='--', alpha=0.5, label="Fin du warmup")
ax.set_xlabel("Pas d'entraînement")
ax.set_ylabel("Taux d'apprentissage")
ax.set_title("Planification du taux d'apprentissage (warmup + décroissance)")
ax.legend()
plt.tight_layout()
plt.show()
_images/fe1c01db8ac98aae1baed9cb16f6520d684e49926e2e96b70f55d3aebcd991c2.png

Boucle d’entraînement#

Hide code cell source

# Hyperparamètres
d_model = 64
n_heads = 4
n_layers = 2
d_ff = 128
n_epochs = 30
batch_size = 64
warmup_steps = 200

# Modèle
torch.manual_seed(42)
model = Transformer(src_vocab=vocab_size, tgt_vocab=vocab_size,
                    d_model=d_model, n_heads=n_heads, n_layers=n_layers,
                    d_ff=d_ff, max_len=seq_len + 2, dropout=0.1)

# Optimiseur Adam avec betas du papier original
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)

# Planification du LR
scheduler = optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: transformer_lr_schedule(step + 1, d_model, warmup_steps)
)

criterion = nn.CrossEntropyLoss(ignore_index=0)

# Entraînement
train_losses = []
test_accuracies = []

for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0.0
    n_batches = 0

    for i in range(0, n_train, batch_size):
        src = src_train[i:i+batch_size]
        tgt_in = tgt_in_train[i:i+batch_size]
        tgt_out = tgt_out_train[i:i+batch_size]

        tgt_mask = create_causal_mask(tgt_in.size(1))

        logits = model(src, tgt_in, tgt_mask=tgt_mask)
        loss = criterion(logits.reshape(-1, vocab_size), tgt_out.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()
        n_batches += 1

    train_losses.append(epoch_loss / n_batches)

    # Évaluation
    model.eval()
    with torch.no_grad():
        tgt_mask_test = create_causal_mask(tgt_in_test.size(1))
        logits_test = model(src_test, tgt_in_test, tgt_mask=tgt_mask_test)
        preds = logits_test.argmax(dim=-1)
        # On ignore la dernière position (PAD dans la cible)
        correct = (preds[:, :-1] == tgt_out_test[:, :-1]).float().mean().item()
        test_accuracies.append(correct * 100)

    if (epoch + 1) % 5 == 0:
        print(f"Époque {epoch+1:3d} | Loss: {train_losses[-1]:.4f} | "
              f"Précision test: {test_accuracies[-1]:.1f}%")
Époque   5 | Loss: 3.1138 | Précision test: 4.8%
Époque  10 | Loss: 3.0425 | Précision test: 5.4%
Époque  15 | Loss: 2.9912 | Précision test: 6.2%
Époque  20 | Loss: 2.9563 | Précision test: 6.8%
Époque  25 | Loss: 2.9285 | Précision test: 7.4%
Époque  30 | Loss: 2.9080 | Précision test: 8.3%

Hide code cell source

# Courbes d'entraînement
fig, axes = plt.subplots(2, 1, figsize=(9, 9))

axes[0].plot(range(1, n_epochs + 1), train_losses, 'o-', color='#E24A33',
             markersize=3, linewidth=2)
axes[0].set_xlabel("Époque")
axes[0].set_ylabel("Loss (entropie croisée)")
axes[0].set_title("Perte d'entraînement")
axes[0].set_yscale('log')

axes[1].plot(range(1, n_epochs + 1), test_accuracies, 's-', color='#55A868',
             markersize=3, linewidth=2)
axes[1].set_xlabel("Époque")
axes[1].set_ylabel("Précision (%)")
axes[1].set_title("Précision sur l'ensemble de test")
axes[1].set_ylim(0, 105)
axes[1].axhline(y=100, color='gray', linestyle='--', alpha=0.3)

plt.suptitle("Entraînement du Transformer sur la tâche de copie", fontsize=13)
plt.tight_layout()
plt.show()
_images/8851e260917a8bf3400f01def6c15e889df4134ca2fa9c328bc15287546b1eb7.png

Hide code cell source

# Exemples de prédictions
model.eval()
with torch.no_grad():
    idx = np.random.choice(n_test, 5, replace=False)
    tgt_mask_ex = create_causal_mask(tgt_in_test.size(1))
    logits_ex = model(src_test[idx], tgt_in_test[idx], tgt_mask=tgt_mask_ex)
    preds_ex = logits_ex.argmax(dim=-1)

print("Exemples de copie (source → prédiction) :")
print("-" * 50)
for i, j in enumerate(idx):
    src_seq = src_test[j].tolist()
    pred_seq = preds_ex[i, :-1].tolist()
    match = "OK" if src_seq == pred_seq else "ERREUR"
    print(f"Source    : {src_seq}")
    print(f"Prédiction: {pred_seq}  [{match}]")
    print()
Exemples de copie (source → prédiction) :
--------------------------------------------------
Source    : [11, 4, 10, 6, 4, 8, 7, 10]
Prédiction: [11, 4, 15, 9, 7, 16, 4, 6]  [ERREUR]

Source    : [12, 15, 9, 18, 7, 13, 7, 2]
Prédiction: [11, 17, 9, 9, 9, 17, 7, 17]  [ERREUR]

Source    : [15, 16, 15, 13, 13, 12, 19, 13]
Prédiction: [15, 9, 13, 9, 7, 7, 9, 9]  [ERREUR]

Source    : [11, 4, 5, 11, 13, 15, 15, 17]
Prédiction: [11, 14, 15, 13, 4, 15, 5, 5]  [ERREUR]

Source    : [13, 4, 3, 7, 6, 11, 6, 3]
Prédiction: [11, 3, 16, 2, 3, 8, 4, 8]  [ERREUR]

Visualisation de l’attention apprise#

Hide code cell source

# Extraction des poids d'attention du premier bloc encodeur
model.eval()
with torch.no_grad():
    sample_src = src_test[:1]
    x = model.pos_enc(model.src_embed(sample_src) * (model.d_model ** 0.5))
    _, attn_w = model.encoder_blocks[0](x)

fig, axes = plt.subplots(2, 2, figsize=(10, 9))
for h in range(4):
    w = attn_w[0, h].numpy()
    ax = axes[h // 2, h % 2]
    sns.heatmap(w, cmap="Blues", ax=ax, square=True,
                cbar=False, xticklabels=sample_src[0].tolist(),
                yticklabels=sample_src[0].tolist())
    ax.set_title(f"Tête {h+1}", fontsize=10)
    ax.tick_params(labelsize=7)

fig.suptitle("Poids d'attention appris (bloc encodeur 1, tâche de copie)", fontsize=12)
plt.tight_layout()
plt.show()
_images/c394a37d07245dfe489fe25c558bb3fe8b665b69d65b0d7dbaf793a846a12d9e.png

Impact et postérité#

L’article Attention Is All You Need de Vaswani et al. (2017) a provoqué une révolution dans l’apprentissage profond. En éliminant la récurrence, le Transformer a permis un passage à l’échelle sans précédent et a engendré une famille de modèles qui dominent aujourd’hui le traitement du langage naturel, la vision par ordinateur et bien d’autres domaines.

Modèles fondateurs#

Exemple 29 (Descendants majeurs du Transformer)

Encodeur seul :

  • BERT (Bidirectional Encoder Representations from Transformers, Devlin et al., 2019) : pré-entraîné par prédiction de mots masqués (masked language modeling) et prédiction de phrase suivante. A défini le paradigme « pré-entraînement + fine-tuning » en NLP. Détails au chapitre 24.

Décodeur seul :

  • GPT (Generative Pre-trained Transformer, Radford et al., 2018) : pré-entraîné en modélisation de langage auto-régressive (\(P(x_t \mid x_1, \ldots, x_{t-1})\)). A donné naissance à GPT-2, GPT-3, GPT-4 et ChatGPT.

Encodeur-décodeur :

  • T5 (Text-to-Text Transfer Transformer, Raffel et al., 2020) : formule toutes les tâches NLP comme des problèmes texte-vers-texte.

  • BART (Lewis et al., 2020) : combinaison d’un encodeur bidirectionnel et d’un décodeur auto-régressif.

Vision :

  • ViT (Vision Transformer, Dosovitskiy et al., 2021) : applique le Transformer aux images en les découpant en patches. Détails au chapitre 25.

Hide code cell source

# Chronologie des modèles Transformer
fig, ax = plt.subplots(figsize=(15, 4))
ax.set_xlim(2016.5, 2025.5)
ax.set_ylim(-1.5, 3.5)
ax.axis('off')
ax.set_title("Chronologie des architectures fondées sur le Transformer", fontsize=13, pad=15)

events = [
    (2017, "Transformer\n(Vaswani et al.)", '#E24A33', 0),
    (2018, "GPT\n(Radford et al.)", '#DD8452', 1),
    (2018.6, "ELMo\n(Peters et al.)", '#8B6DAF', -1),
    (2019, "BERT\n(Devlin et al.)", '#4C72B0', 0),
    (2019.5, "GPT-2\n(Radford et al.)", '#DD8452', 1),
    (2020, "T5\n(Raffel et al.)", '#55A868', -1),
    (2020.5, "GPT-3\n(Brown et al.)", '#DD8452', 0),
    (2021, "ViT\n(Dosovitskiy et al.)", '#C44E52', 1),
    (2022, "ChatGPT\n(OpenAI)", '#DD8452', -1),
    (2023, "GPT-4\n(OpenAI)", '#DD8452', 0),
    (2024, "Mamba, Mixtral\nJamba...", '#777777', 1),
]

ax.axhline(y=0, color='gray', linewidth=2, alpha=0.3, xmin=0.02, xmax=0.98)

for year, label, color, offset in events:
    y_text = 1.0 + offset * 0.9
    ax.plot(year, 0, 'o', color=color, markersize=10, zorder=5)
    ax.annotate(label, xy=(year, 0), xytext=(year, y_text),
                fontsize=7, ha='center', va='bottom', color=color,
                fontweight='bold',
                arrowprops=dict(arrowstyle='-', color=color, alpha=0.4))

plt.tight_layout()
plt.show()
_images/10c6387b24ef672c37878961b86d1337265b3f3b59db49992f0f448c85557655.png

Remarque 237

Le Transformer n’est pas limité au texte. Son architecture a été adaptée avec succès aux images (ViT, Swin Transformer), à l’audio (Whisper), à la biologie (AlphaFold 2), à la chimie (MolBERT), à la robotique (RT-2) et même aux jeux (MuZero). La modularité du mécanisme d’attention — sa capacité à modéliser des relations arbitraires entre éléments d’un ensemble — en fait une brique de base véritablement universelle de l’apprentissage profond moderne.

Comparaison avec les architectures précédentes#

Théorème 5 (Avantages du Transformer sur les RNN)

Soit une séquence de longueur \(T\) et un modèle de dimension \(d\). Le tableau suivant compare les propriétés fondamentales :

Propriété

RNN

Transformer

Complexité par couche

\(O(T \cdot d^2)\)

\(O(T^2 \cdot d + T \cdot d^2)\)

Opérations séquentielles

\(O(T)\)

\(O(1)\)

Longueur max. du chemin de gradient

\(O(T)\)

\(O(1)\)

Parallélisation

Non

Oui

Pour \(T < d\) (cas fréquent en NLP où \(T \sim 512\) et \(d \sim 512\) ou plus), le terme \(O(T^2 \cdot d)\) de l’attention est dominé par \(O(T \cdot d^2)\), et le Transformer est plus rapide que le RNN grâce à la parallélisation complète.

Hide code cell source

# Comparaison visuelle des complexités
T_values = np.arange(10, 2001, 10)
d = 512

complexity_rnn = T_values * d**2
complexity_transformer = T_values**2 * d + T_values * d**2

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(T_values, complexity_rnn / 1e9, label="RNN : $O(T \\cdot d^2)$",
        color='#4C72B0', linewidth=2)
ax.plot(T_values, complexity_transformer / 1e9, label="Transformer : $O(T^2 d + T d^2)$",
        color='#E24A33', linewidth=2)
ax.axvline(x=d, color='gray', linestyle='--', alpha=0.5, label=f"$T = d = {d}$")
ax.set_xlabel("Longueur de séquence $T$")
ax.set_ylabel("FLOPs ($\\times 10^9$)")
ax.set_title("Complexité par couche : RNN vs Transformer")
ax.legend()
ax.set_xlim(0, 2000)
plt.tight_layout()
plt.show()
_images/375165f0169b7e1852e95ce89d34953883451ca05830f7b0af4a828952a5887e.png

Résumé#

Ce chapitre a présenté le mécanisme d’attention et l’architecture Transformer, qui ont révolutionné l’apprentissage profond depuis 2017.

  1. Le mécanisme d’attention résout le goulot d’étranglement des modèles seq2seq en permettant au décodeur d’accéder directement à tous les états de l’encodeur, pondérés par des scores de pertinence. Les variantes principales sont l’attention additive (Bahdanau) et l’attention multiplicative (Luong).

  2. L”auto-attention (self-attention) applique ce principe au sein d’une même séquence. Le formalisme Query-Key-Value et le scaled dot-product attention \(\text{softmax}(QK^\top / \sqrt{d_k}) V\) constituent le coeur du Transformer.

  3. L”attention multi-têtes exécute \(h\) mécanismes d’attention en parallèle dans des sous-espaces différents, capturant des types de relations complémentaires.

  4. L”encodage positionnel sinusoïdal injecte l’information d’ordre dans une architecture qui en est autrement dépourvue, grâce à des signaux de fréquences variées.

  5. L’architecture Transformer combine auto-attention multi-têtes, couches feed-forward, connexions résiduelles et normalisation de couche en blocs encodeur et décodeur empilés. Le masque causal dans le décodeur empêche le modèle de « tricher » en regardant le futur.

  6. L’entraînement du Transformer utilise un planificateur de taux d’apprentissage avec warmup, l’optimiseur Adam avec des hyperparamètres spécifiques, et le gradient clipping pour stabiliser la convergence.

  7. L”impact du Transformer est considérable : BERT, GPT, T5, ViT et des centaines d’autres modèles en sont dérivés. L’attention est devenue la brique fondamentale de l’apprentissage profond moderne, et ses descendants seront étudiés dans les chapitres 24 (NLP) et 25 (Vision).

Remarque 238

Le Transformer illustre un principe profond en apprentissage automatique : la bonne inductive bias — ici, la capacité à modéliser des relations arbitraires entre éléments d’une séquence sans a priori de localité — peut l’emporter sur des architectures plus structurées (RNN, CNN) lorsque les données et la puissance de calcul sont suffisantes. Comprendre l’attention et le Transformer est aujourd’hui un prérequis indispensable pour aborder les modèles de fondation (foundation models) qui façonnent l’IA contemporaine.