---
jupytext:
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.16.0
kernelspec:
  name: python3
  display_name: Python 3
---

# MCMC et inférence bayésienne computationnelle

> Quand le posterior est trop complexe pour être calculé, on l'échantillonne.

## Introduction : pourquoi le MCMC ?

Le chapitre précédent a montré que l'inférence bayésienne requiert de calculer la distribution posterior :

$$P(\theta \mid \mathbf{y}) = \frac{P(\mathbf{y} \mid \theta) \, P(\theta)}{\int P(\mathbf{y} \mid \theta') \, P(\theta') \, d\theta'}$$

Le problème est le dénominateur : l'**evidence** $P(\mathbf{y})$ est une intégrale en dimension $d$ (le nombre de paramètres). Pour des modèles réalistes avec des priors non-conjugués ou des structures hiérarchiques, cette intégrale est **analytiquement intractable** et le coût de l'intégration numérique croît exponentiellement avec $d$.

Les méthodes **MCMC** (Markov Chain Monte Carlo) contournent ce problème en construisant une chaîne de Markov dont la distribution stationnaire est précisément le posterior souhaité. On n'a pas besoin de normaliser !

```{code-cell} python
:tags: [hide-input]

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import pandas as pd
from scipy import stats
from scipy.stats import norm, multivariate_normal

sns.set_theme(style="whitegrid", palette="muted", font_scale=1.1)

# Illustration : posterior 2D intractable
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Prior
ax = axes[0]
mu_grid = np.linspace(-4, 8, 200)
sigma_grid = np.linspace(0.1, 5, 200)
MU, SIGMA = np.meshgrid(mu_grid, sigma_grid)
prior_2d = (norm.pdf(MU, 2, 2) *
            stats.gamma.pdf(SIGMA, 2, scale=1))
im1 = ax.contourf(MU, SIGMA, prior_2d, levels=30, cmap='Blues')
ax.set_xlabel('μ'); ax.set_ylabel('σ')
ax.set_title('Prior joint P(μ, σ)')
plt.colorbar(im1, ax=ax)

# 2. Vraisemblance
np.random.seed(42)
data_ex = np.random.normal(3.5, 1.5, 20)
ax = axes[1]
like_2d = np.array([
    [np.sum(norm.logpdf(data_ex, mu, sigma)) for mu in mu_grid]
    for sigma in sigma_grid
])
like_2d_exp = np.exp(like_2d - like_2d.max())
im2 = ax.contourf(MU, SIGMA, like_2d_exp, levels=30, cmap='Reds')
ax.set_xlabel('μ'); ax.set_ylabel('σ')
ax.set_title('Vraisemblance L(μ, σ | données)')
plt.colorbar(im2, ax=ax)

# 3. Posterior (non normalisé)
ax = axes[2]
post_2d = like_2d_exp * prior_2d
post_2d /= post_2d.max()
im3 = ax.contourf(MU, SIGMA, post_2d, levels=30, cmap='Greens')
ax.set_xlabel('μ'); ax.set_ylabel('σ')
ax.set_title('Posterior P(μ, σ | données)\n(non normalisé)')
plt.colorbar(im3, ax=ax)

plt.suptitle('Inférence bayésienne 2D — Illustration du problème de normalisation',
             fontsize=12, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig('_static/14_posterior_2d.png', dpi=120, bbox_inches='tight')
plt.show()
```

## Chaînes de Markov : rappels

Une **chaîne de Markov** est une suite de variables aléatoires $\theta^{(1)}, \theta^{(2)}, \ldots$ telle que l'état suivant ne dépend que de l'état courant (propriété de Markov) :

$$P(\theta^{(t+1)} \mid \theta^{(t)}, \theta^{(t-1)}, \ldots) = P(\theta^{(t+1)} \mid \theta^{(t)})$$

Pour que la chaîne converge vers le posterior cible $\pi(\theta)$, elle doit être :

- **Irréductible** : tout état est accessible depuis tout autre état
- **Apériodique** : la chaîne ne s'enferme pas dans des cycles
- **Réversible** : condition de bilan détaillé $\pi(\theta) q(\theta \to \theta') = \pi(\theta') q(\theta' \to \theta)$

## Algorithme de Metropolis-Hastings

L'algorithme de **Metropolis-Hastings** (MH) est le fondement de la plupart des méthodes MCMC. Il construit une chaîne en proposant un nouveau point $\theta^*$ et en l'acceptant ou le rejetant selon un ratio.

**Algorithme :**

1. Initialiser $\theta^{(0)}$ arbitrairement
2. Pour $t = 1, 2, \ldots, T$ :
   - Proposer $\theta^* \sim q(\theta^* \mid \theta^{(t-1)})$
   - Calculer le ratio d'acceptation : $\alpha = \min\left(1, \frac{\pi(\theta^*) \, q(\theta^{(t-1)} \mid \theta^*)}{\pi(\theta^{(t-1)}) \, q(\theta^* \mid \theta^{(t-1)})}\right)$
   - Accepter avec probabilité $\alpha$ : $\theta^{(t)} = \theta^*$ ; sinon $\theta^{(t)} = \theta^{(t-1)}$

```{admonition} Astuce clé : travail en log
:class: note

En pratique, le ratio est calculé en log pour éviter les sous-débordements numériques :
$\log \alpha = \log \pi(\theta^*) - \log \pi(\theta^{(t-1)}) + \log q(\theta^{(t-1)}|\theta^*) - \log q(\theta^*|\theta^{(t-1)})$

Pour une marche aléatoire gaussienne (proposition symétrique), les termes $q$ s'annulent.
```

## Implémentation complète from scratch

### Exemple : inférence de la moyenne et variance d'une Normale

```{code-cell} python
def log_prior(mu, log_sigma):
    """Log-prior : mu ~ N(0,10), log(sigma) ~ N(0,1)."""
    lp_mu = norm.logpdf(mu, 0, 10)
    lp_logsigma = norm.logpdf(log_sigma, 0, 1)
    return lp_mu + lp_logsigma

def log_likelihood(data, mu, log_sigma):
    """Log-vraisemblance gaussienne."""
    sigma = np.exp(log_sigma)
    return np.sum(norm.logpdf(data, mu, sigma))

def log_posterior(data, mu, log_sigma):
    """Log-posterior non normalisé."""
    return log_prior(mu, log_sigma) + log_likelihood(data, mu, log_sigma)


def metropolis_hastings(data, n_iter=10000, step_size=0.15, seed=42):
    """
    Metropolis-Hastings avec marche aléatoire gaussienne.
    Paramètres : (mu, log_sigma) pour estimer N(mu, sigma²).
    """
    rng = np.random.default_rng(seed)
    # Initialisation
    theta_curr = np.array([0.0, 0.0])  # (mu, log_sigma)
    lp_curr = log_posterior(data, theta_curr[0], theta_curr[1])

    chain = np.zeros((n_iter, 2))
    accepted = 0

    for t in range(n_iter):
        # Proposition : marche aléatoire gaussienne
        theta_prop = theta_curr + rng.normal(0, step_size, 2)
        lp_prop = log_posterior(data, theta_prop[0], theta_prop[1])

        # Ratio d'acceptation (log-espace)
        log_alpha = lp_prop - lp_curr
        log_u = np.log(rng.uniform())

        if log_u < log_alpha:
            theta_curr = theta_prop
            lp_curr = lp_prop
            accepted += 1

        chain[t] = theta_curr

    taux_accept = accepted / n_iter
    return chain, taux_accept


# Génération des données
np.random.seed(2024)
true_mu, true_sigma = 3.7, 1.4
data_mh = np.random.normal(true_mu, true_sigma, 50)

print(f"Vraies valeurs : μ = {true_mu}, σ = {true_sigma}")
print(f"Statistiques empiriques : ȳ = {data_mh.mean():.3f}, s = {data_mh.std():.3f}")

# Lancer MCMC
chain, taux = metropolis_hastings(data_mh, n_iter=15000, step_size=0.15)
chain_mu = chain[:, 0]
chain_sigma = np.exp(chain[:, 1])  # retour en espace σ

print(f"\nTaux d'acceptation : {taux:.1%}")
print(f"(Idéal : 20-40% pour marche aléatoire)")
```

```{code-cell} python
# Résultats après burn-in
burn_in = 3000
chain_mu_post = chain_mu[burn_in:]
chain_sigma_post = chain_sigma[burn_in:]

print("Inférence bayésienne (MCMC) :")
print(f"  μ — Moyenne : {chain_mu_post.mean():.3f}  "
      f"IC 95% : [{np.percentile(chain_mu_post, 2.5):.3f}, {np.percentile(chain_mu_post, 97.5):.3f}]")
print(f"  σ — Moyenne : {chain_sigma_post.mean():.3f}  "
      f"IC 95% : [{np.percentile(chain_sigma_post, 2.5):.3f}, {np.percentile(chain_sigma_post, 97.5):.3f}]")
print(f"\nVraies valeurs : μ = {true_mu}, σ = {true_sigma}")
```

```{code-cell} python
:tags: [hide-input]

fig = plt.figure(figsize=(15, 10))
gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.45, wspace=0.35)

# Trace plots
ax1 = fig.add_subplot(gs[0, :2])
ax1.plot(chain_mu, color='steelblue', lw=0.6, alpha=0.8)
ax1.axvline(burn_in, color='tomato', ls='--', lw=2, label=f'Burn-in ({burn_in})')
ax1.axhline(true_mu, color='green', ls=':', lw=1.5, label=f'Vraie valeur μ={true_mu}')
ax1.set_xlabel('Itération'); ax1.set_ylabel('μ')
ax1.set_title('Trace plot — μ')
ax1.legend(fontsize=8)

ax2 = fig.add_subplot(gs[1, :2])
ax2.plot(chain_sigma, color='darkorange', lw=0.6, alpha=0.8)
ax2.axvline(burn_in, color='tomato', ls='--', lw=2, label=f'Burn-in ({burn_in})')
ax2.axhline(true_sigma, color='green', ls=':', lw=1.5, label=f'Vraie valeur σ={true_sigma}')
ax2.set_xlabel('Itération'); ax2.set_ylabel('σ')
ax2.set_title('Trace plot — σ')
ax2.legend(fontsize=8)

# Posterior μ
ax3 = fig.add_subplot(gs[0, 2])
ax3.hist(chain_mu_post, bins=60, density=True, color='steelblue',
         alpha=0.7, edgecolor='white')
mu_range = np.linspace(chain_mu_post.min(), chain_mu_post.max(), 300)
# Prior pour comparaison
ax3.plot(mu_range, norm.pdf(mu_range, 0, 10), 'gray', lw=1.5,
         ls='--', alpha=0.7, label='Prior N(0,10)')
ax3.axvline(true_mu, color='green', lw=2, label=f'Vraie val. {true_mu}')
ax3.axvline(chain_mu_post.mean(), color='black', lw=1.5, ls='--',
            label=f'Post. moy. {chain_mu_post.mean():.2f}')
ax3.set_title('Posterior μ'); ax3.legend(fontsize=7)

# Posterior σ
ax4 = fig.add_subplot(gs[1, 2])
ax4.hist(chain_sigma_post, bins=60, density=True, color='darkorange',
         alpha=0.7, edgecolor='white')
ax4.axvline(true_sigma, color='green', lw=2, label=f'Vraie val. {true_sigma}')
ax4.axvline(chain_sigma_post.mean(), color='black', lw=1.5, ls='--',
            label=f'Post. moy. {chain_sigma_post.mean():.2f}')
ax4.set_title('Posterior σ'); ax4.legend(fontsize=7)

# Autocorrélation μ
ax5 = fig.add_subplot(gs[2, 0])
lags = range(0, 60)
autocorr = [np.corrcoef(chain_mu_post[:-l], chain_mu_post[l:])[0, 1]
            if l > 0 else 1.0 for l in lags]
ax5.bar(list(lags), autocorr, color='steelblue', alpha=0.7, width=0.8)
ax5.axhline(0, color='black', lw=0.8)
ax5.axhline(1.96/np.sqrt(len(chain_mu_post)), color='tomato', ls='--',
            lw=1.2, label='IC 95%')
ax5.axhline(-1.96/np.sqrt(len(chain_mu_post)), color='tomato', ls='--', lw=1.2)
ax5.set_xlabel('Lag'); ax5.set_ylabel('Autocorrélation')
ax5.set_title('Autocorrélation — μ'); ax5.legend(fontsize=7)

# Autocorrélation σ
ax6 = fig.add_subplot(gs[2, 1])
autocorr_s = [np.corrcoef(chain_sigma_post[:-l], chain_sigma_post[l:])[0, 1]
              if l > 0 else 1.0 for l in lags]
ax6.bar(list(lags), autocorr_s, color='darkorange', alpha=0.7, width=0.8)
ax6.axhline(0, color='black', lw=0.8)
ax6.axhline(1.96/np.sqrt(len(chain_sigma_post)), color='tomato', ls='--', lw=1.2)
ax6.axhline(-1.96/np.sqrt(len(chain_sigma_post)), color='tomato', ls='--', lw=1.2)
ax6.set_xlabel('Lag'); ax6.set_ylabel('Autocorrélation')
ax6.set_title('Autocorrélation — σ')

# Joint posterior
ax7 = fig.add_subplot(gs[2, 2])
ax7.scatter(chain_mu_post[::5], chain_sigma_post[::5],
            alpha=0.15, color='steelblue', s=8)
ax7.scatter(true_mu, true_sigma, color='tomato', s=150, marker='*',
            zorder=10, label=f'Vraie val. ({true_mu},{true_sigma})')
ax7.set_xlabel('μ'); ax7.set_ylabel('σ')
ax7.set_title('Joint posterior (μ, σ)')
ax7.legend(fontsize=8)

plt.savefig('_static/14_mcmc_diag.png', dpi=120, bbox_inches='tight')
plt.show()
```

## Diagnostics MCMC

### Taux d'acceptation

```{code-cell} python
# Effet de la taille de pas sur le taux d'acceptation
step_sizes = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
results_steps = []

for step in step_sizes:
    chain_test, taux_test = metropolis_hastings(data_mh, n_iter=5000, step_size=step, seed=1)
    chain_test_post = chain_test[1000:, 0]
    results_steps.append({
        'step_size': step,
        'taux_acceptation': taux_test,
        'ESS_approx': len(chain_test_post) / (1 + 2 * sum(
            [np.corrcoef(chain_test_post[:-l], chain_test_post[l:])[0, 1]
             for l in range(1, min(50, len(chain_test_post)//4))]
        ))
    })

df_steps = pd.DataFrame(results_steps)
print("Effet de la taille de pas :")
print(df_steps.to_string(index=False))
print("\n→ Taux optimal : 20-40% (balance exploration/exploitation)")
```

### R-hat (Gelman-Rubin)

Le diagnostic de **Gelman-Rubin** ($\hat{R}$) compare la variance entre plusieurs chaînes et la variance intra-chaîne. Une valeur proche de 1 indique la convergence.

```{code-cell} python
def gelman_rubin(chains):
    """
    Calcule R-hat de Gelman-Rubin pour plusieurs chaînes.
    chains : liste de tableaux 1D de même longueur.
    """
    M = len(chains)    # nombre de chaînes
    N = len(chains[0]) # longueur (après burn-in)

    chain_means = np.array([c.mean() for c in chains])
    overall_mean = chain_means.mean()

    # Variance entre chaînes
    B = N / (M - 1) * np.sum((chain_means - overall_mean) ** 2)
    # Variance intra-chaînes
    W = np.mean([c.var(ddof=1) for c in chains])

    # Variance marginale estimée
    var_hat = (1 - 1/N) * W + B/N
    R_hat = np.sqrt(var_hat / W)
    return R_hat

# Lancer 4 chaînes avec initialisations différentes
chains_mu = []
for seed_i, init_mu in zip([42, 123, 456, 789], [-3, 0, 5, 8]):
    def mh_init(data, init, n_iter=8000, step_size=0.2, seed=42):
        rng = np.random.default_rng(seed)
        theta_curr = np.array([init, 0.0])
        lp_curr = log_posterior(data, theta_curr[0], theta_curr[1])
        chain = np.zeros((n_iter, 2))
        for t in range(n_iter):
            theta_prop = theta_curr + rng.normal(0, step_size, 2)
            lp_prop = log_posterior(data, theta_prop[0], theta_prop[1])
            if np.log(rng.uniform()) < lp_prop - lp_curr:
                theta_curr = theta_prop
                lp_curr = lp_prop
            chain[t] = theta_curr
        return chain

    ch = mh_init(data_mh, init_mu, seed=seed_i)
    chains_mu.append(ch[2000:, 0])  # après burn-in

R_hat = gelman_rubin(chains_mu)
print(f"R-hat (Gelman-Rubin) pour μ : {R_hat:.4f}")
print(f"→ R-hat < 1.01 : convergence excellente")
print(f"→ R-hat < 1.05 : convergence acceptable")
print(f"→ R-hat > 1.1  : chaîne non convergée — augmenter n_iter")
```

```{code-cell} python
:tags: [hide-input]

# Visualisation des 4 chaînes avec R-hat
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors_chains = ['steelblue', 'tomato', 'darkgreen', 'orange']
labels_chains = ['Chaîne 1 (init=-3)', 'Chaîne 2 (init=0)',
                 'Chaîne 3 (init=5)', 'Chaîne 4 (init=8)']

# Générer les chaînes complètes pour visualisation
full_chains = []
for seed_i, init_mu in zip([42, 123, 456, 789], [-3, 0, 5, 8]):
    ch_full = mh_init(data_mh, init_mu, n_iter=5000, step_size=0.2, seed=seed_i)
    full_chains.append(ch_full[:, 0])

ax = axes[0]
for ch, color, label in zip(full_chains, colors_chains, labels_chains):
    ax.plot(ch, color=color, lw=0.8, alpha=0.8, label=label)
ax.axvline(2000, color='black', ls='--', lw=1.5, label='Burn-in (2000)')
ax.axhline(true_mu, color='gray', ls=':', lw=1.5, label=f'Vraie valeur μ={true_mu}')
ax.set_xlabel('Itération'); ax.set_ylabel('μ')
ax.set_title(f'Trace plots — 4 chaînes parallèles\nR-hat = {R_hat:.4f}')
ax.legend(fontsize=7, loc='upper right')

# Distributions des posteriors par chaîne (après burn-in)
ax = axes[1]
for ch, color, label in zip(chains_mu, colors_chains, labels_chains):
    from scipy.stats import gaussian_kde
    kde = gaussian_kde(ch)
    x_range = np.linspace(ch.min(), ch.max(), 300)
    ax.plot(x_range, kde(x_range), color=color, lw=2, alpha=0.8, label=label)
ax.axvline(true_mu, color='black', ls='--', lw=1.5, label=f'Vraie valeur {true_mu}')
ax.set_xlabel('μ'); ax.set_ylabel('Densité estimée')
ax.set_title('Posteriors des 4 chaînes\n(distributions superposées = convergence)')
ax.legend(fontsize=7)

plt.tight_layout()
plt.savefig('_static/14_rhat.png', dpi=120, bbox_inches='tight')
plt.show()
```

### Taille effective d'échantillon (ESS)

L'autocorrélation réduit l'information effective de la chaîne. L'ESS (*Effective Sample Size*) estime le nombre d'échantillons indépendants équivalents :

$$\text{ESS} \approx \frac{N}{1 + 2\sum_{k=1}^{\infty} \rho_k}$$

```{code-cell} python
def ess(chain, max_lag=200):
    """Calcule l'ESS par la somme des autocorrélations."""
    n = len(chain)
    rho_sum = 0
    for lag in range(1, min(max_lag, n//4)):
        rho = np.corrcoef(chain[:-lag], chain[lag:])[0, 1]
        if rho < 0.05:  # stop quand autocorr négligeable
            break
        rho_sum += rho
    return n / (1 + 2 * rho_sum)

ess_mu = ess(chain_mu_post)
ess_sigma = ess(chain_sigma_post)
n_post = len(chain_mu_post)

print(f"Taille de la chaîne (après burn-in) : {n_post}")
print(f"ESS pour μ : {ess_mu:.0f} ({ess_mu/n_post:.1%} d'efficacité)")
print(f"ESS pour σ : {ess_sigma:.0f} ({ess_sigma/n_post:.1%} d'efficacité)")
print(f"\n→ ESS > 400 est généralement suffisant pour estimer la moyenne et l'IC 95%.")
```

## Burn-in et thinning

```{code-cell} python
:tags: [hide-input]

fig, axes = plt.subplots(2, 2, figsize=(13, 8))

# Burn-in
ax = axes[0, 0]
ax.plot(chain_mu[:5000], color='steelblue', lw=0.8, alpha=0.8)
ax.axvspan(0, burn_in, alpha=0.12, color='tomato', label=f'Burn-in ({burn_in} iter.)')
ax.axhline(true_mu, color='green', ls=':', lw=1.5, label=f'Vraie valeur {true_mu}')
ax.set_xlabel('Itération'); ax.set_ylabel('μ')
ax.set_title('Burn-in : la chaîne "chauffe" depuis μ=0')
ax.legend(fontsize=8)

# Thinning
ax = axes[0, 1]
# Comparaison avec et sans thinning
chain_full = chain_mu_post
chain_thin = chain_mu_post[::10]  # thinning x10

for i, (c, label, color) in enumerate(zip(
        [chain_full[:200], chain_thin[:200]],
        ['Sans thinning (200 iter.)', 'Avec thinning x10 (200 iter. = 2000 total)'],
        ['steelblue', 'tomato'])):
    ax.plot(c, color=color, lw=1.2, alpha=0.8, label=label)
ax.set_xlabel('Index'); ax.set_ylabel('μ')
ax.set_title('Thinning : réduire l\'autocorrélation')
ax.legend(fontsize=8)

# Autocorrélation avant et après thinning
ax = axes[1, 0]
lags_ac = range(0, 80)
ac_full = [np.corrcoef(chain_full[:-l], chain_full[l:])[0, 1]
           if l > 0 else 1.0 for l in lags_ac]
ac_thin = [np.corrcoef(chain_thin[:-l], chain_thin[l:])[0, 1]
           if l > 0 and l < len(chain_thin)//2 else 1.0 for l in lags_ac]

ax.plot(list(lags_ac), ac_full, 'steelblue', lw=1.5, label='Sans thinning')
ax.plot(list(lags_ac), ac_thin, 'tomato', lw=1.5, label='Avec thinning x10')
ax.axhline(0, color='black', lw=0.8)
ax.axhline(1.96/np.sqrt(len(chain_full)), color='gray', ls='--',
           lw=1, label='Seuil IC 95%')
ax.set_xlabel('Lag'); ax.set_ylabel('Autocorrélation')
ax.set_title('Effet du thinning sur l\'autocorrélation')
ax.legend(fontsize=8)

# Taux d'acceptation moyen par fenêtre
ax = axes[1, 1]
window_size = 500
n_windows = len(chain_mu) // window_size
acceptance_windows = []
for i in range(n_windows):
    window = chain_mu[i*window_size:(i+1)*window_size]
    n_unique = len(np.unique(window.round(8)))
    acceptance_windows.append(n_unique / window_size)

ax.plot(range(n_windows), acceptance_windows, 'steelblue', lw=1.5)
ax.axhline(0.234, color='tomato', ls='--', lw=1.5,
           label='Optimal théorique (23.4%)')
ax.axhline(0.44, color='orange', ls=':', lw=1.5,
           label='Optimal 1D (44%)')
ax.axvline(burn_in // window_size, color='gray', ls=':', lw=1.5)
ax.set_xlabel(f'Fenêtre (×{window_size} iter.)')
ax.set_ylabel('Taux d\'acceptation')
ax.set_title('Taux d\'acceptation par fenêtre')
ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig('_static/14_burnin_thinning.png', dpi=120, bbox_inches='tight')
plt.show()
```

## Gibbs Sampling

Le **Gibbs sampling** est un cas particulier de MH où l'on tire chaque paramètre depuis sa distribution conditionnelle complète, en fixant tous les autres. Taux d'acceptation = 100%.

```{code-cell} python
def gibbs_sampler_normal(data, n_iter=8000, seed=42):
    """
    Gibbs sampler pour N(mu, sigma²) avec priors conjugués :
    mu | sigma² ~ N(mu0, tau0²)
    sigma² ~ InvGamma(a0, b0)
    """
    rng = np.random.default_rng(seed)
    n = len(data)
    y_bar = np.mean(data)

    # Hyperparamètres des priors
    mu0, tau0_sq = 0.0, 100.0  # prior sur mu
    a0, b0 = 0.1, 0.1           # prior sur sigma² (quasi non-informatif)

    # Initialisation
    mu_curr = y_bar
    sigma2_curr = np.var(data, ddof=1)

    chain_gibbs = np.zeros((n_iter, 2))

    for t in range(n_iter):
        # 1. Tirer mu | sigma², données
        tau_n_sq = 1 / (1/tau0_sq + n/sigma2_curr)
        mu_n = tau_n_sq * (mu0/tau0_sq + n*y_bar/sigma2_curr)
        mu_curr = rng.normal(mu_n, np.sqrt(tau_n_sq))

        # 2. Tirer sigma² | mu, données
        a_n = a0 + n/2
        b_n = b0 + 0.5 * np.sum((data - mu_curr)**2)
        sigma2_curr = 1 / rng.gamma(a_n, 1/b_n)  # InvGamma via Gamma

        chain_gibbs[t] = [mu_curr, np.sqrt(sigma2_curr)]

    return chain_gibbs

chain_gibbs = gibbs_sampler_normal(data_mh, n_iter=8000)
burn_gibbs = 2000
chain_g_mu = chain_gibbs[burn_gibbs:, 0]
chain_g_sigma = chain_gibbs[burn_gibbs:, 1]

print("Résultats Gibbs Sampling :")
print(f"  μ : {chain_g_mu.mean():.3f}  "
      f"IC 95% : [{np.percentile(chain_g_mu, 2.5):.3f}, {np.percentile(chain_g_mu, 97.5):.3f}]")
print(f"  σ : {chain_g_sigma.mean():.3f}  "
      f"IC 95% : [{np.percentile(chain_g_sigma, 2.5):.3f}, {np.percentile(chain_g_sigma, 97.5):.3f}]")
print(f"\nVraies valeurs : μ = {true_mu}, σ = {true_sigma}")
print(f"ESS μ (Gibbs) : {ess(chain_g_mu):.0f}")
print(f"ESS μ (MH)    : {ess(chain_mu_post):.0f}")
```

```{code-cell} python
:tags: [hide-input]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Joint posterior MH
ax = axes[0]
ax.scatter(chain_mu_post[::5], chain_sigma[burn_in::5],
           alpha=0.12, color='steelblue', s=8)
ax.scatter(true_mu, true_sigma, color='tomato', s=200, marker='*', zorder=10)
ax.set_xlabel('μ'); ax.set_ylabel('σ')
ax.set_title(f'Joint posterior — Metropolis-Hastings\n(ESS μ ≈ {ess(chain_mu_post):.0f})')

# Joint posterior Gibbs
ax = axes[1]
ax.scatter(chain_g_mu[::5], chain_g_sigma[::5],
           alpha=0.12, color='darkgreen', s=8)
ax.scatter(true_mu, true_sigma, color='tomato', s=200, marker='*', zorder=10)
ax.set_xlabel('μ'); ax.set_ylabel('σ')
ax.set_title(f'Joint posterior — Gibbs Sampling\n(ESS μ ≈ {ess(chain_g_mu):.0f})')

# Comparaison marginales
ax = axes[2]
from scipy.stats import gaussian_kde
kde_mh = gaussian_kde(chain_mu_post)
kde_gb = gaussian_kde(chain_g_mu)
x_range = np.linspace(2.5, 5.0, 300)
ax.plot(x_range, kde_mh(x_range), 'steelblue', lw=2.5, label='MH')
ax.plot(x_range, kde_gb(x_range), 'darkgreen', lw=2.5, ls='--', label='Gibbs')
ax.axvline(true_mu, color='tomato', ls=':', lw=2, label=f'Vraie valeur {true_mu}')
ax.set_xlabel('μ'); ax.set_ylabel('Densité')
ax.set_title('Posterior marginal μ\nMH vs Gibbs (résultats identiques)')
ax.legend(fontsize=9)

plt.tight_layout()
plt.savefig('_static/14_mh_vs_gibbs.png', dpi=120, bbox_inches='tight')
plt.show()
```

## Mentions : PyMC et approximation variationnelle

```{admonition} PyMC (non installé dans cet environnement)
:class: note

[PyMC](https://www.pymc.io) est la bibliothèque Python de référence pour la modélisation bayésienne probabiliste. Elle fournit une interface de haut niveau pour spécifier des modèles et utilise des algorithmes avancés comme **NUTS** (No-U-Turn Sampler, une variante adaptative de HMC).

Exemple de syntaxe PyMC pour le même modèle :

```python
import pymc as pm
with pm.Model() as model:
    mu = pm.Normal('mu', mu=0, sigma=10)
    sigma = pm.HalfNormal('sigma', sigma=1)
    y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
    trace = pm.sample(2000, tune=1000, chains=4, return_inferencedata=True)
    pm.plot_trace(trace)
```

```{admonition} Approximation variationnelle (VI)
:class: note

L'**inférence variationnelle** remplace l'échantillonnage par un problème d'optimisation : on cherche la distribution $q(\theta)$ dans une famille paramétrique (ex : Gaussienne) qui minimise la divergence KL avec le posterior. Plus rapide que MCMC mais l'approximation peut être inexacte pour des posteriors complexes ou multimodaux. Implémentée dans PyMC (`pm.fit()`), TensorFlow Probability, et Pyro.
```

## Résumé

```{admonition} Points clés — MCMC
:class: note

- Le MCMC permet d'échantillonner des posteriors sans calculer la constante de normalisation.
- **Metropolis-Hastings** : algorithme général, taux d'acceptation cible 20-40% pour une marche aléatoire.
- **Gibbs Sampling** : efficace quand les conditionnelles complètes sont disponibles, taux d'acceptation 100%.
- Diagnostics indispensables : trace plots (chaîne mixée ?), R-hat < 1.01 (convergence ?), ESS > 400 (précision ?), autocorrélation (nécessité de thinning ?).
- Burn-in : jeter les premières itérations pendant la phase de "chauffe".
- En pratique, utiliser PyMC ou Stan qui implémentent des algorithmes plus efficaces (NUTS/HMC).
```
