Tangent and Adjoint Sensitivity Analysis of Nonlinear Equations

Auto Differentiation
Implicit Function
Author

Wangkun Xu

Published

June 5, 2024

Modified

June 18, 2024

This post extends from the previous post on linear system to find the sensitivities of the parameter of nonlinear systems. This post is largely learnt from the YouTube: Adjoint Sensitivities of a Non-Linear system of equations | Full Derivation and YouTube: Lagrangian Perspective on the Derivation of Adjoint Sensitivities of Nonlinear Systems. Another reference is Deep Implicit Layers

Settings

Consider a nonlinear system of equations \[ f(x, \theta) = 0 \] where \(x \in \mathbb{R}^N\) is the state variable and \(\theta \in \mathbb{R}^P\) is the parameter. Nonlinear equation solvers such as Newton’s method can be used to find \(x\) given \(\theta\). Assume there is a scalar loss function \(J(x,\theta)\) and our goal is to find the sensitivity or total gradient of \(J\) with respect to \(\theta\): \(\frac{d J}{d \theta}\).

The total derivative of \(J\) is \[ \frac{d J}{d \theta} = \frac{\partial J}{\partial x} \frac{d x}{d \theta} + \frac{\partial J}{\partial \theta} \tag{1}\] where \(\frac{d x}{d \theta}\) is unknown. Do the total derivative of \(f\) with respect to \(\theta\): \[ \frac{d f}{d \theta} = \frac{\partial f}{\partial x} \frac{d x}{d \theta} + \frac{\partial f}{\partial \theta} = 0 \tag{2}\]

Therefore, \(\frac{d x}{d \theta}\) can be solved by the above linear equation. In detail, Equation 1 can be rewritten as \[ \frac{d J}{d \theta} = -\frac{\partial J}{\partial x} \left( \frac{\partial f}{\partial x} \right)^{-1} \frac{\partial f}{\partial \theta} + \frac{\partial J}{\partial \theta} \tag{3}\]

Tangent Sensitivity Analysis

In the tangent (forward) method, the term \(\left( \frac{\partial f}{\partial x} \right)^{-1} \frac{\partial f}{\partial \theta}\) is computed by solving the batch of linear system Equation 2 directly. Denote the \(i\)-th column of \(\frac{\partial f}{\partial \theta}\) as \(g_i\), then \(P\) linear systems need to be solved: \[ \frac{\partial f}{\partial x} \frac{dx}{d\theta_i} = -g_i \]

Adjoint Sensitivity Analysis

Equation 3 can be solved from left to right by first computing \(-\frac{\partial J}{\partial x} \left( \frac{\partial f}{\partial x} \right)^{-1}\) and then multiplying \(\frac{\partial f}{\partial \theta}\). The adjoint linear system is \[ \left( \frac{\partial f}{\partial x} \right)^T \lambda = -\left(\frac{\partial J}{\partial x}\right)^T \tag{4}\] which can be solved by conjugate gradient method or LU deomposition. The Jacobian matrix \(\frac{\partial f}{\partial x}\) may not need to be solved explicitely but can be found by VJP.

In the adjoint method, there is only one linear system to solve (note that the original system is nonlinear), regardless of the number of parameters \(P\).

Alternative Derivation using Lagrangian

Similar to the linear system, we can derive the adjoint sensitivity analysis for nonlinear system from the Lagrangian perspective.

Consider the equality constrained optimization: \[ \min_{x} J(x, \theta) \quad \text{s.t.} \quad f(x, \theta) = 0 \]

The Lagrangian is \[ \mathcal{L}(x, \lambda, \theta) = J(x, \theta) + \lambda^T f(x, \theta) \]

Take the total derivative of \(\mathcal{L}\) with respect to \(\theta\): \[ \frac{d \mathcal{L}}{d \theta} = \frac{\partial J}{\partial \theta} + \frac{\partial J}{\partial x} \frac{dx}{d\theta} + \lambda^T \left(\frac{\partial f}{\partial \theta} + \frac{\partial f}{\partial x} \frac{dx}{d\theta}\right) = \frac{\partial J}{\partial \theta} + \lambda^T\frac{\partial f}{\partial \theta} + \left(\frac{\partial J}{\partial x} + \lambda^T \frac{\partial f}{\partial x}\right) \frac{dx}{d\theta} \]

Because the equality constraint is always satisfied (as \(x\) is solved from \(f(x, \theta) = 0\)), we can set the dual variable \(\lambda\) arbitarily. Here, we can choose to make the coefficient of \(\frac{dx}{d\theta}\) to be zero so that this complex term never appears in the final expression. \[ \frac{\partial J}{\partial x} + \lambda^T \frac{\partial f}{\partial x} = 0 \] which is the adjoint equation Equation 4.

At last, because \(f(x,\theta)=0\), we have \[ \frac{d \mathcal{L}}{d \theta} = \frac{d J}{d \theta} + \lambda^T\frac{d f}{d \theta} = \frac{d J}{d\theta} \]

Relation to Linear System

We can rewrite the linear system as \[ f(x, \theta) = b(\theta) - A(\theta) x = 0 \] with \[ \frac{\partial f}{\partial \theta} = \frac{\partial f}{\partial b}\frac{db}{d\theta} + \frac{\partial f}{\partial A}\frac{dA}{d\theta} = \frac{db}{d\theta} - \frac{dA}{d\theta}x \neq 0 \]

Therefore, the total derivative can be written as \[ \frac{df}{d\theta} = \frac{\partial f}{\partial x} \frac{dx}{d\theta} + \frac{\partial f}{\partial \theta} = -A \frac{dx}{d\theta} + \frac{db}{d\theta} - \frac{dA}{d\theta}x \] which recover the derivation of the linear system.

Implementation

An example implementation of batched-version of adjoint sensitivity analysis has been added to here.