Tangent and Adjoint Sensitivity Analysis of Linear Equations

Auto Differentiation
Implicit Function
Author

Wangkun Xu

Published

June 3, 2024

Modified

June 4, 2024

This post contains my learning note for YouTube: Adjoint Equation of a Linear System of Equations - by implicit derivative.

All credits go to the author of the video.

Settings

Consider a linear system of equations \[ A(\theta) x = b(\theta) \tag{1}\] with a loss function \(J(x)\) . Our goal is to find the total derivative \(\frac{d J}{d \theta}\). This gradient can be useful for:

  1. Gradient-based optimization.
  2. Local sensitivity analysis of linear equations.

where \(\theta\in\mathbb{R}^P\) (this can be the weights of neural network); \(A\in\mathbb{R}^{M\times N}\); \(x\in\mathbb{R}^N\); \(b\in\mathbb{R}^M\); \(J(x;\theta): \mathbb{R}^N \times \mathbb{R}^P \rightarrow \mathbb{R}\).

Note that \(x\) is dependent on \(\theta\) through \(A\) and \(b\). The total derivative \(\frac{d J}{d \theta}\) can be computed using the chain rule: \[ \frac{d J}{d \theta} = \frac{\partial J}{\partial x} \frac{d x}{d \theta} + \frac{\partial J}{\partial \theta} \tag{2}\] where we use the Jacobian convension such that \(\frac{d x}{d \theta}\) is a matrix of size \(N\times P\) and it is difficult to compute directly.

Tangent Sensitivity Analysis

Do total derivative of (Equation 1) with respect to \(\theta\): \[ \frac{d}{d\theta} (A x) = \frac{d b}{d\theta} \]

where the unknown \(\frac{d x}{d \theta}\) can found by solving the following linear system

\[ A\cdot\frac{d x}{d \theta} = \frac{db}{d\theta} - \frac{d A}{d\theta} x \] where \(\frac{d x}{d \theta}\) can be solved as \[ \frac{d x}{d \theta} = A^{-1} \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x\right) \tag{3}\]

Note that the dimension \(\frac{d A}{d \theta}\in\mathbb{R}^{N\times N\times P}\). Therefore the product \(\frac{d A}{d\theta} x\) is incorrect (but it is ok here).

This is a batch of linear system we want to solve. Let \(\theta_i\) be the \(i\)-th element of \(\theta\), \[ A\cdot\frac{d x}{d \theta_i} = \frac{db}{d\theta_i} - \underbrace{\frac{d A}{d\theta_i}}_{N\times N} x, \quad i=1,\dots,P \]

We can view \(\frac{d x}{d \theta_i}\) as the tangent of \(x\), e.g., \(\dot{x}_i = \frac{d x}{d \theta_i}\). Then solving the above system follows the idea of tangent sensitivity analysis (or forward-mode AD). The matrix-vector products can be computed efficiently using the JVP. However, this method is less efficient when \(P\) is large.

Adjoint Sensitivity Analysis

Plug (Equation 3) into (Equation 2), we have \[ \frac{d J}{d \theta} = \underbrace{\frac{\partial J}{\partial x} A^{-1}}_{\lambda^T:1\times N} \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x\right) + \frac{\partial J}{\partial \theta} \]

Now instead of solving the linear system as in the tangent method (which requires solving \(P\) linear systems), note that the term \(\frac{\partial J}{\partial x} A^{-1}\) is a vector of size \(1\times N\) which can be solved by the following linear system once: \[ A^T \lambda = \left(\frac{\partial J}{\partial x}\right)^T \tag{4}\]

Then the gradient \(\frac{d J}{d \theta}\) can be computed as \[ \frac{d J}{d \theta} = \lambda^T \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x\right) + \frac{\partial J}{\partial \theta} \]

This is the idea of adjoint sensitivity analysis (or reverse-mode AD). It can be considered as assigning the adjoint of \(x\) as \(\bar{x} = \frac{\partial J}{\partial x}\).

The adjoint sensitivity can be solved by solving two linear systems: (Equation 1) for \(x\) and (Equation 4) for \(\lambda\). This is more efficient when \(P\) is large (especially when the loss function is a scalar).

Also note that the Jacobians \(\frac{d A}{d\theta}\), \(\frac{d b}{d\theta}\), \(\frac{dA}{d\theta}\), and \(\frac{\partial J}{\partial x}\) can be computed by automatic differentiation or analytical solution.

Alternative Derivation Using Lagrange Multiplier

This part is based on the YouTube: Adjoint Sensitivities of a Linear System of Equations - derived using the Lagrangian.

As the sensitivity analysis can be directly used for purturbation analysis in an optimization problem, e.g., to find how a small change on \(\theta\) can affect the objective function \(J(x)\), we can consider the following optimization problem \[ \min_{x} J(x, \theta) \quad \text{s.t.} \quad A(\theta) x = b(\theta) \]

The equality constraint can be regarded as the KKT condition of another convex optimization problem. I.e., the original problem is actually a bi-level optimization problem. The Lagrangian of the above problem is \[ \mathcal{L}(x, \theta, \lambda) = J(x(\theta), \lambda) + \lambda^T (b(\theta) - A(\theta) x) \] where \(\lambda\) is the Lagrange multiplier. The total derivative wrt \(\theta\) is \[ \frac{d\mathcal{L}}{d\theta} = \frac{\partial J}{\partial x} \frac{d x}{d \theta} + \frac{\partial J}{\partial \theta} + \lambda^T \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x - A(\theta)\frac{dx}{d\theta}\right) \]

Again, the dimension of \(\frac{d A}{d\theta}\) is incorrect. The difficult term is \(\frac{dx}{d\theta}\). After some arrangement, we have \[ \frac{d\mathcal{L}}{d\theta} = \frac{\partial J}{\partial \theta} + \lambda^T \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x\right) + \underbrace{\left(\frac{\partial J}{\partial x} - \lambda^T A\right)}_{\rightarrow 0} \frac{dx}{d\theta} \]

Note that because \(x\) is solved as the solution to \(Ax = b\), the equality constraint is always satisfied. Therefore, the value of \(\lambda\) can be arbitrary. Consequently, we obtain the adjoint system the same to the previous derivation Equation 4.

PLugging in \(\lambda\) into the Lagrangian, we have \[ \frac{d \mathcal{L}}{d \theta} = \frac{d J}{d \theta} = \lambda^T \left(\frac{db}{d\theta} - \frac{d A}{d\theta} x\right) + \frac{\partial J}{\partial \theta} \] where the first equality is due to \(Ax=b\).