Réseaux antagonistes génératifs (GAN)#

Toute création est d’abord un acte de destruction.

— Pablo Picasso, Propos sur l’art

Le chapitre précédent a présenté les auto-encodeurs variationnels (VAE), un modèle génératif fondé sur l’inférence variationnelle et un encodeur-décodeur probabiliste. Les VAE optimisent une borne inférieure de la vraisemblance (ELBO) et produisent des échantillons en décodant des vecteurs latents tirés d’une distribution gaussienne. Si cette approche offre un cadre mathématique élégant, les images générées sont souvent floues en raison de la perte de reconstruction \(L^2\). Ce chapitre introduit une approche radicalement différente de la modélisation générative : les réseaux antagonistes génératifs (Generative Adversarial Networks, GAN), proposés par Ian Goodfellow et al. en 2014. L’idée fondatrice est d’entraîner simultanément deux réseaux — un générateur qui produit des données synthétiques et un discriminateur qui tente de distinguer les données réelles des données générées — dans un jeu à somme nulle inspiré de la théorie des jeux. Cette dynamique adversariale pousse le générateur à produire des échantillons de plus en plus réalistes, sans jamais nécessiter de reconstruction explicite.

Hide code cell source

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

sns.set_theme(style="whitegrid", palette="muted", font_scale=1.1)
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Périphérique utilisé : {device}")
Périphérique utilisé : cpu

Introduction : modèles génératifs et apprentissage adversarial#

Rappel : le cadre génératif#

En apprentissage automatique, on distingue deux grandes familles de modèles. Les modèles discriminatifs apprennent la distribution conditionnelle \(p(y \mid \mathbf{x})\) pour classifier ou régresser. Les modèles génératifs, en revanche, apprennent la distribution des données \(p_{\text{data}}(\mathbf{x})\) elle-même, ce qui permet de générer de nouveaux échantillons ressemblant aux données d’entraînement.

Définition 259 (Modèle génératif)

Un modèle génératif est un modèle statistique qui apprend la distribution \(p_{\text{data}}(\mathbf{x})\) des données d’entraînement \(\mathbf{x} \in \mathbb{R}^d\). Une fois entraîné, il permet de :

  1. Échantillonner de nouveaux points \(\mathbf{x}_{\text{new}} \sim p_{\text{model}}(\mathbf{x})\).

  2. Évaluer la vraisemblance (ou une approximation) d’un point donné.

Les principales approches sont :

  • Modèles à variable latente : VAE, modèles de diffusion.

  • Modèles autorégressifs : PixelCNN, GPT.

  • Modèles adversariaux : GAN et ses variantes.

Du VAE au GAN : une autre philosophie#

Les VAE définissent explicitement un modèle probabiliste \(p_\theta(\mathbf{x} \mid \mathbf{z})\) et optimisent l’ELBO. Les GAN adoptent une stratégie radicalement différente : le générateur apprend une transformation déterministe \(G(\mathbf{z})\) d’un bruit aléatoire \(\mathbf{z}\) vers l’espace des données, sans jamais définir explicitement la densité \(p_{\text{model}}(\mathbf{x})\). La qualité des échantillons est jugée par un réseau adversaire — le discriminateur — et non par une fonction de perte de reconstruction.

Remarque 227

VAE vs GAN — Le VAE fournit une densité explicite et un encodeur permettant l’inférence, mais produit des images floues. Le GAN ne fournit ni densité explicite ni encodeur, mais génère des images nettement plus nettes. Ces deux approches sont complémentaires et peuvent être combinées (VAE-GAN, par exemple).

Hide code cell source

# Schéma comparatif : VAE vs GAN
fig, axes = plt.subplots(2, 1, figsize=(9, 7))

# VAE
ax = axes[0]
ax.set_xlim(0, 10); ax.set_ylim(0, 6)
ax.set_title("Auto-encodeur variationnel (VAE)", fontsize=11, fontweight='bold')
ax.axis('off')
boxes_vae = [("$\\mathbf{x}$", 0.5, 3), ("Encodeur\n$q_\\phi(\\mathbf{z}|\\mathbf{x})$", 3, 3),
             ("$\\mathbf{z}$", 5.5, 3), ("Décodeur\n$p_\\theta(\\mathbf{x}|\\mathbf{z})$", 7.5, 3),
             ("$\\hat{\\mathbf{x}}$", 9.5, 3)]
for label, x, y in boxes_vae:
    ax.annotate(label, xy=(x, y), fontsize=10, ha='center', va='center',
                bbox=dict(boxstyle='round,pad=0.4', facecolor='#AEC6CF', alpha=0.8))
for i in range(len(boxes_vae) - 1):
    ax.annotate('', xy=(boxes_vae[i+1][1]-0.6, 3), xytext=(boxes_vae[i][1]+0.6, 3),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.text(5, 0.8, "Optimise : $\\mathcal{L}_{\\mathrm{ELBO}}$ (reconstruction + KL)", ha='center', fontsize=9, style='italic')

# GAN
ax = axes[1]
ax.set_xlim(0, 10); ax.set_ylim(0, 6)
ax.set_title("Réseau antagoniste génératif (GAN)", fontsize=11, fontweight='bold')
ax.axis('off')
boxes_gan = [("$\\mathbf{z}$", 0.5, 3), ("Générateur\n$G_\\theta$", 3, 3),
             ("$G(\\mathbf{z})$", 5.5, 4), ("Discriminateur\n$D_\\phi$", 7.5, 3)]
for label, x, y in boxes_gan:
    ax.annotate(label, xy=(x, y), fontsize=10, ha='center', va='center',
                bbox=dict(boxstyle='round,pad=0.4',
                          facecolor='#FFB7B2' if 'D_' in label else '#AEC6CF', alpha=0.8))
ax.annotate('', xy=(2.2, 3), xytext=(1.1, 3), arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate('', xy=(4.8, 3.7), xytext=(3.8, 3.2), arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate("$\\mathbf{x}_{\\mathrm{réel}}$", xy=(5.5, 2), fontsize=10, ha='center', va='center',
            bbox=dict(boxstyle='round,pad=0.4', facecolor='#B5EAD7', alpha=0.8))
ax.annotate('', xy=(6.8, 3.5), xytext=(6.1, 3.8), arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate('', xy=(6.8, 2.5), xytext=(6.1, 2.2), arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate("Vrai / Faux", xy=(9.3, 3), fontsize=10, ha='center', va='center',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='#FFDAC1', alpha=0.8))
ax.annotate('', xy=(8.7, 3), xytext=(8.2, 3), arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.text(5, 0.8, "Optimise : jeu minimax adversarial", ha='center', fontsize=9, style='italic')

plt.tight_layout()
plt.show()
_images/b0f09ae26bbcb77de4c4768a17a21ab2e84d3cbe5a9732d7ef4edfe6aa6ba930.png

Architecture GAN#

Générateur et discriminateur#

Un GAN est composé de deux réseaux de neurones entraînés simultanément :

  • Le générateur \(G_\theta : \mathbb{R}^{n_z} \to \mathbb{R}^d\) transforme un vecteur de bruit latent \(\mathbf{z} \sim p_z(\mathbf{z})\) (typiquement \(\mathcal{N}(\mathbf{0}, \mathbf{I})\)) en un échantillon synthétique \(G(\mathbf{z})\) de même dimension que les données réelles.

  • Le discriminateur \(D_\phi : \mathbb{R}^d \to [0, 1]\) reçoit un point \(\mathbf{x}\) et produit la probabilité que \(\mathbf{x}\) provienne des données réelles plutôt que du générateur.

L’analogie classique est celle du faussaire et de l”expert en art : le générateur tente de produire de faux tableaux indiscernables des originaux, tandis que le discriminateur tente de détecter les contrefaçons.

Définition 260 (Réseau antagoniste génératif (GAN))

Un réseau antagoniste génératif est défini par le jeu minimax suivant :

\[\min_G \max_D \; V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p_z} [\log(1 - D(G(\mathbf{z})))]\]

\(V(D, G)\) est la fonction de valeur du jeu. Le discriminateur \(D\) cherche à maximiser \(V\) (classifier correctement réels et faux), tandis que le générateur \(G\) cherche à minimiser \(V\) (tromper le discriminateur).

Discriminateur optimal#

Pour un générateur \(G\) fixé, on peut calculer analytiquement le discriminateur optimal.

Théorème 3 (Discriminateur optimal)

Pour un générateur \(G\) fixé, le discriminateur optimal \(D^*_G\) est donné par :

\[D^*_G(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\]

\(p_g\) est la distribution implicite induite par \(G(\mathbf{z})\) lorsque \(\mathbf{z} \sim p_z\).

Preuve (esquisse) : Pour un \(\mathbf{x}\) fixé, la fonction à maximiser par rapport à \(D(\mathbf{x})\) est \(f(d) = a \log d + b \log(1-d)\) avec \(a = p_{\text{data}}(\mathbf{x})\) et \(b = p_g(\mathbf{x})\). La dérivée s’annule en \(d = \frac{a}{a+b}\).

Équilibre de Nash et divergence de Jensen-Shannon#

En substituant le discriminateur optimal dans la fonction de valeur, on obtient une expression qui fait apparaître la divergence de Jensen-Shannon entre \(p_{\text{data}}\) et \(p_g\).

Théorème 4 (Convergence du GAN)

Lorsque le discriminateur est optimal, la fonction de valeur du GAN se réécrit :

\[V(D^*_G, G) = -\log 4 + 2 \, D_{\text{JS}}(p_{\text{data}} \| p_g)\]

\(D_{\text{JS}}\) est la divergence de Jensen-Shannon. Le minimum global est atteint si et seulement si \(p_g = p_{\text{data}}\), auquel cas \(D_{\text{JS}} = 0\) et \(D^*(\mathbf{x}) = \frac{1}{2}\) partout.

Hide code cell source

# Illustration : discriminateur optimal pour deux gaussiennes
x = np.linspace(-6, 10, 500)
p_data = 0.6 * np.exp(-0.5 * ((x - 2) / 1.0)**2) / (1.0 * np.sqrt(2 * np.pi)) + \
         0.4 * np.exp(-0.5 * ((x - 5) / 0.8)**2) / (0.8 * np.sqrt(2 * np.pi))
p_g = np.exp(-0.5 * ((x - 3) / 1.5)**2) / (1.5 * np.sqrt(2 * np.pi))

# Normalisation pour visualisation
p_data = p_data / p_data.max()
p_g = p_g / p_g.max()

D_star = p_data / (p_data + p_g + 1e-8)

fig, ax = plt.subplots(figsize=(10, 5))
ax.fill_between(x, p_data, alpha=0.3, color='steelblue', label='$p_{\\mathrm{data}}(\\mathbf{x})$')
ax.fill_between(x, p_g, alpha=0.3, color='coral', label='$p_g(\\mathbf{x})$')
ax.plot(x, D_star, color='forestgreen', linewidth=2.5, label='$D^*_G(\\mathbf{x})$')
ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='$D^* = 0.5$ (équilibre)')
ax.set_xlabel("$x$"); ax.set_ylabel("Densité / Probabilité")
ax.set_title("Discriminateur optimal entre $p_{\\mathrm{data}}$ et $p_g$")
ax.legend(fontsize=10)
plt.tight_layout()
plt.show()
_images/bd83dfbd601f495fe398273979a8a747a674ec569acc5a4a2315ebb014536b30.png

Entraînement d’un GAN#

Algorithme d’entraînement#

L’entraînement d’un GAN alterne entre la mise à jour du discriminateur et celle du générateur. En pratique, on effectue \(k\) pas de gradient sur \(D\) pour chaque pas sur \(G\) (Goodfellow et al. recommandent \(k = 1\)).

Définition 261 (Algorithme d’entraînement GAN)

Pour chaque itération d’entraînement :

1. Mise à jour du discriminateur (monter le gradient de \(V\)) :

  • Échantillonner un mini-batch \(\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(m)}\}\) de données réelles.

  • Échantillonner un mini-batch \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\}\) de bruit \(\mathbf{z} \sim p_z\).

  • Mettre à jour \(\phi\) par montée de gradient :

\[\nabla_\phi \frac{1}{m} \sum_{i=1}^{m} \left[ \log D_\phi(\mathbf{x}^{(i)}) + \log(1 - D_\phi(G_\theta(\mathbf{z}^{(i)}))) \right]\]

2. Mise à jour du générateur (descendre le gradient de \(V\)) :

  • Échantillonner un mini-batch \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\}\) de bruit.

  • Mettre à jour \(\theta\) par descente de gradient :

\[\nabla_\theta \frac{1}{m} \sum_{i=1}^{m} \log(1 - D_\phi(G_\theta(\mathbf{z}^{(i)})))\]

Remarque 228

Astuce pratique — En début d’entraînement, \(D\) rejette facilement les sorties de \(G\), donc \(\log(1 - D(G(\mathbf{z})))\) est proche de \(0\) et le gradient pour \(G\) est très faible. Une astuce courante est de remplacer la perte du générateur par \(-\log D(G(\mathbf{z}))\), qui fournit des gradients plus forts lorsque \(D(G(\mathbf{z}))\) est petit. On appelle cela le non-saturating GAN loss.

Implémentation : GAN sur un mélange de gaussiennes 2D#

Pour illustrer le fonctionnement d’un GAN, commençons par un exemple simple : apprendre un mélange de gaussiennes en 2D.

Hide code cell source

# Génération des données cibles : mélange de 8 gaussiennes en cercle
def make_ring_data(n_samples=2000, n_modes=8, radius=3.0, std=0.3):
    """Génère un mélange de gaussiennes disposées en cercle."""
    angles = np.linspace(0, 2 * np.pi, n_modes, endpoint=False)
    centers = np.stack([radius * np.cos(angles), radius * np.sin(angles)], axis=1)
    data = []
    for _ in range(n_samples):
        idx = np.random.randint(n_modes)
        point = centers[idx] + std * np.random.randn(2)
        data.append(point)
    return np.array(data, dtype=np.float32)

data_real = make_ring_data(n_samples=5000)

fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(data_real[:, 0], data_real[:, 1], s=3, alpha=0.5, color='steelblue')
ax.set_title("Données cibles : mélange de 8 gaussiennes")
ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
ax.set_aspect('equal')
plt.tight_layout()
plt.show()
_images/d11bb666403eb134a87c987926b9c8fbacfdcee59f6e747b6246ddddf2b6d886.png

Hide code cell source

# Définition du générateur et du discriminateur
class Generator2D(nn.Module):
    def __init__(self, latent_dim=16, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, z):
        return self.net(z)


class Discriminator2D(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)

Hide code cell source

# Entraînement du GAN 2D
latent_dim = 16
G = Generator2D(latent_dim=latent_dim).to(device)
D = Discriminator2D().to(device)

opt_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.999))

dataset = TensorDataset(torch.from_numpy(data_real))
loader = DataLoader(dataset, batch_size=256, shuffle=True)

n_epochs = 30
losses_D, losses_G = [], []
snapshots = {}  # Pour visualiser l'évolution

for epoch in range(n_epochs):
    epoch_loss_D, epoch_loss_G = 0, 0
    for (real_batch,) in loader:
        real_batch = real_batch.to(device)
        batch_size = real_batch.size(0)

        # --- Entraîner D ---
        z = torch.randn(batch_size, latent_dim, device=device)
        fake = G(z).detach()
        pred_real = D(real_batch)
        pred_fake = D(fake)
        loss_D = -torch.mean(torch.log(pred_real + 1e-8) + torch.log(1 - pred_fake + 1e-8))
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # --- Entraîner G (non-saturating loss) ---
        z = torch.randn(batch_size, latent_dim, device=device)
        fake = G(z)
        pred_fake = D(fake)
        loss_G = -torch.mean(torch.log(pred_fake + 1e-8))
        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()

    losses_D.append(epoch_loss_D / len(loader))
    losses_G.append(epoch_loss_G / len(loader))

    if epoch in [0, 10, 50, 100, 199]:
        with torch.no_grad():
            z_test = torch.randn(2000, latent_dim, device=device)
            snapshots[epoch] = G(z_test).cpu().numpy()

print("Entraînement terminé.")
Entraînement terminé.

Hide code cell source

# Visualisation de l'évolution du générateur
fig, axes = plt.subplots(5, 1, figsize=(8, 18))
for ax, (epoch, samples) in zip(axes, sorted(snapshots.items())):
    ax.scatter(data_real[:, 0], data_real[:, 1], s=2, alpha=0.2, color='steelblue', label='Réel')
    ax.scatter(samples[:, 0], samples[:, 1], s=2, alpha=0.5, color='coral', label='Généré')
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
    ax.set_aspect('equal')
    ax.set_title(f"Époque {epoch + 1}", fontsize=10)
    if ax == axes[0]:
        ax.legend(fontsize=8, loc='upper left')
plt.suptitle("Évolution des échantillons générés au cours de l'entraînement", fontsize=12, y=1.02)
plt.tight_layout()
plt.show()
_images/6eb68c485c0a74237b2736dde501c97e54dffbf13f3ca9f9cbbe92db2e07819f.png

Hide code cell source

# Courbes de perte
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(losses_D, label='Perte $D$', color='coral', alpha=0.8)
ax.plot(losses_G, label='Perte $G$', color='steelblue', alpha=0.8)
ax.set_xlabel("Époque"); ax.set_ylabel("Perte")
ax.set_title("Évolution des pertes du discriminateur et du générateur")
ax.legend()
plt.tight_layout()
plt.show()
_images/acae54e68495e442584701eeffc7044f9911d1e3b475fa20c40ab7a87c09a4cf.png

On observe que le générateur apprend progressivement la structure en anneau du mélange de gaussiennes. Les premières époques produisent des points concentrés autour de l’origine, puis la distribution s’étale pour couvrir les 8 modes.

DCGAN : GAN convolutif profond#

Principes architecturaux#

Le DCGAN (Deep Convolutional GAN), proposé par Radford et al. en 2016, est un ensemble de directives architecturales qui stabilisent l’entraînement des GAN sur des images. Ces directives sont devenues des pratiques standard.

Définition 262 (Directives architecturales DCGAN)

Les principales recommandations du DCGAN sont :

  1. Remplacer les couches de pooling par des convolutions à pas (strided convolutions) dans le discriminateur et des convolutions transposées (transposed convolutions) dans le générateur.

  2. Utiliser la normalisation par batch (batch normalization) dans les deux réseaux, sauf dans la couche de sortie du générateur et la couche d’entrée du discriminateur.

  3. Supprimer les couches entièrement connectées dans les architectures profondes.

  4. Utiliser ReLU dans le générateur (sauf la sortie : \(\tanh\)) et LeakyReLU dans le discriminateur.

Exemple 26 (Architecture DCGAN pour images 28x28 (MNIST))

Générateur (vecteur latent \(\mathbf{z} \in \mathbb{R}^{100} \to\) image \(1 \times 28 \times 28\)) :

  • Linéaire : \(100 \to 256 \times 7 \times 7\) puis reshape

  • ConvTranspose2d : \(256 \to 128\), noyau \(4 \times 4\), stride \(2\), padding \(1\) + BatchNorm + ReLU

  • ConvTranspose2d : \(128 \to 1\), noyau \(4 \times 4\), stride \(2\), padding \(1\) + Tanh

Discriminateur (image \(1 \times 28 \times 28 \to\) scalaire \([0, 1]\)) :

  • Conv2d : \(1 \to 64\), noyau \(4 \times 4\), stride \(2\), padding \(1\) + LeakyReLU(0.2)

  • Conv2d : \(64 \to 128\), noyau \(4 \times 4\), stride \(2\), padding \(1\) + BatchNorm + LeakyReLU(0.2)

  • Linéaire : \(128 \times 7 \times 7 \to 1\) + Sigmoid

Implémentation PyTorch#

Hide code cell source

# Architecture DCGAN pour MNIST (28x28)
class DCGenerator(nn.Module):
    """Générateur DCGAN pour images 28x28 en niveaux de gris."""
    def __init__(self, latent_dim=100, ngf=128):
        super().__init__()
        self.latent_dim = latent_dim
        self.project = nn.Sequential(
            nn.Linear(latent_dim, ngf * 2 * 7 * 7),
            nn.BatchNorm1d(ngf * 2 * 7 * 7),
            nn.ReLU(True),
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, ngf, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.project(z)
        x = x.view(x.size(0), -1, 7, 7)
        return self.conv(x)


class DCDiscriminator(nn.Module):
    """Discriminateur DCGAN pour images 28x28 en niveaux de gris."""
    def __init__(self, ndf=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, ndf, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(ndf * 2 * 7 * 7, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        features = self.conv(x)
        features = features.view(features.size(0), -1)
        return self.classifier(features)

Hide code cell source

# Initialisation des poids (recommandation DCGAN)
def weights_init(m):
    """Initialisation des poids selon les recommandations DCGAN."""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

# Instanciation
latent_dim = 100
netG = DCGenerator(latent_dim=latent_dim).to(device)
netD = DCDiscriminator().to(device)
netG.apply(weights_init)
netD.apply(weights_init)

print(f"Générateur : {sum(p.numel() for p in netG.parameters()):,} paramètres")
print(f"Discriminateur : {sum(p.numel() for p in netD.parameters()):,} paramètres")
Générateur : 1,818,624 paramètres
Discriminateur : 138,625 paramètres

Entraînement sur MNIST#

Hide code cell source

# Chargement de MNIST
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Images dans [-1, 1] pour Tanh
])

mnist_full = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Sous-ensemble pour accélérer l'entraînement (CPU)
from torch.utils.data import Subset
mnist_train = Subset(mnist_full, range(5000))
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True, num_workers=0)

Hide code cell source

# Boucle d'entraînement DCGAN
criterion = nn.BCELoss()
opt_G = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))

n_epochs = 2
fixed_noise = torch.randn(64, latent_dim, device=device)  # Pour visualisation

history = {'loss_D': [], 'loss_G': [], 'D_x': [], 'D_Gz': []}
gen_images_history = []

for epoch in range(n_epochs):
    for i, (real_imgs, _) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        label_real = torch.ones(batch_size, 1, device=device) * 0.9   # Label smoothing
        label_fake = torch.zeros(batch_size, 1, device=device)

        # --- Entraîner D ---
        opt_D.zero_grad()
        output_real = netD(real_imgs)
        loss_D_real = criterion(output_real, label_real)

        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = netG(z)
        output_fake = netD(fake_imgs.detach())
        loss_D_fake = criterion(output_fake, label_fake)

        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        opt_D.step()

        # --- Entraîner G ---
        opt_G.zero_grad()
        output_fake = netD(fake_imgs)
        loss_G = criterion(output_fake, torch.ones(batch_size, 1, device=device))
        loss_G.backward()
        opt_G.step()

        history['loss_D'].append(loss_D.item())
        history['loss_G'].append(loss_G.item())
        history['D_x'].append(output_real.mean().item())
        history['D_Gz'].append(output_fake.mean().item())

    # Sauvegarder des images générées à chaque époque
    with torch.no_grad():
        gen_imgs = netG(fixed_noise).cpu()
        gen_images_history.append(gen_imgs)

    print(f"Époque [{epoch+1}/{n_epochs}]  "
          f"Loss_D: {history['loss_D'][-1]:.4f}  Loss_G: {history['loss_G'][-1]:.4f}  "
          f"D(x): {history['D_x'][-1]:.3f}  D(G(z)): {history['D_Gz'][-1]:.3f}")
Époque [1/2]  Loss_D: 0.5153  Loss_G: 2.8416  D(x): 0.781  D(G(z)): 0.062
Époque [2/2]  Loss_D: 0.5268  Loss_G: 2.3544  D(x): 0.776  D(G(z)): 0.105

Hide code cell source

# Visualisation des images générées à travers les époques
epochs_to_show = list(range(len(gen_images_history)))
fig, axes = plt.subplots(len(epochs_to_show), 8, figsize=(14, 2 * len(epochs_to_show)))

for row, ep in enumerate(epochs_to_show):
    for col in range(8):
        img = gen_images_history[ep][col].squeeze().numpy()
        img = (img + 1) / 2  # De [-1,1] à [0,1]
        axes[row, col].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
    axes[row, 0].set_ylabel(f"Ép. {ep+1}", fontsize=10, rotation=0, labelpad=35)

plt.suptitle("Images MNIST générées par le DCGAN au fil des époques", fontsize=12, y=1.01)
plt.tight_layout()
plt.show()
_images/a71afa7894dd591d494084b6240b3fd771de461db26d9c648c0acbbc0fda43ed.png

Hide code cell source

# Courbes d'entraînement DCGAN
fig, axes = plt.subplots(2, 1, figsize=(9, 7))

# Pertes
ax = axes[0]
window = 50
loss_D_smooth = np.convolve(history['loss_D'], np.ones(window)/window, mode='valid')
loss_G_smooth = np.convolve(history['loss_G'], np.ones(window)/window, mode='valid')
ax.plot(loss_D_smooth, label='$\\mathcal{L}_D$', color='coral', alpha=0.8)
ax.plot(loss_G_smooth, label='$\\mathcal{L}_G$', color='steelblue', alpha=0.8)
ax.set_xlabel("Itération"); ax.set_ylabel("Perte (lissée)")
ax.set_title("Pertes du discriminateur et du générateur")
ax.legend()

# D(x) et D(G(z))
ax = axes[1]
dx_smooth = np.convolve(history['D_x'], np.ones(window)/window, mode='valid')
dgz_smooth = np.convolve(history['D_Gz'], np.ones(window)/window, mode='valid')
ax.plot(dx_smooth, label='$D(\\mathbf{x})$ (réel)', color='forestgreen', alpha=0.8)
ax.plot(dgz_smooth, label='$D(G(\\mathbf{z}))$ (généré)', color='purple', alpha=0.8)
ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel("Itération"); ax.set_ylabel("Sortie de $D$")
ax.set_title("Confiance du discriminateur")
ax.legend()

plt.tight_layout()
plt.show()
_images/aa4d52c2235bcc87fbc39811df185742dbc7b1edc9b06417964f2297cc180d50.png

Hide code cell source

# Exploration de l'espace latent : interpolation entre deux chiffres
netG.eval()
with torch.no_grad():
    z1 = torch.randn(1, latent_dim, device=device)
    z2 = torch.randn(1, latent_dim, device=device)
    alphas = np.linspace(0, 1, 10)
    interpolated = []
    for alpha in alphas:
        z_interp = (1 - alpha) * z1 + alpha * z2
        img = netG(z_interp).cpu().squeeze().numpy()
        interpolated.append((img + 1) / 2)

fig, axes = plt.subplots(10, 1, figsize=(8, 18))
for i, (ax, img) in enumerate(zip(axes, interpolated)):
    ax.imshow(img, cmap='gray', vmin=0, vmax=1)
    ax.axis('off')
    ax.set_title(f"$\\alpha={alphas[i]:.1f}$", fontsize=8)
plt.suptitle("Interpolation linéaire dans l'espace latent", fontsize=12, y=1.05)
plt.tight_layout()
plt.show()
_images/8de57a6424079da55055d76843133e005fd0e40ae69d1269b3a1a028b98e0ab8.png

L’interpolation dans l’espace latent produit des transitions fluides entre les chiffres, ce qui suggère que le générateur a appris une représentation structurée et continue.

Problèmes de stabilité#

L’entraînement des GAN est notoirement instable. Contrairement à l’optimisation classique d’une seule fonction de perte, le GAN résout un problème d”optimisation à deux joueurs dont la convergence n’est pas garantie. Trois problèmes majeurs se posent.

Effondrement de mode (mode collapse)#

Définition 263 (Effondrement de mode)

L”effondrement de mode (mode collapse) survient lorsque le générateur produit des échantillons concentrés sur un sous-ensemble restreint des modes de la distribution cible. Dans le cas extrême, \(G\) produit toujours la même sortie quel que soit \(\mathbf{z}\) :

\[G(\mathbf{z}_1) \approx G(\mathbf{z}_2) \quad \forall \, \mathbf{z}_1, \mathbf{z}_2\]

Cela se produit lorsque \(G\) trouve un unique point \(\mathbf{x}^*\) qui trompe systématiquement \(D\), et exploite cette stratégie au lieu d’explorer la diversité de \(p_{\text{data}}\).

Hide code cell source

# Illustration de l'effondrement de mode
fig, axes = plt.subplots(3, 1, figsize=(9, 11))

# Mode 1 : distribution cible multimodale
np.random.seed(0)
modes_x = [np.random.randn(200) + c for c in [-3, 0, 3]]
modes_y = [np.random.randn(200) + c for c in [0, 3, 0]]

ax = axes[0]
for mx, my in zip(modes_x, modes_y):
    ax.scatter(mx, my, s=10, alpha=0.5)
ax.set_title("Distribution cible\n(3 modes)", fontsize=10)
ax.set_xlim(-6, 6); ax.set_ylim(-3, 6); ax.set_aspect('equal')

# Mode 2 : effondrement partiel
ax = axes[1]
ax.scatter(modes_x[0], modes_y[0], s=10, alpha=0.3, color='gray', label='Non couvert')
ax.scatter(modes_x[2], modes_y[2], s=10, alpha=0.3, color='gray')
collapsed = np.random.randn(600, 2) * 0.3 + np.array([0, 3])
ax.scatter(collapsed[:, 0], collapsed[:, 1], s=10, alpha=0.5, color='coral', label='Généré')
ax.set_title("Effondrement partiel\n(1 mode sur 3)", fontsize=10)
ax.set_xlim(-6, 6); ax.set_ylim(-3, 6); ax.set_aspect('equal')
ax.legend(fontsize=8)

# Mode 3 : effondrement total
ax = axes[2]
for mx, my in zip(modes_x, modes_y):
    ax.scatter(mx, my, s=10, alpha=0.3, color='gray')
collapsed_total = np.random.randn(600, 2) * 0.1 + np.array([0.5, 1.5])
ax.scatter(collapsed_total[:, 0], collapsed_total[:, 1], s=10, alpha=0.5, color='coral')
ax.set_title("Effondrement total\n(1 seul point)", fontsize=10)
ax.set_xlim(-6, 6); ax.set_ylim(-3, 6); ax.set_aspect('equal')

plt.suptitle("Illustration de l'effondrement de mode", fontsize=12, y=1.02)
plt.tight_layout()
plt.show()
_images/3dff075ab9c1f491f0edfa32709ccd1162edb02ea4769f65e8dc4420ab6ffeb2.png

Gradient évanescent pour le générateur#

Lorsque le discriminateur est trop puissant, il classe les échantillons générés comme faux avec une confiance proche de 1. La perte du générateur \(\log(1 - D(G(\mathbf{z})))\) sature alors près de \(\log(1) = 0\), et les gradients pour \(G\) deviennent négligeables. Le générateur ne reçoit plus de signal d’apprentissage utile.

Oscillations et non-convergence#

Le GAN n’optimise pas une fonction de perte unique mais résout un jeu minimax. Dans les cas pathologiques, les mises à jour de \(G\) et \(D\) s’annulent mutuellement, conduisant à des oscillations sans convergence vers un équilibre.

Solutions et techniques de stabilisation#

Définition 264 (Techniques de stabilisation des GAN)

Plusieurs techniques ont été proposées pour améliorer la stabilité de l’entraînement :

  1. Label smoothing : remplacer les étiquettes \(1\) (réel) par des valeurs dans \([0.7, 1.0]\) pour éviter que le discriminateur soit trop confiant.

  2. Normalisation spectrale (spectral normalization, Miyato et al., 2018) : normaliser les poids de \(D\) par leur plus grande valeur singulière, contraignant la constante de Lipschitz du discriminateur.

  3. Croissance progressive (progressive growing, Karras et al., 2018) : entraîner le GAN à des résolutions croissantes (\(4 \times 4 \to 8 \times 8 \to \cdots \to 1024 \times 1024\)) pour stabiliser l’apprentissage de structures à grande échelle.

  4. Entraînement à deux échelles de temps (TTUR) : utiliser un taux d’apprentissage plus élevé pour \(D\) que pour \(G\) afin de maintenir le discriminateur proche de son optimum.

  5. Perte de gradient penalty : ajouter une pénalisation sur la norme du gradient de \(D\) (utilisée dans le WGAN-GP).

Hide code cell source

# Illustration : effet du label smoothing sur la confiance de D
labels_hard = np.array([1.0] * 50 + [0.0] * 50)
labels_smooth = np.array([0.9] * 50 + [0.0] * 50)  # Label smoothing unilatéral

fig, axes = plt.subplots(2, 1, figsize=(9, 5))
axes[0].bar(range(100), labels_hard, color=['steelblue']*50 + ['coral']*50, alpha=0.7)
axes[0].set_title("Étiquettes dures : réel = 1, faux = 0", fontsize=10)
axes[0].set_ylabel("Étiquette"); axes[0].set_xlabel("Échantillon")

axes[1].bar(range(100), labels_smooth, color=['steelblue']*50 + ['coral']*50, alpha=0.7)
axes[1].set_title("Label smoothing : réel = 0.9, faux = 0", fontsize=10)
axes[1].set_ylabel("Étiquette"); axes[1].set_xlabel("Échantillon")

plt.tight_layout()
plt.show()
_images/b81b57637c7194161cc7f6f04259acc0bd1698985892dba4bb3a0a5f3b7c4c54.png

Variantes et extensions#

WGAN : distance de Wasserstein#

Le WGAN (Wasserstein GAN, Arjovsky et al., 2017) remplace la divergence de Jensen-Shannon par la distance de Wasserstein-1 (distance de Earth Mover), qui fournit un signal de gradient plus informatif même lorsque les distributions \(p_{\text{data}}\) et \(p_g\) ne se chevauchent pas.

Définition 265 (WGAN et distance de Wasserstein)

La distance de Wasserstein-1 entre deux distributions \(p\) et \(q\) est :

\[W_1(p, q) = \inf_{\gamma \in \Pi(p, q)} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \gamma} \left[ \|\mathbf{x} - \mathbf{y}\| \right]\]

Par la dualité de Kantorovich-Rubinstein :

\[W_1(p, q) = \sup_{\|f\|_L \leq 1} \left( \mathbb{E}_{\mathbf{x} \sim p}[f(\mathbf{x})] - \mathbb{E}_{\mathbf{x} \sim q}[f(\mathbf{x})] \right)\]

où le supremum porte sur les fonctions 1-Lipschitz. Le discriminateur devient un critique \(f_w\) (sans sigmoïde) et la contrainte de Lipschitz est imposée par :

  • Weight clipping (WGAN original) : \(w \leftarrow \text{clip}(w, -c, c)\)

  • Gradient penalty (WGAN-GP, Gulrajani et al., 2017) : \(\lambda \, \mathbb{E}_{\hat{\mathbf{x}}} \left[ (\|\nabla_{\hat{\mathbf{x}}} f_w(\hat{\mathbf{x}})\|_2 - 1)^2 \right]\)

\(\hat{\mathbf{x}}\) est un point interpolé entre un échantillon réel et un échantillon généré.

Remarque 229

Avantages du WGAN :

  • La perte du critique est une estimation de la distance de Wasserstein, ce qui la rend interprétable comme mesure de qualité. Contrairement au GAN classique, une perte décroissante du critique corrèle avec une amélioration de la qualité des échantillons.

  • Le gradient penalty (WGAN-GP) est préféré au weight clipping car il évite les problèmes de gradients explosifs ou évanescents liés au clipping.

Hide code cell source

# Implémentation du critique WGAN-GP
class WGANCritic(nn.Module):
    """Critique WGAN (pas de sigmoïde en sortie)."""
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),  # Pas de Sigmoid !
        )

    def forward(self, x):
        return self.net(x)


def gradient_penalty(critic, real, fake, device):
    """Calcule la pénalité de gradient pour WGAN-GP."""
    alpha = torch.rand(real.size(0), 1, device=device)
    interpolated = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interpolated = critic(interpolated)
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]
    grad_norm = gradients.norm(2, dim=1)
    return ((grad_norm - 1) ** 2).mean()

Hide code cell source

# Entraînement WGAN-GP sur le mélange de gaussiennes 2D
G_w = Generator2D(latent_dim=16).to(device)
C_w = WGANCritic().to(device)
opt_G_w = optim.Adam(G_w.parameters(), lr=1e-4, betas=(0.0, 0.9))
opt_C_w = optim.Adam(C_w.parameters(), lr=1e-4, betas=(0.0, 0.9))

n_critic = 5  # Nombre de pas du critique par pas du générateur
lambda_gp = 10.0
n_epochs_w = 30
losses_W = []

dataset_w = TensorDataset(torch.from_numpy(data_real))
loader_w = DataLoader(dataset_w, batch_size=256, shuffle=True)

for epoch in range(n_epochs_w):
    for (real_batch,) in loader_w:
        real_batch = real_batch.to(device)
        bs = real_batch.size(0)

        # --- Entraîner le critique (n_critic pas) ---
        for _ in range(n_critic):
            z = torch.randn(bs, 16, device=device)
            fake = G_w(z).detach()
            loss_C = C_w(fake).mean() - C_w(real_batch).mean()
            gp = gradient_penalty(C_w, real_batch, fake, device)
            loss_C_total = loss_C + lambda_gp * gp

            opt_C_w.zero_grad()
            loss_C_total.backward()
            opt_C_w.step()

        # --- Entraîner G ---
        z = torch.randn(bs, 16, device=device)
        fake = G_w(z)
        loss_G_w = -C_w(fake).mean()

        opt_G_w.zero_grad()
        loss_G_w.backward()
        opt_G_w.step()

    losses_W.append(-loss_C.item())  # Estimation de W_1

# Visualisation
with torch.no_grad():
    z_test = torch.randn(3000, 16, device=device)
    samples_wgan = G_w(z_test).cpu().numpy()

fig, axes = plt.subplots(2, 1, figsize=(9, 9))
ax = axes[0]
ax.scatter(data_real[:, 0], data_real[:, 1], s=3, alpha=0.3, color='steelblue', label='Réel')
ax.scatter(samples_wgan[:, 0], samples_wgan[:, 1], s=3, alpha=0.5, color='coral', label='WGAN-GP')
ax.set_xlim(-5, 5); ax.set_ylim(-5, 5); ax.set_aspect('equal')
ax.set_title("Échantillons WGAN-GP"); ax.legend(fontsize=9)

ax = axes[1]
ax.plot(losses_W, color='steelblue')
ax.set_xlabel("Époque"); ax.set_ylabel("Distance de Wasserstein estimée")
ax.set_title("Évolution de la distance de Wasserstein")

plt.tight_layout()
plt.show()
_images/914979f2783548191566a2480c8d6718ef3f9eee760fe0bf2d0b1ec51b51bcd7.png

GAN conditionnel (cGAN)#

Le GAN conditionnel (Conditional GAN, Mirza & Osindero, 2014) introduit une information conditionnelle \(\mathbf{y}\) (par exemple une étiquette de classe) dans les deux réseaux. Cela permet de contrôler la génération.

Définition 266 (GAN conditionnel)

Dans un cGAN, le générateur et le discriminateur reçoivent tous deux l’information conditionnelle \(\mathbf{y}\) :

\[\min_G \max_D \; V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} [\log D(\mathbf{x} \mid \mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p_z} [\log(1 - D(G(\mathbf{z} \mid \mathbf{y}) \mid \mathbf{y}))]\]

En pratique, \(\mathbf{y}\) est concaténé au vecteur latent \(\mathbf{z}\) pour le générateur, et à l’entrée \(\mathbf{x}\) pour le discriminateur. Pour des étiquettes discrètes, on utilise un embedding ou un vecteur one-hot.

Hide code cell source

# Implémentation simplifiée d'un cGAN pour MNIST
class ConditionalGenerator(nn.Module):
    """Générateur conditionnel : z + embedding(y) -> image."""
    def __init__(self, latent_dim=100, n_classes=10, ngf=128):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, 50)
        self.project = nn.Sequential(
            nn.Linear(latent_dim + 50, ngf * 2 * 7 * 7),
            nn.BatchNorm1d(ngf * 2 * 7 * 7),
            nn.ReLU(True),
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, ngf, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z, y):
        y_emb = self.label_emb(y)
        x = torch.cat([z, y_emb], dim=1)
        x = self.project(x)
        x = x.view(x.size(0), -1, 7, 7)
        return self.conv(x)


class ConditionalDiscriminator(nn.Module):
    """Discriminateur conditionnel : image + embedding(y) -> vrai/faux."""
    def __init__(self, n_classes=10, ndf=64):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, 28 * 28)
        self.conv = nn.Sequential(
            nn.Conv2d(2, ndf, 4, stride=2, padding=1, bias=False),  # 2 canaux : image + label
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(ndf * 2 * 7 * 7, 1),
            nn.Sigmoid(),
        )

    def forward(self, x, y):
        y_emb = self.label_emb(y).view(y.size(0), 1, 28, 28)
        x_cat = torch.cat([x, y_emb], dim=1)
        features = self.conv(x_cat)
        return self.classifier(features.view(features.size(0), -1))

print("Architecture cGAN définie.")
print(f"  Générateur conditionnel : {sum(p.numel() for p in ConditionalGenerator().parameters()):,} paramètres")
print(f"  Discriminateur conditionnel : {sum(p.numel() for p in ConditionalDiscriminator().parameters()):,} paramètres")
Architecture cGAN définie.
  Générateur conditionnel : 2,446,324 paramètres
  Discriminateur conditionnel : 147,489 paramètres

Autres variantes notables#

Remarque 230

Panorama des variantes de GAN :

  • StyleGAN (Karras et al., 2019) : architecture fondée sur un réseau de mapping qui transforme \(\mathbf{z}\) en un espace de style \(\mathbf{w}\), injecté à chaque résolution via la normalisation adaptative d’instance (AdaIN). StyleGAN2 et StyleGAN3 améliorent la qualité et la cohérence spatiale des images générées.

  • CycleGAN (Zhu et al., 2017) : permet la traduction d’image à image non appariée (unpaired image-to-image translation). Deux générateurs \(G_{A \to B}\) et \(G_{B \to A}\) et deux discriminateurs sont entraînés avec une perte de cycle \(\|G_{B \to A}(G_{A \to B}(\mathbf{x})) - \mathbf{x}\|_1\) qui assure la cohérence.

  • Pix2Pix (Isola et al., 2017) : traduction d’image à image appariée, combinant une perte adversariale et une perte \(L^1\) pixel par pixel.

  • BigGAN (Brock et al., 2019) : mise à l’échelle des GAN conditionnels avec de grands batches, orthogonal regularization et class-conditional batch normalization pour générer des images ImageNet à haute résolution.

Hide code cell source

# Chronologie des architectures GAN
fig, ax = plt.subplots(figsize=(14, 3.5))
ax.set_xlim(2013, 2025); ax.set_ylim(-1, 2.5)
ax.axis('off')
ax.set_title("Chronologie des architectures GAN", fontsize=12, pad=10)

events = [
    (2014, "GAN\n(Goodfellow)", '#4C72B0'),
    (2014.7, "cGAN\n(Mirza)", '#55A868'),
    (2016, "DCGAN\n(Radford)", '#DD8452'),
    (2017, "WGAN\n(Arjovsky)", '#8B6DAF'),
    (2017.5, "CycleGAN\n(Zhu)", '#C44E52'),
    (2018, "SAGAN\n(Zhang)", '#937860'),
    (2019, "StyleGAN\n(Karras)", '#E24A33'),
    (2020, "StyleGAN2\n(Karras)", '#DA8BC3'),
    (2021, "StyleGAN3\n(Karras)", '#8172B3'),
]

ax.axhline(y=0, color='gray', linewidth=2, alpha=0.3, xmin=0.02, xmax=0.98)
for year, label, color in events:
    ax.plot(year, 0, 'o', color=color, markersize=10, zorder=5)
    offset_y = 0.5 if events.index((year, label, color)) % 2 == 0 else 1.3
    ax.annotate(f"{int(year)}\n{label}", xy=(year, 0), xytext=(year, offset_y),
                fontsize=7.5, ha='center', va='bottom', color=color,
                arrowprops=dict(arrowstyle='-', color=color, alpha=0.5))

plt.tight_layout()
plt.show()
_images/9e4aa8a561049dd9cbe80bbf654ad1e998dc58ea6385372bf528480f74da9329.png

Applications des GAN#

Les GAN ont trouvé des applications dans de nombreux domaines, bien au-dela de la simple génération d’images.

Génération d’images réalistes#

L’application la plus emblématique : générer des images photoréalistes de visages, d’objets ou de scènes qui n’existent pas. Les modèles StyleGAN génèrent des visages indiscernables de photographies réelles.

Super-résolution#

Le SRGAN (Super-Resolution GAN, Ledig et al., 2017) et son successeur ESRGAN augmentent la résolution d’une image basse résolution. Le discriminateur force le générateur à produire des détails haute fréquence réalistes, là où les méthodes \(L^2\) classiques produisent des résultats flous.

Transfert de style et traduction d’images#

CycleGAN permet des transformations spectaculaires : photo \(\leftrightarrow\) peinture, cheval \(\leftrightarrow\) zèbre, été \(\leftrightarrow\) hiver. Pix2Pix traduit des croquis en images réalistes, des cartes en photos aériennes, etc.

Augmentation de données#

Les GAN peuvent générer des données synthétiques pour enrichir un jeu d’entraînement, particulièrement utile dans les domaines où les données annotées sont rares (imagerie médicale, détection de défauts industriels).

Exemple 27 (Applications des GAN par domaine)

Domaine

Application

Architecture

Vision par ordinateur

Génération de visages

StyleGAN

Imagerie médicale

Augmentation de données

cGAN, Pix2Pix

Super-résolution

Amélioration d’images

SRGAN, ESRGAN

Art et design

Transfert de style

CycleGAN

Texte → Image

Génération depuis description

StackGAN, AttnGAN

Vidéo

Prédiction de trames

MoCoGAN

Audio

Synthèse vocale

WaveGAN, MelGAN

Sciences

Simulation physique

Physics-informed GAN

Hide code cell source

# Grille finale : 100 images générées par le DCGAN entraîné
with torch.no_grad():
    z_final = torch.randn(100, latent_dim, device=device)
    generated = netG(z_final).cpu()

fig, axes = plt.subplots(10, 10, figsize=(10, 10))
for i in range(10):
    for j in range(10):
        img = generated[i * 10 + j].squeeze().numpy()
        img = (img + 1) / 2
        axes[i, j].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[i, j].axis('off')
plt.suptitle("100 images générées par le DCGAN (MNIST)", fontsize=14, y=1.01)
plt.tight_layout()
plt.show()
_images/97df4df94be633216ce571e5c3e6e192ac56066ac427da029762b4a11a4700d4.png

Hide code cell source

# Métriques de qualité : diversité des échantillons générés
with torch.no_grad():
    z_eval = torch.randn(1000, latent_dim, device=device)
    gen_eval = netG(z_eval).cpu().view(1000, -1).numpy()

# Variance inter-échantillons comme proxy de diversité
pixel_variance = np.var(gen_eval, axis=0)
mean_variance = np.mean(pixel_variance)

fig, axes = plt.subplots(2, 1, figsize=(9, 7))

ax = axes[0]
ax.hist(pixel_variance, bins=50, color='steelblue', alpha=0.7, edgecolor='white')
ax.axvline(mean_variance, color='coral', linestyle='--', label=f'Moyenne = {mean_variance:.4f}')
ax.set_xlabel("Variance par pixel"); ax.set_ylabel("Nombre de pixels")
ax.set_title("Distribution de la variance inter-échantillons")
ax.legend()

ax = axes[1]
variance_map = pixel_variance.reshape(28, 28)
im = ax.imshow(variance_map, cmap='hot')
ax.set_title("Carte de variance spatiale")
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()
_images/fc0f9d22859e341686da0fc167af6c219dde36bf21a99e314162e0fc627298f6.png

Remarque 231

Métriques d’évaluation des GAN — L’évaluation des GAN est un problème ouvert. Deux métriques sont largement utilisées :

  • Le FID (Fréchet Inception Distance) mesure la distance entre les distributions des features extraites par un réseau Inception pré-entraîné pour les images réelles et générées. Un FID plus bas indique une meilleure qualité.

  • L”IS (Inception Score) mesure à la fois la qualité (les images générées sont-elles classifiables ?) et la diversité (la distribution des classes est-elle uniforme ?).

Ces deux métriques reposent sur un réseau pré-entraîné et présentent des biais connus, mais elles restent les standards de facto pour comparer les modèles génératifs.

Résumé#

Ce chapitre a présenté les réseaux antagonistes génératifs (GAN), une famille de modèles génératifs fondée sur l’apprentissage adversarial.

  1. Un GAN est composé d’un générateur \(G\) qui transforme du bruit en données synthétiques et d’un discriminateur \(D\) qui distingue le vrai du faux. Les deux réseaux sont entraînés dans un jeu minimax dont l’équilibre de Nash correspond à \(p_g = p_{\text{data}}\).

  2. Le discriminateur optimal est \(D^*(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\) et la fonction de valeur à l’équilibre est liée à la divergence de Jensen-Shannon entre les deux distributions.

  3. Le DCGAN introduit des directives architecturales (convolutions transposées, batch normalization, LeakyReLU) qui stabilisent l’entraînement sur des images et produisent des résultats de qualité avec des architectures relativement simples.

  4. L’entraînement des GAN souffre de problèmes majeurs : effondrement de mode, gradients évanescents pour le générateur et oscillations. Des techniques comme le label smoothing, la normalisation spectrale et la croissance progressive permettent de les atténuer.

  5. Le WGAN remplace la divergence de Jensen-Shannon par la distance de Wasserstein, fournissant un signal de gradient plus stable et une perte corrélée à la qualité des échantillons. Le WGAN-GP impose la contrainte de Lipschitz par pénalisation du gradient.

  6. Le GAN conditionnel (cGAN) permet de contrôler la génération en conditionnant les deux réseaux sur une information auxiliaire (classe, texte, image).

  7. Les applications des GAN couvrent la génération d’images, la super-résolution, le transfert de style, l’augmentation de données et bien d’autres domaines.

Remarque 232

Les GAN ont été les modèles génératifs dominants entre 2014 et 2020, produisant des images d’un réalisme sans précédent. Cependant, les modèles de diffusion (diffusion models), introduits par Sohl-Dickstein et al. (2015) et popularisés par Ho et al. (2020), les ont progressivement supplantés pour la génération d’images, grâce à une stabilité d’entraînement supérieure et une meilleure couverture des modes. Les GAN restent néanmoins pertinents pour des applications nécessitant une génération rapide (inférence en un seul passage) et continuent d’inspirer de nouvelles architectures hybrides combinant les avantages des deux approches.