advanced bayesian-ml 60 min read

Meta-Learning

Learning from a distribution over tasks — three lenses (gradient, Bayesian, metric) unified via hierarchical Bayes, with PAC-Bayes meta-bounds and MAML convergence guarantees

§1. Motivation and the task-distribution structure

1.1 Few-shot learning and what gradient descent from scratch can’t do

Suppose we hand a freshly initialized neural network five labeled examples — five — and ask it to fit a sinusoid y=Asin(x+ϕ)y = A \sin(x + \phi) for some unknown amplitude AA and phase ϕ\phi. SGD against mean-squared-error will memorize the five points and produce a curve that disagrees with the true sinusoid everywhere else. This is the small-sample regime in which standard supervised learning offers nothing useful.

But the framing was incomplete. We don’t have only five examples — we have access to many tasks of this form, each drawn from the same family of sinusoids with random amplitude and phase. If we can learn the structure of the family, then for a new task we only need to identify which member of the family we’re looking at, and five points are plenty to do that.

This is the central reframing of meta-learning: when individual tasks come with too few labels for from-scratch supervised learning to succeed, we transfer the inductive bias from the task distribution. The single-task learning algorithm doesn’t change shape — it’s still gradient descent, or it’s still Bayesian inference, or it’s still nearest-prototype assignment — but it now runs in a regime where most of the structure has been preconfigured by meta-training over the task distribution.

The geometric picture, which we’ll return to in every section: a single task’s loss surface over parameters θ\theta may have many minima; some are easy to reach from a generic initialization, some aren’t. If we knew a region of θ\theta-space that contained the easy-to-reach minimum for every task in our family, we could initialize there. That initialization isn’t task-specific — it’s a property of the task distribution. Meta-learning is the search for that initialization, or its Bayesian analog, or its embedding-space analog.

A small MLP trained from scratch on five sinusoidal datapoints for 200 SGD steps. The fit passes through the support points and disagrees with the true sinusoid everywhere else.
Figure 1. A randomly initialized MLP trained from scratch on five sinusoidal datapoints. The fit passes through the support points and disagrees with the truth everywhere else.

1.2 The task distribution and the support/query split

Let p(T)p(\mathcal{T}) denote a distribution over tasks. Each task T\mathcal{T} is a pair (DS,DQ)(\mathcal{D}^{\mathrm{S}}, \mathcal{D}^{\mathrm{Q}}), where DS={(xiS,yiS)}i=1K\mathcal{D}^{\mathrm{S}} = \{(\mathbf{x}_i^{\mathrm{S}}, y_i^{\mathrm{S}})\}_{i=1}^{K} is the support set and DQ={(xjQ,yjQ)}j=1M\mathcal{D}^{\mathrm{Q}} = \{(\mathbf{x}_j^{\mathrm{Q}}, y_j^{\mathrm{Q}})\}_{j=1}^{M} is the query set. Both are drawn from the task’s own data-generating distribution p(x,yT)p(\mathbf{x}, y \mid \mathcal{T}).

Notation: the script T\mathcal{T} is a single task — a pair of datasets plus an implicit private distribution. The roman superscripts S and Q label support and query. KK is the shot count (usually small — 1, 5, or 10), and MM is large enough that empirical query loss is a low-variance estimate of true task loss.

The meta-objective is the expected query loss after adapting on the support set:

Lmeta(θ0)  =  ETp(T) ⁣[LTQ ⁣(A(θ0,DTS))](1.1)\mathcal{L}_{\mathrm{meta}}(\theta_0) \;=\; \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})} \!\left[\, \mathcal{L}_{\mathcal{T}}^{\mathrm{Q}}\!\big(\,\mathcal{A}(\theta_0, \mathcal{D}_{\mathcal{T}}^{\mathrm{S}})\,\big) \,\right] \qquad\quad (1.1)

where A(θ0,DS)\mathcal{A}(\theta_0, \mathcal{D}^{\mathrm{S}}) is the adaptation procedure — the algorithm that takes the meta-parameter θ0\theta_0 and the support set and returns task-specific parameters. Different choices of A\mathcal{A} give different meta-learning lenses. The support/query split is the meta-learning analog of the train/test split for ordinary supervised learning.

1.3 Three lenses on one problem

Gradient-based meta-learning (MAML, Finn–Abbeel–Levine 2017). The adaptation procedure runs a few steps of SGD on the support loss, starting from θ0\theta_0. The meta-parameter θ0\theta_0 is the initialization from which task-specific gradient descent works well in expectation over p(T)p(\mathcal{T}). The outer loop differentiates the query loss with respect to θ0\theta_0, threading the gradient back through the inner SGD trajectory — a bilevel optimization whose outer step requires second-order gradient information. Section §2 derives the bilevel objective in full.

Bayesian meta-learning (Neural Processes, Garnelo et al. 2018). The adaptation procedure is amortized posterior inference: an encoder network reads the support set and emits parameters of a posterior over task-specific predictions. The meta-parameters are the weights of the encoder–decoder pair. Tasks are draws from a stochastic process, and the inference network learns to do approximate Bayesian regression over that process. Section §4 derives the conditional and latent variants.

Metric-learning meta-learning (Prototypical Networks, Snell–Swersky–Zemel 2017). The adaptation procedure computes class prototypes by averaging support-set embeddings under a meta-learned feature map ϕθ0(x)\phi_{\theta_0}(\mathbf{x}), then classifies query points by nearest prototype. There’s no inner-loop optimization — adaptation reduces to a few embedding evaluations — but the embedding network is meta-trained over task batches via cross-entropy on the query set. Section §5 develops this view.

These three look different in implementation, but they share a deep common structure: each is a posterior approximation in a hierarchical-Bayes model where group membership corresponds to task identity. The hyperparameter θ0\theta_0 plays the role of the group-level prior parameter in all three lenses. Section §6 makes that statement rigorous.

1.4 The three running demos at a glance

The notebook validates each lens on a small synthetic problem, chosen to match the canonical demonstration in the original paper, scaled to run on CPU within the 60-second budget.

The first demo is sinusoidal regression for the gradient lens. Tasks are y=Asin(x+ϕ)y = A \sin(x + \phi) with (A,ϕ)Uniform([0.1,5]×[0,2π])(A, \phi) \sim \mathrm{Uniform}([0.1, 5] \times [0, 2\pi]). Support size K=5K = 5. MAML learns an initialization from which one or a few SGD steps recover the correct sinusoid. This is Finn–Abbeel–Levine’s Figure 2 reproduction at the architecture they originally used — two hidden layers, 40 ReLU units.

The second is 1D GP few-shot regression for the Bayesian lens. Each task is a draw from a Gaussian process with an RBF kernel of randomized lengthscale \ell. A Neural Process is trained to perform amortized regression on context-target splits of each task. The GP backbone reuses the kernel-evaluation and Cholesky machinery from formalML’s Gaussian Processes notebook.

The third is 2D synthetic prototypical clustering for the metric lens. Five-way classification with class means drawn uniformly on a 2D disk, five support examples per class, fifteen query examples per class. A prototypical network with a two-layer MLP embedding learns to classify query points by nearest prototype in embedding space. This replaces Omniglot for budget reasons.

The viz below lets you switch between the three task families and re-sample tasks from each. The static fallbacks (Figures 2–4) show four samples from each family.

Task family:
A panel of four sinusoidal tasks with K=5 support points each. Different amplitude and phase per panel.
Figure 2. Sinusoidal task family: four tasks with K=5 support each. Each task is individually under-determined; the family shares structure.
A panel of GP samples at four lengthscales, three samples per panel.
Figure 3. GP task family: three samples at each of four RBF lengthscales. Different lengthscales produce qualitatively different functions; meta-learning extracts the shared kernel-family structure.
Two 5-way prototypical tasks with distinct class-mean placements on a 2D disk.
Figure 4. 2D prototypical task family: two tasks with distinct class means. Class identity changes from task to task; the embedding network is invariant.

1.5 Roadmap

Sections §2 and §3 develop the gradient lens. §2 derives MAML in full, including the second-order Hessian-vector product. §3 covers the cheaper first-order approximations — FOMAML, Reptile, and Implicit MAML. Section §4 develops the Bayesian lens via Neural Processes. Section §5 develops the metric lens via Prototypical Networks.

Section §6 unifies the three lenses via hierarchical Bayes — Grant et al.’s recast of MAML, empirical Bayes, and the partial-pooling story from formalStatistics: hierarchical-bayes-and-partial-pooling §28. Sections §7 and §8 are the theory: §7 lifts McAllester PAC-Bayes from PAC-Bayes Bounds to the task-distribution level (the Amit–Meir bound), and §8 gives the Fallah–Mokhtari–Ozdaglar convergence theorem.

Section §9 runs cross-lens comparisons with sample-efficiency curves and ablations. Section §10 collects PyTorch idioms readers will hit if they reimplement, and §11 places meta-learning in the broader curriculum.


§2. Gradient-based meta-learning: MAML

2.1 The bilevel meta-objective

Finn, Abbeel, and Levine’s 2017 proposal commits to a specific choice of the adaptation procedure A\mathcal{A} from §1.2: it is NN steps of stochastic gradient descent on the support loss, starting from a shared meta-parameter θ0\theta_0. Write the inner trajectory as

θT(0)=θ0,θT(n)=θT(n1)αθLTS ⁣(θT(n1))for n=1,,N.(2.1)\theta_{\mathcal{T}}^{(0)} = \theta_0, \qquad \theta_{\mathcal{T}}^{(n)} = \theta_{\mathcal{T}}^{(n-1)} - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}}^{\mathrm{S}}\!\left(\theta_{\mathcal{T}}^{(n-1)}\right) \quad\text{for } n = 1,\ldots,N. \qquad\quad (2.1)

Here α>0\alpha > 0 is the inner-loop step size, LTS\mathcal{L}_\mathcal{T}^{\mathrm{S}} is the support loss for task T\mathcal{T}, and the superscript (n)(n) counts inner steps.

The MAML meta-objective is the expected query loss after running this trajectory:

LMAML(θ0)  =  ETp(T) ⁣[LTQ ⁣(θT(N)(θ0))].(2.2)\mathcal{L}_{\mathrm{MAML}}(\theta_0) \;=\; \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})}\!\left[\,\mathcal{L}_{\mathcal{T}}^{\mathrm{Q}}\!\big(\theta_{\mathcal{T}}^{(N)}(\theta_0)\big)\,\right]. \qquad\quad (2.2)

Optimizing (2.2) over θ0\theta_0 is a bilevel problem: the outer level picks the initialization, the inner level runs SGD from that initialization, and the outer level’s gradient must account for how the entire inner trajectory changes when θ0\theta_0 shifts.

In practice we Monte-Carlo the outer expectation: at each meta-iteration we draw a task batch B={T1,,TB}\mathcal{B} = \{\mathcal{T}_1, \ldots, \mathcal{T}_B\}, run the inner trajectory for each task, and take the empirical mean of query losses. The outer optimizer is typically Adam with meta-learning rate β\beta:

θ0    θ0β1Bk=1Bθ0LTkQ ⁣(θTk(N)(θ0)).(2.3)\theta_0 \;\leftarrow\; \theta_0 - \beta \cdot \frac{1}{B}\sum_{k=1}^{B} \nabla_{\theta_0} \mathcal{L}_{\mathcal{T}_k}^{\mathrm{Q}}\!\big(\theta_{\mathcal{T}_k}^{(N)}(\theta_0)\big). \qquad\quad (2.3)

2.2 Inner loop: task-specific SGD on the support set

The inner loop has nothing exotic in it. The SGD recursion is unchanged from ordinary supervised learning — what changes is what we optimize θ0\theta_0 to be. Two notes: the inner loop is cheap (NN forward and NN backward passes per task on the KK-shot support set, with NN typically 1–5 and KK typically 5), and we do not detach the inner trajectory from the computation graph. Every θT(n)\theta_{\mathcal{T}}^{(n)} remains a differentiable function of θ0\theta_0. The outer loop will need that. The canonical MAML paper uses N=1N = 1 for regression and N{5,10}N \in \{5, 10\} for classification.

2.3 Outer loop: differentiating through the trajectory

Hold a single task fixed and drop the subscript T\mathcal{T}. We want θ0LQ(θ(N)(θ0))\nabla_{\theta_0} \mathcal{L}^{\mathrm{Q}}(\theta^{(N)}(\theta_0)).

The chain rule on (2.1) gives the inner Jacobian. From θ(n)=θ(n1)αLS(θ(n1))\theta^{(n)} = \theta^{(n-1)} - \alpha \nabla \mathcal{L}^{\mathrm{S}}(\theta^{(n-1)}), differentiating with respect to θ0\theta_0,

θ(n)θ0  =  (Iα2LS ⁣(θ(n1)))θ(n1)θ0.\frac{\partial \theta^{(n)}}{\partial \theta_0} \;=\; \Big( I - \alpha \nabla^2 \mathcal{L}^{\mathrm{S}}\!\big(\theta^{(n-1)}\big) \Big) \frac{\partial \theta^{(n-1)}}{\partial \theta_0}.

Iterating from θ(0)/θ0=I\partial \theta^{(0)} / \partial \theta_0 = I,

θ(N)θ0  =  n=1N(Iα2LS ⁣(θ(n1))).(2.4)\frac{\partial \theta^{(N)}}{\partial \theta_0} \;=\; \prod_{n=1}^{N} \Big( I - \alpha \nabla^2 \mathcal{L}^{\mathrm{S}}\!\big(\theta^{(n-1)}\big) \Big). \qquad\quad (2.4)

The outer gradient follows by chain rule on the query loss:

θ0LQ ⁣(θ(N)(θ0))  =  n=N1 ⁣(Iα2LS ⁣(θ(n1)))LQ ⁣(θ(N)).(2.5)\nabla_{\theta_0} \mathcal{L}^{\mathrm{Q}}\!\big(\theta^{(N)}(\theta_0)\big) \;=\; \prod_{n=N}^{1}\!\Big( I - \alpha \nabla^2 \mathcal{L}^{\mathrm{S}}\!\big(\theta^{(n-1)}\big)\Big) \cdot \nabla \mathcal{L}^{\mathrm{Q}}\!\big(\theta^{(N)}\big). \qquad\quad (2.5)

Specializing to N=1N=1:

θ0LQ ⁣(θ(1)(θ0))  =  (Iα2LS(θ0))LQ ⁣(θ0αLS(θ0)).(2.6)\nabla_{\theta_0} \mathcal{L}^{\mathrm{Q}}\!\big(\theta^{(1)}(\theta_0)\big) \;=\; \Big( I - \alpha \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0) \Big) \, \nabla \mathcal{L}^{\mathrm{Q}}\!\big(\theta_0 - \alpha \nabla \mathcal{L}^{\mathrm{S}}(\theta_0)\big). \qquad\quad (2.6)

Equation (2.6) is the most-quoted equation in the meta-learning literature. The Hessian of the support loss appears on the outer-loop gradient because θ0\theta_0 moves θ(1)\theta^{(1)} through the inner gradient of LS\mathcal{L}^{\mathrm{S}}.

Three readings of (2.6): as a corrector — the II piece is what FOMAML keeps, the α2LS-\alpha \nabla^2 \mathcal{L}^{\mathrm{S}} piece is the second-order correction; geometricallyIαHI - \alpha H is the linearization of the inner step, and (2.6) is the pullback; computationally — the HVP 2LSLQ\nabla^2 \mathcal{L}^{\mathrm{S}} \nabla \mathcal{L}^{\mathrm{Q}} is the only second-order quantity that appears, never the full Hessian.

2.4 The Hessian-vector product and its computational cost

The MAML outer gradient (2.6) involves Iα2LS(θ0)I - \alpha \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0) acting on LQ(θ(1))\nabla \mathcal{L}^{\mathrm{Q}}(\theta^{(1)}). We never form the Hessian. The Pearlmutter (1994) identity makes the HVP cheap:

θ2f(θ)v  =  θ ⁣[θf(θ)v].(2.7)\nabla^2_\theta f(\theta) \, \mathbf{v} \;=\; \nabla_\theta\!\left[\nabla_\theta f(\theta)^{\top} \mathbf{v}\right]. \qquad\quad (2.7)

Form θf\nabla_\theta f, dot with the fixed vector v\mathbf{v}, then differentiate the scalar. Modern autograd implements (2.7) automatically through nested gradient calls. PyTorch’s torch.func.grad and torch.autograd.grad(..., create_graph=True) both expose this.

Cost ledger for one MAML outer step per task with NN inner steps:

NTgradS+TgradQ+NTHVPS    (3N+1)TgradN \cdot T_{\mathrm{grad}}^{\mathrm{S}} + T_{\mathrm{grad}}^{\mathrm{Q}} + N \cdot T_{\mathrm{HVP}}^{\mathrm{S}} \;\approx\; (3N + 1) \cdot T_{\mathrm{grad}}

For N=5N = 5, that’s 16×16\times the cost of plain SGD per task per meta-iter. Multiply by task batch size BB; torch.func.vmap parallelizes the task dimension. Practical consequences: small α\alpha is good for cost-bound regimes, and N5N \le 5 is the practical sweet spot where adaptation saturates.

2.5 Sinusoidal regression: MAML in action

Implementation uses three pieces of torch.func: functional_call expresses the forward pass as a pure function of (params, x); grad computes the inner-loop gradient in a vmap-compatible way; vmap parallelizes the per-task adaptation. The outer loss.backward() unwinds the second-order chain through the functional inner gradients.

from torch.func import functional_call, vmap, grad

def loss_fn(params, x, y):
    pred = functional_call(meta_model, params, (x,))
    return F.mse_loss(pred, y)

def adapt(params, support_x, support_y, alpha=0.01, n_inner=5):
    for _ in range(n_inner):
        g = grad(loss_fn)(params, support_x, support_y)
        params = {k: p - alpha * g[k] for k, p in params.items()}
    return params

def task_meta_loss(params, sx, sy, qx, qy, alpha, n_inner):
    adapted = adapt(params, sx, sy, alpha, n_inner)
    return loss_fn(adapted, qx, qy)

batched_meta_loss = vmap(task_meta_loss, in_dims=(None, 0, 0, 0, 0, None, None))

Training: B=4B = 4 tasks per meta-iter, N=5N = 5 inner steps, α=0.01\alpha = 0.01 inner-loop SGD rate, β=103\beta = 10^{-3} outer-loop Adam rate, 100 meta-iters. Under 60 s of CPU time. The notebook prints the meta-loss at iters 0, 50, 99 — values approximately 4.464.46, 5.465.46, 4.154.15. The meta-loss curve oscillates rather than dropping monotonically over 100 iters; the adaptation-evolution figure below is the qualitatively stronger evidence that the meta-trained θ0\theta_0 has acquired family structure.

MAML meta-training loss curve, log-scaled, over 100 meta-iters. Curve fluctuates in the 4–6 range.
Figure 5. MAML on sinusoidal regression — meta-training curve, log scale. Loss fluctuates without strong monotonic decrease at this short horizon; the across-task averaging at B=4 produces high variance.
Loading precomputed meta-loss curves…

Left: meta-training loss curve at the selected (α, N). Right: held-out task adaptation at the default (α=0.01, N=5). Across the 3×3 grid, the curve shape stays bounded and doesn't show strong monotonic descent — the topic's "convergence" claim is qualitative across this hyperparameter range.

A held-out sinusoidal task with four panels: predictions at n=0, n=1, n=3, n=5 inner steps. At n=0 the predictions are near-zero; each inner step pulls predictions toward the true sinusoid.
Figure 6. A held-out task at A=3.2, φ=1.7. Four panels show predictions at n=0, 1, 3, 5 inner steps. The meta-trained θ₀ predicts approximately the zero function (conditional mean over the task family); five inner steps recover the sinusoid.

The adaptation-evolution figure reproduces Finn–Abbeel–Levine’s Figure 2: θ0\theta_0 predicts approximately the zero function — the conditional mean EA,ϕ[Asin(x+ϕ)]\mathbb{E}_{A,\phi}[A \sin(x + \phi)] over uniform ϕ\phi is zero — and each inner step pulls predictions toward the new task’s sinusoid. Compare with §1’s Figure 1: 200 SGD steps from random init memorized 5 points and failed; meta-trained θ0\theta_0 achieves a better fit in 5 steps. The difference is the initialization, not the algorithm.


§3. First-order approximations and implicit gradients

3.1 FOMAML: dropping the second-order term

The full MAML outer gradient (2.6) has two pieces: LQ(θ(1))\nabla \mathcal{L}^{\mathrm{Q}}(\theta^{(1)}) and α2LS(θ0)LQ(θ(1))-\alpha \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0)\,\nabla \mathcal{L}^{\mathrm{Q}}(\theta^{(1)}). FOMAML drops the second piece:

θ0FOMAMLLQ  :=  θLQ(θ)θ=θ(N).(3.1)\nabla_{\theta_0}^{\mathrm{FOMAML}} \mathcal{L}^{\mathrm{Q}} \;:=\; \nabla_\theta \mathcal{L}^{\mathrm{Q}}(\theta) \Big|_{\theta = \theta^{(N)}}. \qquad\quad (3.1)

Run the inner loop as usual; take the gradient at the adapted point; use it as the meta-gradient. Per-task per-meta-iter cost drops from (3N+1)Tgrad(3N+1)\,T_{\mathrm{grad}} to (N+1)Tgrad(N+1)\,T_{\mathrm{grad}} — a 2.7× speedup at N=5N=5.

The FOMAML meta-gradient is a biased estimate of the true MAML meta-gradient:

θ0MAMLθ0FOMAML  =  α2LS(θ0)LQ(θ(1)).(3.2)\nabla_{\theta_0}^{\mathrm{MAML}} - \nabla_{\theta_0}^{\mathrm{FOMAML}} \;=\; -\alpha \, \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0)\, \nabla \mathcal{L}^{\mathrm{Q}}(\theta^{(1)}). \qquad\quad (3.2)

The bias is linear in α\alpha, in 2LS\lVert\nabla^2 \mathcal{L}^{\mathrm{S}}\rVert, and in LQ\lVert\nabla \mathcal{L}^{\mathrm{Q}}\rVert. Small α\alpha — the standard choice — keeps the bias small. Finn et al. report on miniImageNet that FOMAML achieves nearly the same final accuracy as full MAML, but on inner-loop landscapes with sharp curvature, FOMAML’s plateau lies meaningfully above MAML’s.

3.2 Reptile: implicit meta-gradient via repeated inner SGD

Nichol, Achiam, and Schulman (2018) propose a different shortcut. Run inner-loop SGD for NN steps and use the direction of travel as the meta-update:

θ0    θ0+β(θ(N)θ0).(3.3)\theta_0 \;\leftarrow\; \theta_0 + \beta \big(\theta^{(N)} - \theta_0\big). \qquad\quad (3.3)

No query loss, no second-order term, and canonically no support/query split. For N=1N=1, this reduces to plain SGD on expected support loss. The interesting regime is N2N \ge 2.

For N=2N=2, Taylor-expand the second gradient around θ0\theta_0:

θ(2)θ0  =  2αLS(θ0)+α22LS(θ0)LS(θ0)+O(α3).(3.4)\theta^{(2)} - \theta_0 \;=\; -2\alpha \nabla \mathcal{L}^{\mathrm{S}}(\theta_0) + \alpha^2 \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0)\, \nabla \mathcal{L}^{\mathrm{S}}(\theta_0) + O(\alpha^3). \qquad\quad (3.4)

Taking the expectation over tasks, Nichol et al. work out the general-NN case and show that the leading-order expected Reptile update is the gradient of

ET ⁣[LTS(θ0)]    α(N1)2ET ⁣[LTS(θ0)2]+O(α2).(3.5)\mathbb{E}_\mathcal{T}\!\left[\mathcal{L}^{\mathrm{S}}_\mathcal{T}(\theta_0)\right] \;-\; \tfrac{\alpha(N-1)}{2}\, \mathbb{E}_\mathcal{T}\!\left[\big\lVert\nabla \mathcal{L}^{\mathrm{S}}_\mathcal{T}(\theta_0)\big\rVert^2\right] + O(\alpha^2). \qquad\quad (3.5)

The second term rewards θ0\theta_0 for being in a region where task-specific gradients are small in mean and aligned across tasks (via EL2=EL2+trCov(L)\mathbb{E}\lVert\nabla \mathcal{L}\rVert^2 = \lVert\mathbb{E}\nabla\mathcal{L}\rVert^2 + \mathrm{tr}\,\mathrm{Cov}(\nabla \mathcal{L})). This is what makes Reptile a meta-learning algorithm despite its first-order simplicity.

3.3 The bias of first-order approximations

For FOMAML, the bias is (3.2) verbatim: α2LS(θ0)LQ(θ(1))-\alpha \, \nabla^2 \mathcal{L}^{\mathrm{S}}(\theta_0)\, \nabla \mathcal{L}^{\mathrm{Q}}(\theta^{(1)}). For Reptile, the comparison is harder because Reptile doesn’t target the MAML objective — it targets the implicit objective (3.5).

Three takeaways. FOMAML is the natural drop-in replacement for MAML — same support/query split, same per-task loss structure, bias bounded by α2LSLQ\alpha\, \lVert\nabla^2 \mathcal{L}^{\mathrm{S}}\rVert\, \lVert\nabla \mathcal{L}^{\mathrm{Q}}\rVert. Reptile is doing something different — no query loss, the implicit objective rewards gradient alignment rather than minimizing post-adaptation loss directly. Computational savings are real and the empirical price is often small — both methods at ~1/3 the compute of full MAML, with near-MAML accuracy on standard benchmarks.

3.4 Implicit MAML and the inverse-Hessian-vector product

Rajeswaran, Finn, Kakade, and Levine (2019) avoid unrolling the inner trajectory altogether. Replace the NN-step SGD with an exact regularized minimizer:

θ(θ0)  :=  argminθ  LS(θ)+λ2θθ02.(3.6)\theta^{*}(\theta_0) \;:=\; \arg\min_\theta\; \mathcal{L}^{\mathrm{S}}(\theta) + \frac{\lambda}{2}\,\lVert\theta - \theta_0\rVert^2. \qquad\quad (3.6)

The regularization λ>0\lambda > 0 pulls θ\theta^* toward θ0\theta_0 and plays the role of 1/(Nα)1/(N\alpha) in standard MAML. Memory is O(1)O(1) in the inner-loop count rather than O(Nθ)O(N \cdot |\theta|) — the trajectory graph is never stored. The cost is the linear-system solve derived next.

3.5 The implicit-function-theorem derivation

The first-order optimality condition at θ\theta^*:

LS(θ)+λ(θθ0)  =  0.(3.7)\nabla \mathcal{L}^{\mathrm{S}}(\theta^{*}) + \lambda (\theta^{*} - \theta_0) \;=\; 0. \qquad\quad (3.7)

The implicit function theorem (see formalCalculus: inverse-implicit ) says that when 2LS(θ)+λI\nabla^2 \mathcal{L}^{\mathrm{S}}(\theta^{*}) + \lambda I is invertible (guaranteed for λ>λmin(2LS)\lambda > -\lambda_{\min}(\nabla^2 \mathcal{L}^{\mathrm{S}})), θ\theta^{*} is a differentiable function of θ0\theta_0. Differentiating (3.7),

θθ0  =  λ(2LS(θ)+λI)1.(3.8)\frac{\partial \theta^{*}}{\partial \theta_0} \;=\; \lambda \big(\nabla^2 \mathcal{L}^{\mathrm{S}}(\theta^{*}) + \lambda I\big)^{-1}. \qquad\quad (3.8)

The meta-gradient by chain rule:

θ0LQ ⁣(θ(θ0))  =  λ(2LS(θ)+λI)1LQ(θ).(3.9)\nabla_{\theta_0} \mathcal{L}^{\mathrm{Q}}\!\big(\theta^{*}(\theta_0)\big) \;=\; \lambda \big(\nabla^2 \mathcal{L}^{\mathrm{S}}(\theta^{*}) + \lambda I\big)^{-1} \nabla \mathcal{L}^{\mathrm{Q}}(\theta^{*}). \qquad\quad (3.9)

This is an inverse-Hessian-vector product: solve (2LS(θ)+λI)x=LQ(θ)(\nabla^2 \mathcal{L}^{\mathrm{S}}(\theta^{*}) + \lambda I) x = \nabla \mathcal{L}^{\mathrm{Q}}(\theta^{*}) via conjugate gradients using HVPs. CG converges in O~(κ)\tilde{O}(\sqrt{\kappa}) iterations where κ\kappa is the condition number; the λI\lambda I shift improves conditioning.

MAML (NN-step)Implicit MAML
Inner loopNN SGD stepsExact minimizer of (3.6)
MemoryO(Nθ)O(N \cdot \lvert\theta\rvert)O(θ)O(\lvert\theta\rvert)
Outer computeNN HVPs ⁣κ\sim\!\sqrt{\kappa} HVPs (CG)
BiasTruncation from finite NNNone at the fixed point
Best forFew-step adaptationLong-horizon adaptation

For our small sinusoidal demo, MAML with N=5N=5 is the right choice. Implicit MAML earns its keep at long horizons (RL fine-tuning, large-model meta-tuning) where unrolling becomes the binding constraint.

3.6 Head-to-head on sinusoidal regression

The notebook implements MAML (from §2), FOMAML, and Reptile on the same sinusoidal task distribution with matched hyperparameters. Implicit MAML is skipped at code level because the CG inner solver dominates the runtime budget on a problem where standard MAML already finishes in seconds.

At α=0.01\alpha = 0.01, N=5N = 5, B=4B = 4, the three methods sit in nearby regions of meta-loss space. Notebook outputs: MAML 4.46 → 4.15; FOMAML 7.09 → 5.33; Reptile 3.52 → 3.21. Reptile reaches the lowest absolute meta-loss at this short horizon — but Reptile’s meta-loss is computed on the union of support and query (a different quantity from MAML’s pure post-adaptation query loss), so the absolute values aren’t directly comparable across methods. What the head-to-head figure does show is that all three trajectories sit close to each other in the log-meta-loss range [3,7][3, 7], consistent with the first-order theory’s claim that the FOMAML bias is small at small α\alpha.

Three meta-training loss curves on log axes: MAML, FOMAML, and Reptile.
Figure 7. Head-to-head: MAML / FOMAML / Reptile on sinusoidal regression. The three trajectories cluster in nearby regions of log-meta-loss space. Curve shape differs because Reptile optimizes the implicit objective (3.5) rather than MAML's (2.2).

§4. Bayesian meta-learning: Neural Processes

4.1 Tasks as samples from a stochastic process

In §§2–3, each task was a parametric optimization problem solved by SGD from a learned initialization. The Bayesian lens reframes the same setup at a higher level of abstraction: each task is a function, drawn from a stochastic process, and the inference procedure is amortized Bayesian regression. The meta-parameters are the weights of an inference network that reads a context set and emits a posterior over predictions at any target location.

A stochastic process is a distribution over functions f ⁣:XYf \colon \mathcal{X} \to \mathcal{Y}. The canonical example is the Gaussian process (formalML’s Gaussian Processes notebook). For Bayesian meta-learning, the task distribution p(T)p(\mathcal{T}) is a stochastic process PP: each task TP\mathcal{T} \sim P is realized as a function fTf_\mathcal{T}, the support set is the context, and the query set is the target.

For a GP context-target task with known kernel, exact GP regression is optimal. Meta-learning’s interesting regime: the kernel is unknown, or the task distribution is not a GP, and we want a learned inference procedure that takes a context set of any size, produces a posterior predictive at any target location, is order-invariant in the context, and amortizes inference across tasks.

The Neural Process family (Garnelo et al. 2018a, 2018b; Kim et al. 2019) commits to two architectural choices. Permutation invariance over the context: per-point embeddings hi=enc(xi,yi)h_i = \mathrm{enc}(x_i, y_i) are pooled through a permutation-invariant aggregator — canonically the mean,

r=1NCi=1NChi.(4.1)r = \frac{1}{N_C} \sum_{i=1}^{N_C} h_i. \qquad\quad (4.1)

Parameter sharing across context sizes: the encoder applies independently to each context point, then the aggregator collapses to a fixed-dimensional representation regardless of cardinality.

4.2 The conditional Neural Process: deterministic context aggregation

The Conditional Neural Process (Garnelo et al. 2018a) is the minimal version. Encoder hi=encθ(xi,yi)h_i = \mathrm{enc}_\theta(x_i, y_i); aggregator r=(1/NC)ihir = (1/N_C) \sum_i h_i; decoder (μ(x),σ(x))=decθ(x,r)(\mu(x_*), \sigma(x_*)) = \mathrm{dec}_\theta(x_*, r). The predictive is Gaussian per target: p(yx,ctx)=N(μ(x),σ(x)2)p(y_* \mid x_*, \mathrm{ctx}) = \mathcal{N}(\mu(x_*), \sigma(x_*)^2).

Training objective:

LCNP(θ)  =  ETEC,TT(x,y)TlogN ⁣(y;μθ(x;C),σθ(x;C)2).(4.2)\mathcal{L}_{\mathrm{CNP}}(\theta) \;=\; \mathbb{E}_\mathcal{T}\, \mathbb{E}_{C, T \subset \mathcal{T}}\, \sum_{(x_*, y_*) \in T} \log \mathcal{N}\!\left(y_*;\, \mu_\theta(x_*; C),\, \sigma_\theta(x_*; C)^2\right). \qquad\quad (4.2)

CNP limitations: the predictive at each target is independent (no posterior correlation between target locations, missing the joint GP structure), and the model is deterministic given the context (no posterior uncertainty over the latent function). The Latent NP addresses both.

4.3 The latent Neural Process and the variational lower bound

The Latent NP (Garnelo et al. 2018b) introduces a global latent zRdzz \in \mathbb{R}^{d_z}:

zp(z),yixi,zpθ(yixi,z).(4.3)z \sim p(z), \qquad y_i \mid x_i, z \sim p_\theta(y_i \mid x_i, z). \qquad\quad (4.3)

Inference: an encoder qϕ(zD)q_\phi(z \mid D) takes any dataset and outputs Gaussian posterior parameters. The same encoder is used for qϕ(zDC)q_\phi(z \mid D_C) and qϕ(zDCDT)q_\phi(z \mid D_C \cup D_T) — different subsets.

Theorem 1 (Latent NP evidence lower bound).

Under the latent-variable model (4.3) with encoder qϕq_\phi,

logpθ(yTxT,DC)    Ezqϕ(zDCDT) ⁣[(x,y)DTlogpθ(yx,z)]KL ⁣(qϕ(zDCDT)qϕ(zDC)).(4.4)\log p_\theta(y_T \mid x_T, D_C) \;\geq\; \mathbb{E}_{z \sim q_\phi(z \mid D_C \cup D_T)}\!\left[\sum_{(x_*, y_*) \in D_T} \log p_\theta(y_* \mid x_*, z)\right] - \mathrm{KL}\!\left(q_\phi(z \mid D_C \cup D_T) \,\big\|\, q_\phi(z \mid D_C)\right). \qquad\quad (4.4)
Proof.

Start with the log-evidence and introduce the variational distribution multiplicatively:

logpθ(yTxT,DC)  =  logpθ(yTxT,z)pθ(zDC)dz.\log p_\theta(y_T \mid x_T, D_C) \;=\; \log \int p_\theta(y_T \mid x_T, z) \, p_\theta(z \mid D_C) \, dz.

Multiply and divide by qϕ(zDCDT)q_\phi(z \mid D_C \cup D_T) and apply Jensen on log\log:

qϕ(zDCDT)logpθ(yTxT,z)pθ(zDC)qϕ(zDCDT)dz.\geq \int q_\phi(z \mid D_C \cup D_T) \log \frac{p_\theta(y_T \mid x_T, z)\, p_\theta(z \mid D_C)}{q_\phi(z \mid D_C \cup D_T)} \, dz.

Split into reconstruction + log-prior - log-q terms; combine the last two into a KL:

=Eq(zDCDT)[logpθ(yTxT,z)]KL ⁣(qϕ(zDCDT)pθ(zDC)).= \mathbb{E}_{q(z|D_C \cup D_T)}[\log p_\theta(y_T \mid x_T, z)] - \mathrm{KL}\!\left(q_\phi(z|D_C \cup D_T) \,\|\, p_\theta(z \mid D_C)\right).

This is the standard ELBO with the inconvenient feature that pθ(zDC)p_\theta(z \mid D_C) is intractable. The Neural Process substitutes qϕ(zDC)q_\phi(z \mid D_C) for pθ(zDC)p_\theta(z \mid D_C) in the KL:

logpθ(yTxT,DC)    Eqϕ(zDCDT)[logpθ(yTxT,z)]KL ⁣(qϕ(zDCDT)qϕ(zDC)).\log p_\theta(y_T \mid x_T, D_C) \;\geq\; \mathbb{E}_{q_\phi(z \mid D_C \cup D_T)}[\log p_\theta(y_T \mid x_T, z)] - \mathrm{KL}\!\left(q_\phi(z \mid D_C \cup D_T) \,\|\, q_\phi(z \mid D_C)\right).

The substitution turns the lower bound on pθ(yTxT,DC)p_\theta(y_T \mid x_T, D_C) into a lower bound on the model’s implied predictive pθ(yTxT,z)qϕ(zDC)dz\int p_\theta(y_T \mid x_T, z)\, q_\phi(z \mid D_C)\, dz — the operationally correct quantity at inference time.

Reading the bound: reconstruction rewards predicting target outputs well using a latent inferred from both context and target; KL penalizes the gap between context-augmented-with-target and context-only posteriors, pushing predictions whose latent representation is mostly determined by context.

4.4 Attention-based context aggregation

Mean-pooling (4.1) assigns equal weight regardless of where target xx_* sits. The Attentive Neural Process (Kim et al. 2019) replaces the mean pool with cross-attention from targets to context:

r(x)  =  i=1NCsoftmaxi ⁣(q(x)kid)vi,r(x_*) \;=\; \sum_{i=1}^{N_C} \mathrm{softmax}_i\!\left(\frac{q(x_*)^\top k_i}{\sqrt{d}}\right) v_i,

where q(x)=WQxq(x_*) = W_Q x_*, ki=WKxik_i = W_K x_i, vi=WVhiv_i = W_V h_i. The effect: local-context fidelity comparable to GP regression. The ANP combines both forms (attention r(x)r(x_*) + mean-pool zz). We don’t implement ANP in the notebook — the multi-head attention would push the budget — but the architecture is straightforward from this template.

4.5 1D GP few-shot demo

We train CNP and Latent NP on tasks drawn from a GP with random RBF lengthscale Uniform(0.5,2.0)\ell \sim \mathrm{Uniform}(0.5, 2.0), comparing predictions on a held-out task against the exact GP posterior (closed-form NumPy, reused from gaussian-processes).

Architecture: 2-layer MLP encoder (hidden 64, embedding 64), 2-layer MLP decoder. Latent NP has zz dimension 16. Training: 1000 epochs, one task per epoch, Adam at 10310^{-3}. CPU runtime ~15 s.

Notebook outputs at our scale and horizon: CNP NLL 31.92 → 42.91 over 1000 epochs; Latent NP -ELBO 31.37 → 43.91. Both increase across training — at this single-task-per-epoch budget, neither architecture’s training objective is monotonically minimized as a function of epoch count. The Latent NP additionally has the -ELBO’s KL term contributing to the upward drift. The figures show smoothed curves that make the underlying noisy training dynamics visible.

The held-out comparison below is the stronger evidence that the trained NPs have learned something useful: both predictive means interpolate the context well and grow uncertainty away from context regions, qualitatively matching the exact GP posterior even though the NPs don’t know the kernel. The training dynamics warn that NPs at the scale Garnelo et al. originally trained at (multiple tasks per batch, longer horizons, larger models) are needed for the loss curves to drop cleanly.

Two panels: CNP NLL and Latent NP negative-ELBO training curves over 1000 epochs. Both curves rise across training; smoothed line makes the trend visible above the per-iter noise.
Figure 8. CNP and Latent NP training curves. At one task per epoch and the notebook's small architectures, neither curve descends monotonically — the curves trend upward. The held-out evaluation (Figure 9) is the substantive evidence of meta-learning.
Loading…
Three side-by-side panels on a held-out GP task with 8 context points: exact GP posterior with known lengthscale, CNP predictive, Latent NP predictive with 20 sampled latents.
Figure 9. Held-out GP task (ℓ = 1.0, 8 context points). The trained NPs interpolate context well and grow predictive uncertainty away from context regions, qualitatively matching the exact GP posterior with known ℓ.

§5. Metric-learning meta-learning: Prototypical Networks

5.1 The few-shot classification objective

The third lens drops both the gradient-based inner loop and the variational posterior, replacing them with the simplest possible adaptation: compute class prototypes from the support set, classify queries by nearest prototype. The meta-learner’s only learnable component is the embedding network.

The setup is K-way N-shot classification. Each task T\mathcal{T} is a KK-class classification problem. The support set has NN labeled examples per class, for KNKN total; the query set has MqM_q examples per class drawn from the same KK classes. Writing the support set partitioned by class as S=k=1KSk\mathcal{S} = \bigsqcup_{k=1}^{K} \mathcal{S}_k,

Lproto(θ0)  =  ETp(T)E(x,y)TQ ⁣[logpθ0(yx,ST)].(5.1)\mathcal{L}_{\mathrm{proto}}(\theta_0) \;=\; \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})}\, \mathbb{E}_{(\mathbf{x}_*, y_*) \in \mathcal{T}^{\mathrm{Q}}}\!\left[-\log p_{\theta_0}(y_* \mid \mathbf{x}_*, \mathcal{S}_\mathcal{T})\right]. \qquad\quad (5.1)

Standard few-shot vision benchmarks (Omniglot 5-way 5-shot, miniImageNet) episode the task distribution by sampling KK classes from a large catalog. For our notebook we replace this with the 2D Gaussian-cluster family from §1.4 — same architecture and objective, CPU-tractable.

5.2 Embeddings and class prototypes

The meta-parameter is an embedding network ϕθ0 ⁣:XRd\phi_{\theta_0} \colon \mathcal{X} \to \mathbb{R}^d. The prototype for class kk is the mean of the support embeddings for that class:

ck  =  1Sk(x,k)Skϕθ0(x)    Rd.(5.2)\mathbf{c}_k \;=\; \frac{1}{|\mathcal{S}_k|} \sum_{(\mathbf{x}, k) \in \mathcal{S}_k} \phi_{\theta_0}(\mathbf{x}) \;\in\; \mathbb{R}^d. \qquad\quad (5.2)

Two structural notes. The prototype depends on the support set, not the query set — the support/query separation again. Computing prototypes is one forward pass + class-wise mean — no inner-loop SGD, no HVP, no variational sampling. The cheapest adaptation procedure of the three lenses.

The meta-learner doesn’t learn the prototypes; it learns ϕθ0\phi_{\theta_0}. Prototypes are derived quantities per-task. Meta-train over many tasks → ϕθ0\phi_{\theta_0} learns to map class instances to compact non-overlapping regions of Rd\mathbb{R}^d within any task drawn from p(T)p(\mathcal{T}).

5.3 Nearest-prototype assignment as softmax over distances

The predictive is a softmax of negative squared distances:

pθ0(y=kx,S)  =  exp ⁣(d(ϕθ0(x),ck))k=1Kexp ⁣(d(ϕθ0(x),ck)),(5.3)p_{\theta_0}(y_* = k \mid \mathbf{x}_*, \mathcal{S}) \;=\; \frac{\exp\!\big(-d(\phi_{\theta_0}(\mathbf{x}_*), \mathbf{c}_k)\big)}{\sum_{k'=1}^{K} \exp\!\big(-d(\phi_{\theta_0}(\mathbf{x}_*), \mathbf{c}_{k'})\big)}, \qquad\quad (5.3)

with d(ϕ(x),ck)=ϕ(x)ck2d(\phi(\mathbf{x}), \mathbf{c}_k) = \lVert\phi(\mathbf{x}) - \mathbf{c}_k\rVert^2.

The squared-Euclidean choice has a clean justification through Bregman divergences. A Bregman divergence dψd_\psi generated by strictly convex ψ\psi is dψ(u,v)=ψ(u)ψ(v)ψ(v),uvd_\psi(u, v) = \psi(u) - \psi(v) - \langle \nabla\psi(v), u - v\rangle. Banerjee, Merugu, Dhillon, and Ghosh (2005) prove that for any Bregman divergence, the cluster mean is the unique minimizer of within-cluster sum-of-divergences: argmincidψ(ϕi,c)=(1/n)iϕi\mathrm{argmin}_{\mathbf{c}} \sum_i d_\psi(\phi_i, \mathbf{c}) = (1/n)\sum_i \phi_i. Squared Euclidean is the Bregman divergence from ψ(u)=12u2\psi(\mathbf{u}) = \tfrac{1}{2}\lVert\mathbf{u}\rVert^2, so the prototype definition (5.2) is not arbitrary: the mean is the Bregman-Bayes-optimal class center under squared Euclidean.

Cross-entropy training loss per task:

LprotoT(θ0)  =  (x,y)TQ ⁣logexp ⁣(ϕθ0(x)cy2)k=1K ⁣exp ⁣(ϕθ0(x)ck2).(5.4)\mathcal{L}_{\mathrm{proto}}^\mathcal{T}(\theta_0) \;=\; -\sum_{(\mathbf{x}_*, y_*) \in \mathcal{T}^{\mathrm{Q}}}\!\log\frac{\exp\!\big(-\lVert\phi_{\theta_0}(\mathbf{x}_*) - \mathbf{c}_{y_*}\rVert^2\big)}{\sum_{k'=1}^{K}\!\exp\!\big(-\lVert\phi_{\theta_0}(\mathbf{x}_*) - \mathbf{c}_{k'}\rVert^2\big)}. \qquad\quad (5.4)

Gradient flows from cross-entropy through both query and support embeddings (the prototypes depend on support embeddings via (5.2)).

Remark (Prototypical Networks as Gaussian classifiers).

If class-conditional embeddings are isotropic Gaussians p(ϕy=k)=N(ck,I)p(\phi \mid y = k) = \mathcal{N}(\mathbf{c}_k, I), Bayes’ rule with equal class priors gives exactly the softmax-over-squared-distances rule (5.3). Prototypical networks are meta-trained Gaussian classifiers in a learned embedding space.

5.4 Matching networks, prototypical networks, and contrastive learning

Matching Networks (Vinyals et al. 2016) use cosine attention over all support examples rather than class means: p(y=k)=isoftmaxi(cos(ϕ(x),ϕ(xi)))1[yi=k]p(y_* = k) = \sum_i \mathrm{softmax}_i(\cos(\phi(\mathbf{x}_*), \phi(\mathbf{x}_i))) \mathbf{1}[y_i = k]. Snell et al. argue prototypical networks are theoretically cleaner (Bregman justification for the mean); empirically the two are comparable.

Contrastive learning. The prototypical objective (5.4) is structurally identical to NT-Xent (Sohn 2016) and InfoNCE (van den Oord et al. 2018) — softmax over distances, push same-class together, push different-class apart. Khosla et al.’s Supervised Contrastive Learning (2020) makes this explicit. Prototypical networks are the few-shot specialization of supervised contrastive learning.

This connection matters for §6: the metric lens is contrastive representation learning specialized to few-shot tasks. The hierarchical-Bayes unification ties all three lenses together regardless of surface paradigm.

5.5 2D synthetic prototypical demo

Setup: 5-way classification, class means uniform on a 2D disk of radius 2.5, isotropic Gaussian noise σ=0.35\sigma = 0.35, 5 support + 15 query per class. Embedding: 2-layer MLP, hidden 64, embedding 32. Training: 500 meta-iters, 1 task per iter, Adam at 10310^{-3}.

class ProtoNet(nn.Module):
    def __init__(self, in_dim=2, h_dim=64, emb_dim=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, emb_dim),
        )

    def forward(self, x):
        return self.net(x)

def proto_loss_and_acc(model, support_x, support_y, query_x, query_y, n_classes=5):
    sup_emb = model(support_x)
    qry_emb = model(query_x)
    prototypes = torch.stack([
        sup_emb[support_y == k].mean(dim=0) for k in range(n_classes)
    ])
    dists = torch.cdist(qry_emb, prototypes, p=2) ** 2
    logits = -dists
    loss = F.cross_entropy(logits, query_y)
    acc = (logits.argmax(dim=-1) == query_y).float().mean()
    return loss, acc

The 2D input is unusually low-dimensional for a few-shot benchmark. Class clusters are already linearly separable in input space, so the embedding doesn’t have much work to do — the random-init embedding already classifies perfectly. Notebook outputs: query accuracy is 1.001.00 at epoch 0 (chance is 1/5=0.201/5 = 0.20); cross-entropy loss starts at 0.8880.888 and falls to essentially zero by epoch 499 as the embedding sharpens the softmax confidence.

This is a feature, not a bug, of the demo: the meta-learning story still goes through because the same network, trained on hundreds of tasks with random class-mean placements, produces task-specific prototypes that handle arbitrary held-out task arrangements. The decision-region figure makes this visual.

Two panels: cross-entropy loss decreasing from ~0.9 to ~0 over 500 meta-iters; accuracy at 1.00 throughout.
Figure 10. Prototypical Network training. Accuracy stays at 1.00 from epoch 0 (clusters are linearly separable in input space); loss decreases as the embedding sharpens softmax confidence. The accuracy curve is uninformative at this scale; the loss curve shows the embedding refinement.
Loading…
Two held-out tasks side-by-side with input-space decision regions, mapped through the embedding network to nearest prototype.
Figure 11. Prototypical Network on two held-out 5-way 5-shot tasks. Same network produces different decision regions per task — that is the meta-learning angle made visual.

§6. The hierarchical-Bayes unifying view

6.1 Tasks as a hierarchical generative model

The three lenses developed in §§2–5 look like different algorithms. They share a deeper structure: each is a posterior approximation in the same hierarchical generative model. Writing that model down — and re-reading each lens as a particular approximation to its central marginal-likelihood integral — is the conceptual unification this section delivers, and the substrate on which the PAC-Bayes generalization theory of §7 will sit.

The hierarchical model has three levels:

θ0    θTkp(θTkθ0)    (xi,yi)iTkp(yixi,θTk).(6.1)\theta_0 \;\to\; \theta_{\mathcal{T}_k} \sim p(\theta_{\mathcal{T}_k} \mid \theta_0) \;\to\; (\mathbf{x}_i, y_i)_{i \in \mathcal{T}_k} \sim p(y_i \mid \mathbf{x}_i, \theta_{\mathcal{T}_k}). \qquad\quad (6.1)

Meta-parameter θ0RD\theta_0 \in \mathbb{R}^D at the top; task-specific θTk\theta_{\mathcal{T}_k} drawn from a hyperprior; data drawn from a likelihood. This is the canonical hierarchical-Bayes setup of formalStatistics: hierarchical-bayes-and-partial-pooling §28, with groups → tasks, group-level parameters → task-specific parameters, population hyperparameter → meta-parameter.

Two functionals of θ0\theta_0 matter. The predictive density on a held-out task’s query given context:

p(yTxT,DC,θ0)  =  p(yTxT,θT)p(θTθ0,DC)dθT.(6.2)p(y_T \mid x_T, D_C, \theta_0) \;=\; \int p(y_T \mid x_T, \theta_{\mathcal{T}})\, p(\theta_{\mathcal{T}} \mid \theta_0, D_C)\, d\theta_{\mathcal{T}}. \qquad\quad (6.2)

And the marginal likelihood of one task’s full data:

p(DTθ0)  =  p(DTθT)p(θTθ0)dθT.(6.3)p(D_\mathcal{T} \mid \theta_0) \;=\; \int p(D_\mathcal{T} \mid \theta_\mathcal{T})\, p(\theta_\mathcal{T} \mid \theta_0)\, d\theta_\mathcal{T}. \qquad\quad (6.3)

Both are intractable. Each lens makes a different approximation: MAML approximates (6.2) by a point mass at the inner-loop MAP (§6.2); Latent NPs approximate (6.3) by an amortized variational lower bound (Theorem 1); Prototypical Networks approximate (6.2) by a Gaussian classifier in embedding space (§6.4).

6.2 Grant et al.: MAML as MAP with a Gaussian hyperprior

Grant, Finn, Levine, Darrell, and Griffiths (2018) prove that MAML’s inner loop computes — approximately, with the approximation tight in a regime we’ll make precise — the MAP estimate of θT\theta_\mathcal{T} under a Gaussian hyperprior centered at θ0\theta_0, with prior precision determined implicitly by (α,N)(\alpha, N).

Suppose the hyperprior is p(θTθ0)=N(θ0,Σp)p(\theta_\mathcal{T} \mid \theta_0) = \mathcal{N}(\theta_0, \Sigma_p). The MAP estimate minimizes:

θTMAP  =  argminθ[logp(DSθ)+12(θθ0)Σp1(θθ0)].(6.4)\theta_\mathcal{T}^{\mathrm{MAP}} \;=\; \mathrm{argmin}_\theta \left[-\log p(D^{\mathrm{S}} \mid \theta) + \tfrac{1}{2}(\theta - \theta_0)^\top \Sigma_p^{-1} (\theta - \theta_0)\right]. \qquad\quad (6.4)

MAML’s inner loop runs NN steps of GD on the support negative log-likelihood alone, starting from θ0\theta_0. No regularizer appears explicitly. The connection: early-stopped GD implicitly regularizes toward the initialization (Santos 1996; Yao, Rosasco, Caponnetto 2007). For the right Σp\Sigma_p, θT(N)\theta_\mathcal{T}^{(N)} from MAML equals the MAP from (6.4).

Theorem 2 (Implicit regularization of early-stopped GD; 1D Gaussian case).

Suppose logp(DSθ)=λ2(θμMLE)2+const-\log p(D^{\mathrm{S}} \mid \theta) = \tfrac{\lambda}{2}(\theta - \mu_{\mathrm{MLE}})^2 + \mathrm{const} with λ>0\lambda > 0, and the hyperprior is Gaussian θN(θ0,σp2)\theta \sim \mathcal{N}(\theta_0, \sigma_p^2). Then for 0<αλ<20 < \alpha\lambda < 2, the NN-th GD iterate equals the MAP estimate (6.4) iff

σp2  =  λ(1αλ)N1(1αλ)N.(6.5)\sigma_p^{-2} \;=\; \lambda \cdot \frac{(1 - \alpha\lambda)^N}{1 - (1 - \alpha\lambda)^N}. \qquad\quad (6.5)

In the small-αλ\alpha\lambda, early-stopping regime Nαλ1N\alpha\lambda \ll 1,

σp2    1Nαλ    1Nα,(6.6)\sigma_p^{-2} \;\approx\; \frac{1}{N\alpha} - \lambda \;\approx\; \frac{1}{N\alpha}, \qquad\quad (6.6)

approximately independent of the data curvature.

Proof.

MAP estimate for 1D Gaussian × Gaussian:

θMAP  =  λμMLE+θ0/σp2λ+1/σp2.()\theta_{\mathrm{MAP}} \;=\; \frac{\lambda \mu_{\mathrm{MLE}} + \theta_0/\sigma_p^2}{\lambda + 1/\sigma_p^2}. \qquad (\dagger)

GD iterate: gradient of quadratic at θ(n)\theta^{(n)} is λ(θ(n)μMLE)\lambda(\theta^{(n)} - \mu_{\mathrm{MLE}}). The recursion θ(n+1)μMLE=(1αλ)(θ(n)μMLE)\theta^{(n+1)} - \mu_{\mathrm{MLE}} = (1 - \alpha\lambda)(\theta^{(n)} - \mu_{\mathrm{MLE}}) closes. After NN steps from θ0\theta_0:

θ(N)  =  μMLE+(1αλ)N(θ0μMLE).()\theta^{(N)} \;=\; \mu_{\mathrm{MLE}} + (1 - \alpha\lambda)^N(\theta_0 - \mu_{\mathrm{MLE}}). \qquad (\ddagger)

Setting ()=()(\dagger) = (\ddagger), multiplying by λ+1/σp2\lambda + 1/\sigma_p^2, and cancelling θ0μMLE\theta_0 - \mu_{\mathrm{MLE}},

(1αλ)N(λσp2+1)  =  1,(1 - \alpha\lambda)^N (\lambda \sigma_p^2 + 1) \;=\; 1,

giving (6.5). Taylor-expand (1αλ)N=1Nαλ+O((αλ)2)(1 - \alpha\lambda)^N = 1 - N\alpha\lambda + O((\alpha\lambda)^2); numerator 1Nαλ\to 1 - N\alpha\lambda, denominator Nαλ\to N\alpha\lambda, so σp2λ(1Nαλ)/(Nαλ)=1/(Nα)λ1/(Nα)\sigma_p^{-2} \approx \lambda(1 - N\alpha\lambda)/(N\alpha\lambda) = 1/(N\alpha) - \lambda \approx 1/(N\alpha).

The multivariate generalization: diagonalize the support Hessian H=2[logp(DSθ0)]H = \nabla^2[-\log p(D^{\mathrm{S}} \mid \theta_0)]; apply the 1D result in each eigendirection. The implicit prior precision matrix has eigenvalues λi(1αλi)N/(1(1αλi)N)\lambda_i (1 - \alpha\lambda_i)^N / (1 - (1-\alpha\lambda_i)^N), approximately (Nα)1IH(N\alpha)^{-1} I - H in the early-stopping regime.

The reading. MAML’s bilevel optimization meta-learns the mean of a Gaussian hyperprior; the precision is implicitly fixed by (α,N)(\alpha, N). The outer loop optimizes θ0\theta_0 to maximize approximate marginal likelihood across tasks. The inner-loop output θT(N)\theta_\mathcal{T}^{(N)} is a MAP estimate; the MAML predictive is the plug-in predictive at that MAP — a point-mass posterior approximation, the coarsest of the three lens approximations.

Theorem 2 is exact when the support negative log-likelihood is quadratic. For nonlinear models, it holds locally around θ0\theta_0 via second-order Taylor expansion. The approximation is tight when the inner trajectory stays close to θ0\theta_0 — the regime small α\alpha buys us.

6.3 Empirical Bayes and prior estimation from the task distribution

Empirical Bayes (Robbins 1956; Efron and Morris 1973) estimates the prior from data. For meta-learning: estimate θ0\theta_0 (and implicitly Σp\Sigma_p via α,N\alpha, N) from the task distribution. The Type-II MLE objective:

θ0EB  =  argmaxθ0Tkbatchlogp(DTkθTk)p(θTkθ0)dθTk.(6.7)\theta_0^{\mathrm{EB}} \;=\; \mathrm{argmax}_{\theta_0}\, \sum_{\mathcal{T}_k \in \mathrm{batch}} \log \int p(D_{\mathcal{T}_k} \mid \theta_{\mathcal{T}_k})\, p(\theta_{\mathcal{T}_k} \mid \theta_0)\, d\theta_{\mathcal{T}_k}. \qquad\quad (6.7)

The integral inside the log is intractable. Each lens approximates differently.

MAML approximates by Laplace (truncated at order zero — MAP without the Hessian determinant): logp(DQDS,θ0)logp(DQθTMAP)\log p(D^{\mathrm{Q}} \mid D^{\mathrm{S}}, \theta_0) \approx \log p(D^{\mathrm{Q}} \mid \theta_\mathcal{T}^{\mathrm{MAP}}).

Latent NPs approximate by amortized VI: the ELBO of Theorem 1 is a true lower bound on the log-marginal-likelihood (under the model’s implied predictive). Outer loop is SGD on the sum of per-task ELBOs.

Prototypical Networks have a less standard interpretation: the model is a per-class Gaussian in embedding space, the meta-parameter is the embedding network. The objective is a per-task discriminative likelihood — closer to discriminative training than to (6.7)‘s generative formulation.

6.4 The three lenses as posterior approximations of varying fidelity

LensAdaptation procedureWithin-task posterior approx.Hyperprior structure
MAML (NN-step)NN-step SGD from θ0\theta_0 on supportPoint mass at θT(N)θTMAP\theta_\mathcal{T}^{(N)} \approx \theta_\mathcal{T}^{\mathrm{MAP}}Gaussian, mean θ0\theta_0, precision 1/(Nα)\approx 1/(N\alpha)
FOMAMLSame inner loop, drop second-orderSame as MAMLSame as MAML
ReptileNN-step SGD on combined dataImplicit through inner-loop convergenceImplicit through gradient-alignment (3.5)
Implicit MAMLSolve regularized fixed pointPoint mass at exact MAPGaussian, mean θ0\theta_0, precision λ\lambda
Latent NPAmortized q(zDC)q(z \mid D_C)Gaussian in global latent spaceImplicit through encoder qϕq_\phi
PrototypicalClass-mean prototypesGaussian class-conditional (implicit)Implicit through embedding ϕθ0\phi_{\theta_0}

Three readings. Fidelity of within-task posterior: point mass (MAML) → Gaussian variational (Latent NP) → closed-form per-class Gaussian (Prototypical). Adaptation cost runs opposite: MAML pays NN SGD steps + bilevel overhead, Latent NP pays one encoder forward pass, Prototypical pays a class-wise mean. How each represents the hyperprior: MAML implicit via (α,N)(\alpha, N), Implicit MAML explicit via λ\lambda, Latent NP and Prototypical even more implicit (encoder/embedding output). What each optimizes: Laplace-truncated marginal likelihood (MAML), ELBO (Latent NP), discriminative cross-entropy (Prototypical).

The unification is conceptual, not algorithmic. The three lenses are not interchangeable; choosing involves real trade-offs. What the unification delivers is a single language — hierarchical Bayes — in which we can state generalization results that apply to all three. §7’s PAC-Bayes meta-bound is the major payoff.


§7. PAC-Bayes generalization theory for meta-learning

7.1 Why single-task PAC-Bayes does not apply directly

Standard PAC-Bayes (the McAllester bound from formalML’s PAC-Bayes Bounds Theorem 1) gives, for any prior PP over hypotheses and any data-dependent posterior QQ, a high-probability bound on the gap between true and empirical risk on a single iid sample:

R(Q)    R^(Q)+KL(QP)+log(2n/δ)2n.(7.1)R(Q) \;\leq\; \widehat{R}(Q) + \sqrt{\frac{\mathrm{KL}(Q \,\|\, P) + \log(2\sqrt{n}/\delta)}{2n}}. \qquad\quad (7.1)

This is a single-task bound. The natural temptation in meta-learning is to treat the combined dataset (TT tasks, nn samples each, nTnT total) as one big iid sample and apply (7.1). Two reasons that fails:

The samples aren’t iid as one combined sample — each within-task sample comes from a task-specific distribution DT\mathcal{D}_\mathcal{T}, and different tasks have different distributions. And the right notion of “true risk” for a meta-learner is risk on a new task, not risk on observed samples. Two distinct generalization gaps:

  • Within-task gap: empirical risk on nn samples of one task vs. true risk on that task.
  • Across-task gap: average true risk over TT observed tasks vs. expected true risk over a new task.

Standard PAC-Bayes controls a single gap; a meta-PAC-Bayes bound must control both. Amit and Meir (2018) chain two concentration arguments — one at each scale — and combine through a hierarchical KL.

7.2 The Amit–Meir meta-bound: statement

We have a hyperprior P0P_0 — a fixed distribution over priors PP — and a hyperposterior Q\mathcal{Q}, a data-dependent distribution over priors. For each prior PP and each task T\mathcal{T} with data STS_\mathcal{T}, the meta-learner produces a within-task posterior QTPQ_\mathcal{T}^P.

Framework bridge. The §6 hierarchical-Bayes setup treated θ0\theta_0 as a deterministic meta-parameter parameterizing a single hyperprior p(θTθ0)p(\theta_\mathcal{T} \mid \theta_0). The §7 framework generalizes: Q\mathcal{Q} is a distribution over priors PP, with the deterministic-θ0\theta_0 case of §6 corresponding to Q\mathcal{Q} being a Dirac mass on the specific PP induced by θ0\theta_0 via §6.2 (e.g., P=N(θ0,Σp)P = \mathcal{N}(\theta_0, \Sigma_p) for that θ0\theta_0). The distribution-over-priors framing is what the PAC-Bayes machinery needs to deliver clean bounds — Bayesian meta-learning algorithms (Latent NPs) naturally produce such distributions; gradient-based methods (MAML) collapse to Dirac and require the Dziugaite–Roy softening discussed in §7.5.

Three risks. Within-task true and empirical:

LT(P)=E(x,y)DTEhQTP[(h(x),y)],L^T(P)=1n(xi,yi)STEhQTP[(h(xi),yi)].\mathcal{L}_\mathcal{T}(P) = \mathbb{E}_{(\mathbf{x}, y) \sim \mathcal{D}_\mathcal{T}}\, \mathbb{E}_{h \sim Q_\mathcal{T}^P}[\ell(h(\mathbf{x}), y)], \quad \widehat{\mathcal{L}}_\mathcal{T}(P) = \frac{1}{n}\sum_{(\mathbf{x}_i, y_i) \in S_\mathcal{T}} \mathbb{E}_{h \sim Q_\mathcal{T}^P}[\ell(h(\mathbf{x}_i), y_i)].

Meta-risks:

Rmeta(Q)=EPQETp(T)LT(P),R^meta(Q)=EPQ1Tk=1TL^Tk(P).R_{\mathrm{meta}}(\mathcal{Q}) = \mathbb{E}_{P \sim \mathcal{Q}}\, \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})}\, \mathcal{L}_\mathcal{T}(P), \qquad \widehat{R}_{\mathrm{meta}}(\mathcal{Q}) = \mathbb{E}_{P \sim \mathcal{Q}}\, \frac{1}{T}\sum_{k=1}^{T}\, \widehat{\mathcal{L}}_{\mathcal{T}_k}(P).

Loss bounded in [0,1][0, 1].

Theorem 3 (Amit–Meir meta PAC-Bayes bound).

With probability at least 1δ1 - \delta over TT tasks T1,,TTp(T)\mathcal{T}_1, \ldots, \mathcal{T}_T \sim p(\mathcal{T}) with within-task iid samples of size nn:

Rmeta(Q)    R^meta(Q)  +  KL(QP0)+log(4T/δ)2(T1)  +  EPQ1Tk=1TKL(QTkPP)+log(4Tn/δ)2(n1).(7.2)R_{\mathrm{meta}}(\mathcal{Q}) \;\leq\; \widehat{R}_{\mathrm{meta}}(\mathcal{Q}) \;+\; \sqrt{\frac{\mathrm{KL}(\mathcal{Q} \,\|\, P_0) + \log(4T/\delta)}{2(T-1)}} \;+\; \mathbb{E}_{P \sim \mathcal{Q}}\, \frac{1}{T} \sum_{k=1}^{T} \sqrt{\frac{\mathrm{KL}(Q_{\mathcal{T}_k}^P \,\|\, P) + \log(4Tn/\delta)}{2(n-1)}}. \qquad\quad (7.2)

Two square-root concentration terms. The across-task term has rate 1/T1/\sqrt{T} and depends on KL(QP0)\mathrm{KL}(\mathcal{Q} \| P_0). The within-task term has rate 1/n1/\sqrt{n} and is averaged over the TT observed tasks. Large TT, small nn (few-shot) → within-task term dominates. Small TT, large nn (transfer learning) → across-task term dominates.

7.3 The double-sided concentration argument

The proof chains two applications of single-task PAC-Bayes.

Piece A (within-task concentration): McAllester (7.1) applied to within-task data gives high-probability control of LT(P)L^T(P)\mathcal{L}_\mathcal{T}(P) - \widehat{\mathcal{L}}_\mathcal{T}(P) in terms of KL(QTPP)\mathrm{KL}(Q_\mathcal{T}^P \| P) and nn.

Piece B (across-task concentration): treat LT(P)\mathcal{L}_\mathcal{T}(P) as a [0,1][0,1]-bounded function of T\mathcal{T}. PAC-Bayes applied at the meta level — with P0P_0 as prior, Q\mathcal{Q} as posterior, TT observed tasks as samples — gives control of EPETLT(P)EP(1/T)kLTk(P)\mathbb{E}_{P} \mathbb{E}_\mathcal{T} \mathcal{L}_\mathcal{T}(P) - \mathbb{E}_{P} (1/T) \sum_k \mathcal{L}_{\mathcal{T}_k}(P) in terms of KL(QP0)\mathrm{KL}(\mathcal{Q} \| P_0) and TT.

Piece C (union bound): each piece holds with its own probability of failure; allocate δ/2\delta/2 to each.

The subtle move is Piece B: applying PAC-Bayes at the meta level, where “samples” are tasks and “hypotheses” are within-task priors. The McAllester machinery is fully general — it bounds the deviation of any [0,1][0,1]-valued statistic under a posterior from its expectation under iid samples.

7.4 Proof of the Amit–Meir bound

Proof.

Step 1 (within-task, Piece A). Apply McAllester to each Tk\mathcal{T}_k with confidence δ1/T\delta_1/T; union-bound across the TT tasks. With probability 1δ1\ge 1 - \delta_1, for every Tk\mathcal{T}_k and PP,

LTk(P)L^Tk(P)    KL(QTkPP)+log(2Tn/δ1)2(n1).()\mathcal{L}_{\mathcal{T}_k}(P) - \widehat{\mathcal{L}}_{\mathcal{T}_k}(P) \;\leq\; \sqrt{\frac{\mathrm{KL}(Q_{\mathcal{T}_k}^P \,\|\, P) + \log(2T\sqrt{n}/\delta_1)}{2(n-1)}}. \qquad (\star)

Taking expectation over PQP \sim \mathcal{Q} and averaging over tasks,

EPQ1TkLTk(P)    R^meta(Q)  +  EPQ1TkKL(QTkPP)+log(2Tn/δ1)2(n1).()\mathbb{E}_{P \sim \mathcal{Q}}\, \frac{1}{T}\sum_k \mathcal{L}_{\mathcal{T}_k}(P) \;\leq\; \widehat{R}_{\mathrm{meta}}(\mathcal{Q}) \;+\; \mathbb{E}_{P \sim \mathcal{Q}}\, \frac{1}{T}\sum_k \sqrt{\frac{\mathrm{KL}(Q_{\mathcal{T}_k}^P \,\|\, P) + \log(2T\sqrt{n}/\delta_1)}{2(n-1)}}. \qquad (\star\star)

Step 2 (across-task, Piece B). Treat LT(P)[0,1]\mathcal{L}_\mathcal{T}(P) \in [0, 1] as a statistic on tasks. McAllester at the meta level — P0P_0 prior, Q\mathcal{Q} posterior, TT iid task samples:

EPQETp(T)LT(P)    EPQ1TkLTk(P)  +  KL(QP0)+log(2T/δ2)2(T1).()\mathbb{E}_{P \sim \mathcal{Q}}\, \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})}\, \mathcal{L}_\mathcal{T}(P) \;\leq\; \mathbb{E}_{P \sim \mathcal{Q}}\, \frac{1}{T}\sum_k \mathcal{L}_{\mathcal{T}_k}(P) \;+\; \sqrt{\frac{\mathrm{KL}(\mathcal{Q} \,\|\, P_0) + \log(2\sqrt{T}/\delta_2)}{2(T-1)}}. \qquad (\star\star\star)

The LHS is Rmeta(Q)R_{\mathrm{meta}}(\mathcal{Q}); the first RHS term is what ()(\star\star) controls.

Step 3 (chain and union). Allocate δ1=δ2=δ/2\delta_1 = \delta_2 = \delta/2. Both ()(\star\star) and ()(\star\star\star) hold with probability 1δ\ge 1 - \delta. Substitute ()(\star\star) into ()(\star\star\star). With log(4T/δ)log(4T/δ)\log(4\sqrt{T}/\delta) \le \log(4T/\delta) and log(4Tn/δ)log(4Tn/δ)\log(4T\sqrt{n}/\delta) \le \log(4Tn/\delta), recover (7.2).

The T1T-1 and n1n-1 denominators come from the Maurer–Pontil constant refinement (pac-bayes-bounds §7.1); for T,n10T, n \ge 10 the difference vs. T,nT, n is at the second decimal place.

7.5 Tightness, computability, and the Pentina–Lampert variant

The Pentina–Lampert variant. Pentina and Lampert (2014) give a simpler bound when the meta-learner picks a single shared prior PP^* across tasks, rather than a distribution Q\mathcal{Q}. Same double-sided structure, but the across-task KL term simplifies. Pentina–Lampert is the natural bound for most practical meta-learning algorithms (MAML, Reptile, prototypical networks); Amit–Meir is the right tool when the algorithm explicitly meta-learns a distribution over priors (Latent NPs).

Tightness for MAML and the point-mass posterior issue. MAML’s within-task “posterior” is a point mass at θT(N)\theta_\mathcal{T}^{(N)}. KL between a point mass and a continuous prior is infinite — (7.2) is vacuous. The fix: use a soft posterior centered at θT(N)\theta_\mathcal{T}^{(N)} — a Gaussian with small variance, or Laplace approximation. Dziugaite and Roy (2017) developed this for single-task PAC-Bayes; Amit–Meir apply it to the meta-setting. Resulting bounds are non-vacuous for moderate task counts.

Computability. For Gaussian-Gaussian KL the closed form is

KL ⁣(N(μ1,Σ1)N(μ2,Σ2))  =  12 ⁣[tr(Σ21Σ1)+(μ1μ2)Σ21(μ1μ2)D+logdetΣ2detΣ1],\mathrm{KL}\!\big(\mathcal{N}(\mu_1, \Sigma_1) \,\|\, \mathcal{N}(\mu_2, \Sigma_2)\big) \;=\; \tfrac{1}{2}\!\left[\mathrm{tr}(\Sigma_2^{-1} \Sigma_1) + (\mu_1 - \mu_2)^\top \Sigma_2^{-1}(\mu_1 - \mu_2) - D + \log \frac{\det \Sigma_2}{\det \Sigma_1}\right],

tractable for diagonal Σi\Sigma_i. For MAML with Dziugaite–Roy softening, the KL involves θT(N)θ02\|\theta_\mathcal{T}^{(N)} - \theta_0\|^2 scaled by 1/(Nα)1/(N\alpha) — the “distance from initialization” quantity the meta-learning literature tracks empirically.

Two panels: total Amit-Meir bound vs T at three n values; across-task vs within-task decomposition at fixed n=25.
Figure 12. Amit–Meir bound (7.2) visualized. Left: total bound vs T for n ∈ {5, 25, 100}. Right: across-task and within-task term decomposition at n=25. Crossover around T≈60 for these KL values.

§8. Convergence guarantees for MAML

8.1 The smoothness, bounded variance, and Lipschitz Hessian assumptions

§7 told us when MAML generalizes. §8 tells us when MAML’s outer loop converges. The two questions are independent. Convergence theory tackles: under what conditions does meta-gradient descent in (2.3) reach a stationary point of the meta-objective (2.2), and at what rate.

Fallah, Mokhtari, and Ozdaglar (2020) give the first complete analysis of stochastic MAML under non-convex inner loss. Setup with N=1N=1 inner step; multi-step extension is technically more involved but qualitatively the same. The meta-objective:

F(θ0)  =  ETp(T) ⁣[LTQ ⁣(θ0αLTS(θ0))],(8.1)F(\theta_0) \;=\; \mathbb{E}_{\mathcal{T} \sim p(\mathcal{T})}\!\left[ L_\mathcal{T}^{\mathrm{Q}}\!\big(\theta_0 - \alpha \nabla L_\mathcal{T}^{\mathrm{S}}(\theta_0)\big) \right], \qquad\quad (8.1)

and the stochastic meta-gradient at θk\theta_k:

g~k  =  (IαH~kS)g~kQ,(8.2)\tilde{g}_k \;=\; \big(I - \alpha \widetilde{H}_k^{\mathrm{S}}\big)\, \widetilde{g}_k^{\mathrm{Q}}, \qquad\quad (8.2)

where H~kS\widetilde{H}_k^{\mathrm{S}} is a stochastic Hessian estimate at θk\theta_k and g~kQ\widetilde{g}_k^{\mathrm{Q}} is a stochastic gradient at the adapted point. Outer update: θk+1=θkβg~k\theta_{k+1} = \theta_k - \beta \tilde{g}_k.

Definition 1 (Standing assumptions (FMO 2020)).

Assumption 8.1 (Lipschitz gradient). LT(θ)LT(θ)Lθθ\lVert\nabla L_\mathcal{T}(\theta) - \nabla L_\mathcal{T}(\theta')\rVert \le L \lVert\theta - \theta'\rVert.

Assumption 8.2 (Lipschitz Hessian). 2LT(θ)2LT(θ)ρθθ\lVert\nabla^2 L_\mathcal{T}(\theta) - \nabla^2 L_\mathcal{T}(\theta')\rVert \le \rho \lVert\theta - \theta'\rVert.

Assumption 8.3 (Bounded gradient variance). Eg~LT2σg2\mathbb{E}\lVert\widetilde{g} - \nabla L_\mathcal{T}\rVert^2 \le \sigma_g^2.

Assumption 8.4 (Bounded Hessian variance). EH~2LTop2σH2\mathbb{E}\lVert\widetilde{H} - \nabla^2 L_\mathcal{T}\rVert_{\mathrm{op}}^2 \le \sigma_H^2.

Plus a structural assumption: at each meta-iter, draw three independent batches per task — one for g~S\widetilde{g}^{\mathrm{S}}, one for H~S\widetilde{H}^{\mathrm{S}}, one for g~Q\widetilde{g}^{\mathrm{Q}}. Independence breaks correlations that would otherwise bias the meta-gradient. Without it, the analysis acquires a bias term αLσg2\alpha L \sigma_g^2 that doesn’t shrink with meta-iters and prevents convergence to the true stationary point.

8.2 The Fallah–Mokhtari–Ozdaglar theorem

Theorem 4 (FMO 2020, MAML convergence).

Under Assumptions 8.1–8.4, α1/(2L)\alpha \le 1/(2L), and β=1/(LFK)\beta = 1/(L_F \sqrt{K}) where LFL_F is the meta-objective smoothness constant (Lemma 1 below) and KK is the total number of meta-iterations,

min0k<KE ⁣[F(θk)2]    O ⁣(1K).(8.3)\min_{0 \le k < K}\, \mathbb{E}\!\left[\lVert\nabla F(\theta_k)\rVert^2\right] \;\le\; \mathcal{O}\!\left(\frac{1}{\sqrt{K}}\right). \qquad\quad (8.3)

Constants depend polynomially on L,ρ,σg,σH,α,L, \rho, \sigma_g, \sigma_H, \alpha, and the meta-loss gap F(θ0)FF(\theta_0) - F^*.

Standard non-convex SGD rate. K=O(1/ε2)K = \mathcal{O}(1/\varepsilon^2) iterations reach EF2ε\mathbb{E}\lVert\nabla F\rVert^2 \le \varepsilon. The bilevel setup is what’s new and non-trivial; constants need careful tracking.

8.3 The stochastic-gradient analysis with second-order terms

Lemma 1 (Meta-objective smoothness).

Under Assumptions 8.1–8.2 and α1/L\alpha \le 1/L, FF has LFL_F-Lipschitz gradient with

LF  =  (1+αL)2L+αρGQ,(8.4)L_F \;=\; (1 + \alpha L)^2 L + \alpha\, \rho\, G_Q, \qquad\quad (8.4)

where GQG_Q is any uniform bound on LTQ\lVert\nabla L_\mathcal{T}^{\mathrm{Q}}\rVert.

The (1+αL)2L(1+\alpha L)^2 L piece is what you’d get with the inner Jacobian squared. The αρGQ\alpha \rho G_Q piece is new — Assumption 8.2 says 2LS\nabla^2 L^{\mathrm{S}} changes at rate ρ\rho in θ0\theta_0, and the meta-gradient inherits this. Small α\alpha keeps LFL_F small.

Lemma 2 (Meta-gradient variance).

Under Assumptions 8.1–8.4 and independent batches, (8.2) is unbiased for F(θk)\nabla F(\theta_k) and

E ⁣[g~kF(θk)2]    σg2+α2σH2GQ2+α2σg2σH2.(8.5)\mathbb{E}\!\left[\lVert\tilde{g}_k - \nabla F(\theta_k)\rVert^2\right] \;\le\; \sigma_g^2 + \alpha^2 \sigma_H^2 G_Q^2 + \alpha^2 \sigma_g^2 \sigma_H^2. \qquad\quad (8.5)

Three terms: base SGD variance, Hessian-driven variance, and cross-term. The second two are O(α2)\mathcal{O}(\alpha^2) corrections — another reason small α\alpha helps. If the same batch is used for g~S\widetilde{g}^{\mathrm{S}} and g~Q\widetilde{g}^{\mathrm{Q}}, bias of order αLσg2\alpha L \sigma_g^2 contaminates the meta-gradient — fixed bias, not shrinking with KK.

8.4 Proof of Theorem 4

Proof.

By smoothness of FF (Lemma 1):

F(θk+1)    F(θk)βF(θk),g~k+LFβ22g~k2.()F(\theta_{k+1}) \;\le\; F(\theta_k) - \beta \langle \nabla F(\theta_k), \tilde{g}_k\rangle + \frac{L_F \beta^2}{2}\lVert\tilde{g}_k\rVert^2. \qquad (\dagger)

Take conditional expectation. By unbiasedness (Lemma 2), E[F,g~kFk]=F2\mathbb{E}[\langle \nabla F, \tilde{g}_k\rangle \mid \mathcal{F}_k] = \lVert\nabla F\rVert^2. Expand g~k2=F2+Vk+2F,g~kF\lVert\tilde{g}_k\rVert^2 = \lVert\nabla F\rVert^2 + V_k + 2\langle \nabla F, \tilde{g}_k - \nabla F\rangle; cross term vanishes in expectation. With VkV_k bounded by RHS of (8.5):

E[F(θk+1)Fk]    F(θk)βF(θk)2+LFβ22 ⁣(F(θk)2+Vk).\mathbb{E}[F(\theta_{k+1}) \mid \mathcal{F}_k] \;\le\; F(\theta_k) - \beta \lVert\nabla F(\theta_k)\rVert^2 + \frac{L_F \beta^2}{2}\!\left(\lVert\nabla F(\theta_k)\rVert^2 + V_k\right).

Rearrange:

β ⁣(1LFβ2)F(θk)2    F(θk)E[F(θk+1)Fk]+LFβ22Vk.\beta\!\left(1 - \frac{L_F \beta}{2}\right) \lVert\nabla F(\theta_k)\rVert^2 \;\le\; F(\theta_k) - \mathbb{E}[F(\theta_{k+1}) \mid \mathcal{F}_k] + \frac{L_F \beta^2}{2} V_k.

Choose β1/LF\beta \le 1/L_F so coefficient on LHS is β/2\ge \beta/2. Sum k=0k=0 to K1K-1FF terms telescope:

β2k=0K1E ⁣[F(θk)2]    F(θ0)F+LFβ22KV.\frac{\beta}{2} \sum_{k=0}^{K-1} \mathbb{E}\!\left[\lVert\nabla F(\theta_k)\rVert^2\right] \;\le\; F(\theta_0) - F^* + \frac{L_F \beta^2}{2} K V.

Divide by Kβ/2K\beta/2 and bound the minimum by the average:

minkE ⁣[F(θk)2]    2(F(θ0)F)Kβ+LFβV.\min_{k} \mathbb{E}\!\left[\lVert\nabla F(\theta_k)\rVert^2\right] \;\le\; \frac{2(F(\theta_0) - F^*)}{K \beta} + L_F \beta V.

Choose β=1/(LFK)\beta = 1/(L_F \sqrt{K}) to balance — both terms become O(1/K)\mathcal{O}(1/\sqrt{K}):

minkE ⁣[F(θk)2]    1K ⁣[2LF(F(θ0)F)+V].\min_{k} \mathbb{E}\!\left[\lVert\nabla F(\theta_k)\rVert^2\right] \;\le\; \frac{1}{\sqrt{K}}\!\left[2 L_F (F(\theta_0) - F^*) + V\right].
Proof.

Proof of Lemma 1. Write F(θ)F(θ)=ET[(A)+(B)]\nabla F(\theta) - \nabla F(\theta') = \mathbb{E}_\mathcal{T}[(\text{A}) + (\text{B})] where (A) is the change in LQ\nabla L^{\mathrm{Q}} at the adapted point and (B) is the change in the inner Jacobian. For (A): the adapted point is (1+αL)(1 + \alpha L)-Lipschitz in θ\theta; the Jacobian factor has operator norm 1+αL\le 1 + \alpha L; so (A)(1+αL)2Lθθ\lVert(\text{A})\rVert \le (1 + \alpha L)^2 L \lVert\theta - \theta'\rVert. For (B): inner Jacobian changes at rate αρ\alpha \rho; LQ\nabla L^{\mathrm{Q}} bounded by GQG_Q; so (B)αρGQθθ\lVert(\text{B})\rVert \le \alpha \rho G_Q \lVert\theta - \theta'\rVert. Adding gives (8.4).

Proof.

Proof of Lemma 2. Unbiasedness: by independence of the three batches, E[H~Sg~Q]=2LS(θk)LQ(θk(1))\mathbb{E}[\widetilde{H}^{\mathrm{S}} \widetilde{g}^{\mathrm{Q}}] = \nabla^2 L^{\mathrm{S}}(\theta_k)\, \nabla L^{\mathrm{Q}}(\theta^{(1)}_k) (the adapted point depends on g~S\widetilde{g}^{\mathrm{S}} but not on the other two). Variance: decompose H~S=2LS+Δ~H\widetilde{H}^{\mathrm{S}} = \nabla^2 L^{\mathrm{S}} + \widetilde{\Delta}_H and g~Q=LQ+Δ~g\widetilde{g}^{\mathrm{Q}} = \nabla L^{\mathrm{Q}} + \widetilde{\Delta}_g; expand the product; apply Var(XY)Var(X)E[Y]2+E[X]2Var(Y)+Var(X)Var(Y)\mathrm{Var}(XY) \le \mathrm{Var}(X)\mathbb{E}[Y]^2 + \mathbb{E}[X]^2\mathrm{Var}(Y) + \mathrm{Var}(X)\mathrm{Var}(Y) termwise.

8.5 Rate dependence on the inner-loop step size and inner-loop count

Inner-loop step size α\alpha. LFL_F grows polynomially in α\alpha (8.4); variance (8.5) grows quadratically in α\alpha for the second-order pieces. Large α\alpha is bad for convergence. α=0\alpha = 0 degenerates the meta-objective (no inner adaptation). Practical α\alpha is small but positive.

Inner-loop step count NN. Multi-step MAML inherits an LFL_F scaling as (1+αL)2N(1 + \alpha L)^{2N} — exponential in NN. N=5N = 5 in §2 is the practical sweet spot; beyond N10N \approx 10, the smoothness constant becomes large enough to require much smaller outer step sizes.

Mini-batch size effects. Variance terms scale inversely in mini-batch size bb. FMO Theorem 4.5: KK iterations with batch size bb converge at O(1/Kb)\mathcal{O}(1/\sqrt{Kb}) — additional batch buys b\sqrt{b} improvement.

Two-panel: §2 meta-loss curve on log axes; same curve in 1/sqrt(K) coordinates with a C/sqrt(K) envelope overlaid.
Figure 13. Numerical validation of the 1/√K rate. Left: §2's meta-loss curve on log axes. Right: x-axis transformed to 1/√K with a C/√K envelope. The running-minimum curve lies below the envelope at early K and approaches it as K grows — qualitatively matching the FMO prediction.
Derived from sliders: L_F = 1.025 (8.4); V = 1.000 (8.5). Trajectory shape obeys the resulting C/√K envelope.

The qualitative shape matches the theorem prediction; the absolute fit is approximate because the theorem bounds F2\lVert\nabla F\rVert^2 and the curve plots meta-loss. The §2 implementation also uses Adam (not constant-step SGD) and shared batches across the three roles (the structural assumption violation flagged in §8.3) — both are widely used in practice and produce a small empirical bias relative to the FMO theorem’s setup.


§9. Comparative benchmarks and ablations

9.1 Cross-lens comparison on a shared task family

The three lenses developed in §§2–5 each used a different task family — sinusoidal regression for MAML, GP-derived regression for Neural Processes, 2D Gaussian-cluster classification for Prototypical Networks. This was deliberate: each task family is the canonical demonstration for its lens. The §6 unification told us these lenses are conceptually unified despite their algorithmic differences; this section makes the cross-lens story concrete.

The unifying structural claim — every lens implements an approximate posterior over task-specific behavior conditioned on a small support set — predicts three observable consequences:

  • Sample efficiency is a function of meta-trained inductive bias. Each method’s test performance should improve smoothly as the support-set size KK grows. §9.2.
  • Inner-loop step count at test time can differ from meta-training for gradient-based methods, with bias toward the meta-training NN. §9.3.
  • Meta-gradient variance scales as 1/B1/B in the task-batch size. §9.4.
Three-panel synthesis: trained MAML on a held-out sinusoidal task, trained Latent NP on a held-out GP task with exact GP overlay, trained Prototypical Network on a held-out 5-way 5-shot 2D task.
Figure 14. Cross-lens synthesis: three methods on their native task families. Each panel shows the meta-trained model adapting to a held-out task. Visually similar 'few-shot adaptation' through different mechanisms.

The figure is visual rather than quantitative: it shows the three methods produce qualitatively similar “few-shot adaptation” on their respective task families, despite operating through completely different mechanisms. The MAML panel reproduces the §2.5 adaptation evolution at the final n=5n=5 inner step. The Latent NP panel overlays the trained NP’s mean and 95% predictive band against the exact GP posterior. The Prototypical Network panel shows input-space decision regions from a held-out task.

9.2 Sample-efficiency curves

Evaluate trained CNP and Latent NP from §4 on 20 fresh held-out GP tasks, sweeping context size KC{2,4,8,12,16,20}K_{\mathrm{C}} \in \{2, 4, 8, 12, 16, 20\}. Report average test NLL on a 30-target sample. Expected behavior: NLL decreases as KCK_{\mathrm{C}} grows, plateauing as predictions saturate against the GP’s irreducible noise floor.

Test NLL vs context size on held-out GP tasks. CNP and Latent NP curves, error bars over 20 tasks.
Figure 15. Sample efficiency: test NLL vs context size on held-out GP tasks. CNP and Latent NP track each other closely; NLL decreases as context grows and plateaus around K_C ≈ 12–16.
Loading…

CNP and Latent NP track each other closely with the Latent NP’s mean (over 20 sampled latents per task) sitting slightly below CNP. Steep descent for KC[2,8]K_{\mathrm{C}} \in [2, 8], plateau for KC12K_{\mathrm{C}} \ge 12. The transfer to MAML and Prototypical Networks is direct but the experiment shape differs (varying support size for MAML, varying shots per class for Prototypical). We don’t run those sweeps separately — qualitatively the same story.

9.3 Inner-loop step-count sensitivity for MAML / FOMAML / Reptile

Inner-loop step count is a meta-trained hyperparameter for MAML: at training time we fix N=5N = 5 and the meta-gradient flows through exactly NN inner steps. Question: what happens at test time if we use a different NtestN_{\mathrm{test}}?

Evaluate the three trained θ0\theta_0s from §§2–3 on 20 fresh held-out sinusoidal tasks, sweeping Ntest{0,1,2,3,5,8,12}N_{\mathrm{test}} \in \{0, 1, 2, 3, 5, 8, 12\}. Compute post-adaptation query MSE on a 50-target sample.

Mean test query MSE vs N_test, one curve per method (MAML, FOMAML, Reptile).
Figure 16. Inner-loop sensitivity: held-out query MSE vs test-time inner-loop step count. All three methods' MSE drops sharply from N=0 to N=5 (meta-training step count) and plateaus.
Loading…

Pedagogical takeaway: all three methods produce θ0\theta_0s that “work well” with their meta-training NN but tolerate NtestN_{\mathrm{test}} deviations within a factor of two. Consistent with the §6.2 hierarchical-Bayes interpretation: meta-trained θ0\theta_0 encodes a prior precision via (α,N)(\alpha, N), and small changes in effective prior precision give small changes in MAP quality.

9.4 Task-batch size and meta-gradient variance

Task batch size BB controls the variance of the empirical meta-gradient estimator. Lemma 2 bounded one task’s meta-gradient variance; with BB independent task draws and averaging, variance shrinks as 1/B1/B by standard iid argument.

Verify empirically. At the §2-trained θ0MAML\theta_0^{\mathrm{MAML}}, draw 200 fresh sinusoidal tasks and compute per-task FOMAML meta-gradients. (FOMAML rather than full MAML because second-order at 200 samples would be too expensive; same 1/B1/B scaling structure.) For B{1,2,5,10,25,50}B \in \{1, 2, 5, 10, 25, 50\}, estimate the variance of the BB-task-average via 200/B200/B disjoint batches.

Empirical variance vs B on log-log axes, with a 1/B theoretical line overlaid.
Figure 17. Empirical meta-gradient variance vs task batch size B on log-log axes. Empirical points lie close to the theoretical 1/B line of slope -1 over almost two orders of magnitude in B.

In-browser 1D-quadratic analog of the §9.4 experiment: per-task FOMAML meta-gradients at θ₀=0 with random support and query targets. The 1/B scaling identity is rigid because it follows from independence across task draws.

Practical implication: doubling BB halves meta-gradient variance, which by FMO improves the convergence-rate constant by 2\sqrt{2}. Same trade-off as standard mini-batch SGD, lifted to the meta-iter scale. The PAC-Bayes across-task term of (7.2) also shrinks as variance shrinks.


§10. Computational notes

10.1 torch.autograd.grad(..., create_graph=True) semantics

Setting create_graph=True keeps the computation graph alive for a second backward pass. With it, the gradient itself becomes a differentiable computation — you can take a gradient of a gradient. This is the second-order primitive MAML’s outer loop needs.

The canonical HVP idiom:

loss = compute_support_loss(params, x_s, y_s)
grads = torch.autograd.grad(loss, params, create_graph=True)
inner_product = sum((g * v).sum() for g, v in zip(grads, vectors))
hvp = torch.autograd.grad(inner_product, params)  # H @ v

Pitfalls: forgetting create_graph=True produces an obscure “one of the differentiated Tensors does not require grad” error. Memory doubles for the inner loop because the entire forward-and-backward graph is retained — rarely binding for N10N \le 10, can be the binding constraint for long-horizon Implicit MAML.

The §2 MAML implementation uses torch.func.grad instead — equivalent for second-order use but composes with vmap (next section). The §3 FOMAML implementation deliberately uses plain torch.autograd.grad without create_graph=True because no second-order graph is needed.

10.2 torch.func.vmap + functional_call over task batches

The modern idiom for parallelizing inner-loop adaptation across a task batch:

from torch.func import functional_call, vmap, grad

def per_task_loss(params, support_x, support_y, query_x, query_y):
    adapted = params
    for _ in range(N_inner):
        g = grad(lambda p: F.mse_loss(
            functional_call(model, p, (support_x,)), support_y))(adapted)
        adapted = {k: p - alpha * g[k] for k, p in adapted.items()}
    pred = functional_call(model, adapted, (query_x,))
    return F.mse_loss(pred, query_y)

batched_loss = vmap(per_task_loss, in_dims=(None, 0, 0, 0, 0))
loss = batched_loss(meta_params, sx, sy, qx, qy).mean()
loss.backward()

The in_dims is critical: None for shared, 0 for batched on axis 0. Get this wrong and vmap either errors or silently produces wrong results.

Two gotchas. vmap doesn’t compose cleanly with torch.autograd.grad(..., create_graph=True); use functional torch.func.grad instead. Not every operator supports vmap — check compatibility if you hit a runtime error. On CPU, vmap is not always faster than a Python for loop for small batches and small networks; speedup is 1.5–2× for our B=4B = 4 case. The performance argument is much stronger on GPU.

10.3 Buffer hoisting in the meta-loop

The project-wide pattern from formalML’s house-style guide applies. In the outer meta-loop, the natural temptation is to allocate fresh dict-of-tensors every iter. For 100 iters of 50K-parameter models, this is the bulk of the GC overhead.

Pattern: allocate scratch buffers once, before the meta-loop; reuse with in-place writes (copy_, add_, mul_, fill_). For our small notebook at 100 meta-iters this is invisible — Python’s allocator is fast. For production meta-training at 10,000+ iters it’s a 20–30% speedup.

10.4 GP backbones via gpytorch and what’s worth re-deriving

The §4 Neural Process demo uses a hand-rolled 5-line NumPy GP posterior rather than gpytorch. The choice is design-driven:

gpytorch is the right call when: you need kernel hyperparameter learning via marginal likelihood; you’re working with non-RBF kernels (Matérn, Spectral Mixture); the GP backbone is integrated into a larger gradient-based training pipeline and you want the GP’s autograd in the same graph as the rest of the model.

Hand-rolled NumPy is the right call when: the kernel is fixed and hyperparameters are known; the GP is a reference posterior rather than a trainable component; you want explicit control over the posterior arithmetic for pedagogical clarity. The §4 demo is exactly this case.

For the ANP variant of §4.4, the natural implementation uses torch.nn.MultiheadAttention rather than gpytorch — different ecosystem, different role.


§11. Connections and limits

11.1 What this topic delivered

We developed meta-learning through three lenses — gradient-based (MAML and first-order approximations in §§2–3), Bayesian (Neural Processes in §4), metric-based (Prototypical Networks in §5) — each with algorithmic derivation, working implementation, and held-out validation. We unified the three via hierarchical Bayes (§6), showing that MAML’s inner-loop initialization is a Gaussian-prior mean (Grant et al. 2018, Theorem 2), Latent NPs amortize variational inference over a task-distribution prior, and Prototypical Networks specialize discriminative training to a task-level decomposition. The unification carried two theoretical results: the Amit–Meir PAC-Bayes meta-generalization bound (§7) and the Fallah–Mokhtari–Ozdaglar convergence theorem for MAML (§8).

The topic discharges twelve previously-deferred forward-pointers across formalML — six in PAC-Bayes Bounds, plus one each in Stochastic-Gradient MCMC, Bayesian Neural Networks, Gaussian Processes, Sparse Bayesian Priors, and Variational Inference — plus the formalstatistics §28.10 Ex 14 pointer. The longest-running cross-site dependency in the formalML triad is now closed.

11.2 Vision benchmarks: Omniglot, miniImageNet, and what they need that we skipped

The canonical few-shot vision benchmarks — Omniglot (Lake, Salakhutdinov, Tenenbaum 2015) and miniImageNet (Vinyals et al. 2016) — are out of scope for this notebook. The reason is compute, not pedagogy: the math is the same, but the architectures and data scale push past a 60-second CPU budget by roughly two orders of magnitude.

What’s different: embedding networks become 4–6-layer CNNs with 30K–300K parameters rather than 2-layer MLPs; input dimensionality jumps from 1D/2D to 84×84×384 \times 84 \times 3 pixels (miniImageNet) or 28×2828 \times 28 (Omniglot); image-augmentation pipelines become standard preprocessing; task families are naturally large (Omniglot 1623 classes, miniImageNet 100). Production meta-training takes 60,000+ meta-iters at task batch size 32 on a single GPU, or distributed across 4–8 GPUs.

What vision benchmarks tell us that synthetic demos don’t: robustness to natural data variability; generalization across visual domains; scaling behavior of meta-learning algorithms in a regime where published methods land in narrow accuracy bands and differences come from architectural choices, augmentation, and training schedules rather than from the meta-algorithm itself. These are essential signals for practitioners.

11.3 Meta-optimizers, learned losses, and learned regularizers

The three lenses meta-learn a fixed aspect of the inner-loop algorithm: MAML the initialization, Neural Processes the inference network, Prototypical Networks the embedding network. A broader family meta-learns other aspects of the inner-loop algorithm.

Meta-optimizers (Andrychowicz et al. 2016; Ravi and Larochelle 2017) meta-learn the update rule: replace SGD step with an LSTM that takes (θ(n),L)(\theta^{(n)}, \nabla L) and produces the next iterate. The optimizer becomes task-distribution-specific. Expensive and brittle for long inner loops.

Learned losses (Houthooft et al. 2018) meta-learn the inner-loop loss function. The inner loop minimizes a parameterized loss; the outer query loss is unchanged. Useful when the natural training loss is misspecified.

Learned regularizers (Hospedales et al. 2022 §3.3) meta-learn regularization terms or implicit regularization through architecture/inductive bias. Smaller perturbations of the standard recipe.

All three families share the §6 hierarchical-Bayes structure: θ0\theta_0 now includes the inner-loop algorithm itself. The PAC-Bayes and convergence theory of §§7–8 extend with technical modifications.

11.4 Continual learning, online meta-learning, and the boundary with reinforcement learning

Continual learning addresses tasks arriving sequentially — model must learn each without forgetting previous ones, no access to old data. Naïve fine-tuning exhibits catastrophic forgetting; algorithms like Elastic Weight Consolidation (Kirkpatrick et al. 2017) regularize toward old-task parameters. Meta-learning’s hierarchical structure (per-task θT\theta_\mathcal{T} adapted from shared θ0\theta_0) is one solution.

Online meta-learning (Finn, Rajeswaran, Kakade, Levine 2019) makes this explicit: meta-learner updates θ0\theta_0 from a stream of tasks rather than a batch, with regret-minimization objectives replacing empirical meta-risk.

The boundary with reinforcement learning is meta-RL. Each task is an MDP; the meta-learner trains a policy that adapts quickly to new MDPs. RL² (Duan, Schulman, Chen, Bartlett, Sutskever, Abbeel 2016) treats meta-RL as a recurrent policy. Wang et al. (2016) and the gradient-based MAML extensions develop the policy-gradient version. Mathematical generalization is direct — hierarchical Bayes still works — but empirical setup is much harder. Each “task” requires environment rollouts; meta-RL training is two orders of magnitude more expensive than meta-supervised-learning. The convergence theory of §8 has been partially extended (Fallah, Georgiev, Mokhtari, Ozdaglar 2021), but analyses are more restrictive than supervised-learning.

The unifying thread: meta-learning, continual learning, and meta-RL all study learning systems facing task distributions rather than data distributions. The §6 hierarchical-Bayes framework is the right lens for all three.

11.5 Open frontiers

The biggest shift since 2020 is the rise of foundation models and the reinterpretation of meta-learning as in-context learning. Large language models (Brown et al. 2020, GPT-3 and subsequent generations) exhibit ability to adapt to new tasks from in-context examples alone — no gradient adaptation, just prompting. This is meta-learning by another name: the pretrained model has implicitly meta-trained over a vast task distribution during corpus training, and in-context examples function as the support set of §1.2’s hierarchical model.

Theoretical connections are being worked out. Xie, Raghunathan, Liang, Ma (2022) show in-context learning in transformers can be analyzed as approximate Bayesian inference over a meta-learned prior. Garg, Tsipras, Liang, Valiant (2022) demonstrate that transformers trained on synthetic regression tasks learn to in-context-learn linear regression, decision trees, and 2-layer MLPs without explicit meta-training. The implication: the explicit meta-learning framework of §§2–5 may be subsumed by implicit meta-learning emerging from sufficiently broad pretraining at scale.

Three open directions. Theory of in-context learning — extending Xie et al. and Garg et al. to broader function classes, sharper PAC-Bayes-style bounds. Scaling laws for meta-learning — how meta-learning compares to standard supervised pretraining as compute, data, and model size scale (Chan, Santoro, Lampinen, Wang, Singh, Richemond, McClelland, Hill 2022). Multimodal meta-learning — across modalities (vision-language, audio-text), where task structure is more heterogeneous and inductive-bias questions are sharper.

The frontier question — and the question this topic implicitly raises — is whether explicit meta-learning algorithms will remain practically relevant as foundation-model-style implicit meta-learning continues to scale. The §6 conceptual unification suggests they should: hierarchical Bayes is a structurally general framework, and the explicit algorithms make different approximation trade-offs that may matter at smaller scales, in specialized domains, or where interpretability of the meta-trained prior matters. A future revision of this topic will likely need a §12 covering in-context learning as a fourth lens.

Connections

  • Latent NPs are amortized Bayesian neural networks specialized to context-conditional prediction. §4's encoder is the inference network; §4.3's ELBO derivation lifts the single-task VI machinery from BNNs to the meta-distribution level. bayesian-neural-networks
  • The §4 NP demo uses GPs as a reference posterior — meta-learning a function-space prior whose ground truth is the exact GP marginal under the data-generating kernel. §4.4's Attentive NP recovers GP-like local-context fidelity through cross-attention. gaussian-processes
  • The Amit–Meir meta-bound (§7) is a double-application of McAllester PAC-Bayes: once at the within-task level, once at the across-task level, chained through a hierarchical KL. The present topic discharges the six forward-pointers from `pac-bayes-bounds` §§4, 5, 12.1. pac-bayes-bounds
  • The Latent NP training objective (4.4) is the meta-learning analog of VI's ELBO — amortized over tasks rather than per-iter from scratch. §6.3 reads MAML as Laplace-truncated marginal likelihood, an alternative to VI's variational lower bound. variational-inference
  • Meta-learning estimates a shared prior across tasks; sparse Bayesian priors estimate a shared sparsity structure. Both are about transferring inductive bias — meta-learning transfers it from task to task within a family, sparse priors transfer it from prior to posterior. sparse-bayesian-priors
  • SG-MCMC's per-iter posterior sampling is the within-task complement to meta-learning's across-task amortization: combine the two and you have probabilistic meta-learning with non-Gaussian within-task posteriors. The two topics share the §6 hierarchical-Bayes vocabulary. stochastic-gradient-mcmc

References & Further Reading