This post summarizes my learning note on the forward-mode (or tangent method) for sensitivity analysis of ordinary differential equations (ODEs). Similar to the previous posts on tangent method for linear and nonlinear equations, it can be extended for finding the sensitivity of the solution of ODEs or the gradient through the NeuralODE with repsect to the initial condition or the parameters. This post is based on the YouTube video: Neural ODEs - Pushforward/Jvp rule.
Consider the nonlinear ODE: \[ \frac{du}{dt} = f(u, \theta) \tag{1}\] which is evaluated at \(t=T\) with initial condition \(u(t=0)=u_0\) as \[ q(\theta,u_0,T) = u(T) \] where \(\theta\in\mathbb{R}^P\) is the parameter; \(u_0\in\mathbb{R}^N\) is the initial condition; \(T\) is the final time; and \(u(T)\in\mathbb{R}^N\) is the final state. The ODE can be solved by any ODE solver, such as the Euler method or Runge-Kutta method.
Our task is to forward-propagate the tangent information (a vector) on the inpus \(\dot{\theta}\in\mathbb{R}^P, \dot{u}_0\in\mathbb{R}^{N}, T\in\mathbb{R}\) to the output \(\dot{u}(T)\in\mathbb{R}^N\) without unrolling the solver and applying forward-mode AD to its operation.
In general, the forward-mode AD can be found the Jacobian-vector product on the total derivative of the output: \[ \dot{u}(T) = \underbrace{\frac{\partial q}{\partial \theta}\dot{\theta}}_{\dot{u}(T)_\theta} + \underbrace{\frac{\partial q}{\partial u_0}\dot{u}_0}_{\dot{u}(T)_{u_0}} + \underbrace{\frac{\partial q}{\partial T}\dot{T}}_{\dot{u}(T)_T} \]
We can find the tangent information item by item.
Tangent Condition from parameter \(\theta\)
From \[ u(T) = u_0 + \int_0^T f(u(t), \theta) dt \] take the total derivative with respect to \(\theta\): \[ \frac{d}{d\theta}u(T) = \frac{d}{d\theta}u_0 + \int_0^T \frac{\partial f}{\partial u}\frac{d}{d\theta}u(t) + \frac{\partial f}{\partial \theta} dt \]
Multiply both sides by \(\dot{\theta}\): \[ \underbrace{\frac{d}{d\theta}u(T) \cdot \dot{\theta}}_{\dot{u}(T)_\theta} = \underbrace{\frac{d}{d\theta}u_0 \cdot \dot{\theta}}_{=0} + \int_0^T \frac{\partial f}{\partial u}\underbrace{\frac{d}{d\theta}u(t) \cdot \dot{\theta}}_{\dot{u}(t)_\theta} + \frac{\partial f}{\partial \theta} \cdot \dot{\theta} dt \]
It is assumed that the initial condition \(u_0\) is not dependent on \(\theta\). However, this may not be true for example when the event is considered.
The above relationship is the solution of the following ODE on state \(\dot{u}(t)_\theta\): \[ \begin{aligned} \frac{d}{dt}\dot{u}(t)_\theta &= \frac{\partial f}{\partial u}\dot{u}(t)_\theta + \frac{\partial f}{\partial \theta}\dot{\theta} \\ \dot{u}(t=0)_\theta &= 0 \end{aligned} \tag{2}\]
By solving the tangent ODE using any ODE solver, we can find the tangent information \(\dot{u}(t)_\theta\). Note that this ODE is linear and inhomogeneous, although the original ODE can be nonlinear.
Alternative Derivation
In this section, we derive the tangent condition from the perspective of automatic differentiation in deep learning. The sensitivity derived can be used for optimization, local sensitivity analysis, dynamic/control, or neural ODE.
Consider there is a cost function \(J(u,\theta)\) associated with the state \(u(t)\): \[ J(u,\theta) = \int_0^Tg(u(t), \theta)dt \]
For instance, we can set a quadratic loss as \(g(u(t), \theta) = u(t)^TQu(t)\). Cost on finite time instance is also possible.
The total derivative of the cost function with respect to \(\theta\) is: \[ \frac{dJ}{d\theta} = \int_0^T \frac{d}{d\theta} g(u,\theta)dt = \int_0^T \frac{\partial g}{\partial \theta} + \frac{\partial g}{\partial u}\frac{d u}{d\theta} dt \in \mathbb{R}^{1\times P} \] which is a row vector.
The only term that is difficult to solve is \(\frac{d u}{d\theta}\in\mathbb{R}^{N\times P}\), e.g., \(\frac{du}{d\theta} = [\frac{du}{d\theta_1}, \cdots. \frac{du}{d\theta_P}]\) and each \(\frac{du}{d\theta_i}\in\mathbb{R}^N\).
Then do total derivative on Equation 1: \[ \frac{d}{d\theta} \frac{du}{dt} = \frac{d}{dt} \frac{du}{d\theta} = \frac{d}{d\theta} f(u, \theta) = \frac{\partial f}{\partial u}\frac{du}{d\theta} + \frac{\partial f}{\partial \theta} \]
Therefore, there is \(P\) linear ODEs to solve for \(\frac{du}{d\theta}\): \[ \begin{aligned} \frac{d}{dt}\frac{du}{d\theta_i} &= \frac{\partial f}{\partial u}\frac{du}{d\theta_i} + \frac{\partial f}{\partial \theta_i} \\ \frac{du}{d\theta_i}(t=0) &= 0 \end{aligned} \]
Once we have \(\frac{du}{d\theta}\), we can find the gradient of the cost function with respect to \(\theta\) by plugging it into the total derivative of the cost function.
Tangent Condition from initial condition \(u_0\)
Find the total derivative of \(u(T)\) with respect to \(u_0\) on both sides of Equation 2: \[ \frac{d}{du_0}u(T) = \frac{\partial u_0}{\partial u_0} + \int_0^T \frac{\partial f}{\partial u}\frac{d}{du_0}u(t) dt \]
Multiply both sides by \(\dot{u}_0\): \[ \underbrace{\frac{d}{du_0}u(T) \cdot \dot{u}_0}_{\dot{u}(T)_{u_0}} = \underbrace{\frac{\partial u_0}{\partial u_0} \cdot \dot{u}_0}_{=\dot{u}_0} + \int_0^T \frac{\partial f}{\partial u}\underbrace{\frac{d}{du_0}u(t) \cdot \dot{u}_0}_{\dot{u}(t)_{u_0}} dt \] which is again as the solution of the following ODE on state \(\dot{u}(t)_{u_0}\): \[ \begin{aligned} \frac{d}{dt}\dot{u}(t)_{u_0} &= \frac{\partial f}{\partial u}\dot{u}(t)_{u_0} \\ \dot{u}(t=0)_{u_0} &= \dot{u}_0 \end{aligned} \] which is a linear and homogeneous ODE.
Tangent Condition from final time \(T\)
Find the total derivative of \(u(T)\) with respect to \(T\) on both sides of Equation 2: \[ \frac{d}{dT}u(T) = \frac{d}{dT}u_0 + f(u(T), \theta) \] Note that we have \(\frac{d}{dT}\int_0^T f(u(t), \theta) dt = f(u(T), \theta)\).
Multiply both sides by \(\dot{T}\): \[ \underbrace{\frac{d}{dT}u(T) \cdot \dot{T}}_{\dot{u}(T)_T} = \underbrace{\frac{d}{dT}u_0 \cdot \dot{T}}_{=0} + f(u(T), \theta) \cdot \dot{T} \] which is an algebraic equation (only for the final time).
Computation
Note that \(\dot{\theta}, \dot{u}_0, \dot{T}\) are known vectors. Meanwhile, the Jacobian \(\frac{\partial f}{\partial u}\) and \(\frac{\partial f}{\partial \theta}\) can be found by analytical derivation or automatic differentiation. In most of the cases, the Jacobians do not need to be computed or stored explicitly. Instead, the Jacobian-vector product (JVP) can be computed by the forward-mode AD, e.g. by torch.autograd.forward_ad
in PyTorch.
Because \(\frac{\partial f}{\partial u}\) and \(\frac{\partial f}{\partial \theta}\) are functions of \(u(t)\), its value is dependent on the solution of the primal ODE. One method is to store the solution of the primal ODE. However, a more compact method is to simutaneously solve the primal ODE and the tangent ODEs. For example, the augmented ODE system is: \[ \begin{aligned} \frac{du}{dt} &= f(u, \theta) \\ \frac{d}{dt}\dot{u}(t)_{u_0} &= \frac{\partial f}{\partial u}\dot{u}(t)_{u_0} \\ \frac{d}{dt}\dot{u}(t)_\theta &= \frac{\partial f}{\partial u}\dot{u}(t)_\theta + \frac{\partial f}{\partial \theta}\dot{\theta} \end{aligned} \] where the initial conditions are \(u(t=0)=u_0\), \(\dot{u}(t=0)_{u_0}=\dot{u}_0\) and \(\dot{u}(t=0)_\theta=0\). The ODEs can be solved by any ODE solver, such as the Euler method or Runge-Kutta method.
The tangent method can also be used to find the Jacobian of the output with respect to the parameter \(\theta\) by setting the tangent vector \(\dot{\theta}\) to be the unit vectors. Not suprisingly, the \(P\) number of ODEs need to be solved to find the Jacobian \(\frac{d u(t)}{d \theta}\).
Pytorch Implementation
An example code snippet can be found at my-github including a parallel batched implementation.