Deep Implicit Layers: Fixed-Point Iteration

Auto Differentiation
Implicit Function
Author

Wangkun Xu

Published

June 11, 2024

Modified

June 11, 2024

Introduction

In the previous posts, I summarize the mathematical background of the adjoint method for linear system and nonlinear system. This post will summarize them in the view of deep implicit layers. The main reference of this post is Deep Implicit Layers.

The nonlinear equation \(g(x,z)=0\) can be viewed beyond the a simple algebraic equation. Similiar to the sensitivity analysis of the equation, we can design the neural network as an implicit function of the parameter and the solution. The gradinet used for backpropagation can be found by the same idea of sensitivity analysis in the previous posts. This means that we can 1. Encode the implicit layer with physical meaning as part of the neural network. 2. Regard the entire neural network or part of it as a implicit model. For example, ResNet can be viewed as NeuralODE and feedforward neural network can be viewed as deep equilibrium model.

Implicit layers for different types of equations.
Equation Type Neural Network
Algebraic equation (fixed point iteration) Deep equilibrium model
Ordinary differential equation NeuralODE
Convex optimization Differentiable convex layer

The benefits of using implicit layers are 1. The solution can be found by off-the-shelf solver, regardless of the the layer itself. E.g., the fixed point iteration can be solved by Newton’s method; the ODE can be solved by the ODE solver such as Euler’s method; the convex optimization can be solved by the convex optimization solver, such as ADMM. And more. 2. Because the solution procedure is separated from the layer, it does not need to be recorded on the computational graph (although the solution procedure can be unrolled on the computatonal graph). This improves the memory efficiency and numetical stability. 3. Because the forward pass of implicit layer requires a solution procedure which is usually an iterative process (thus repeated nonlinearity), the representation power of the implicit layer is stronger than the explicit layer.

Fixed Point Iteration

A fixed point iteration \[ z^{\star}=\tanh \left(W z^{\star}+x\right) \] can be written as a nonlinear equation \[ g(x, z)=z-\tanh \left(W z+x\right)=0 \] where \(W\in\mathbb{R}^{n\times n}\).

Forward Pass

For a given \(x\), Newton’s method can be used to iteratively solve the equation, \[ z:= z - \left(\frac{\partial g}{\partial z}\right)^{-1} g(x,z) \]

The partial Jacobian of \(g\) with repect to \(z\) can be found by automatic differentation tool such as torch.func.jacfwd() in general. For the simple case, it can be found analytically as \[ \frac{\partial g}{\partial z} = I - \text{diag}(\tanh'(Wz+x))W \]

Note that the forward pass is not recorded on the computational graph.

Backward Pass

We can use the same technique in nonlinear adjoint method to do the reverse mode differentiation or bachpropagation. To have a quick review, do the total derivative of \(g\) with respect to \(x\): \[ \frac{d g}{d x} = \frac{\partial g}{\partial x} + \frac{\partial g}{\partial z} \frac{d z}{d x} = 0 \]

The Jacobian \(\frac{d z}{d x}\) can be found as \[ \frac{d z}{d x} = -\left(\frac{\partial g}{\partial z}\right)^{-1} \frac{\partial g}{\partial x} \]

Note that the term \(\frac{\partial g}{\partial z}\) has already been computed in the forward pass, e.g. we can directly use the Jacobian and its inverse calculated in the last iteration of Newton’s method.

Tip

This is a very common observation in various implicit layer application where you can do the reverse-mode differentiation for ‘free’.

Directly solving the Jacobian \(\frac{d z}{d x}\) is not efficient. As in the Tangent method, the number of linear systems that need to solve is the same as the number of the parameters \(x\).

Modern deep learning model is trained by reverse-mode differentiation, which is more efficient than the forward-mode differentiation. Let \(\ell(\cdot)\) be the scalar loss function, the Jacobian with respect to \(x\) can be found by \[ \frac{d \ell}{d x} = \frac{d \ell}{d z} \frac{d z}{d x} = \frac{d \ell}{d z} \left(-\left(\frac{\partial g}{\partial z}\right)^{-1} \frac{\partial g}{\partial x}\right) \]

Again we can find the vector Jacobian product of the first two terms as the new adjoint of \(z\) denoted as \(\dot{z}\): \[ \left(\frac{\partial g}{\partial z}\right)^T \dot{z} = \left(\frac{d \ell}{d z}\right)^T \tag{1}\]

Then we need to re-engage \(x\) on the computational graph. This can be done by calculating \[ z:= z - g(x,z) \] whose gradient with respect to \(x\) is \(-\frac{\partial g}{\partial x}\), which is the last term in \(\frac{d\ell}{dx}\):

\[ \frac{dz}{dx} = -\frac{\partial g}{\partial x} \]

(The minus sign depends how the adjoint system is defined.)

To sum up, the process of forward and backward pass of fixed-point iteration is

  1. Forward pass: Solve the nonlinear equation \(g(x,z)=0\) by off-the-shelf solver. This is outside the automatic differentiation tape.
  2. In the automatic differentiation tape, re-engage \(x\) on the computational graph by \(z:= z - g(x,z)\).
  3. Modify the gradient of the above \(z\) as the solution to Equation 1, using the register_hook() method in PyTorch.

PyTorch Implementation

Here, I implement a simple fixed-point iteration layer in PyTorch and compare it to the previous method in the nonlinear adjoint method.

import torch
from torch import nn

def loss_fn(z):
    return torch.sum(z**2, axis = -1).mean()

class FixedPointLayer(torch.nn.Module):
    def __init__(self, W, tol = 1e-4, max_iter = 1):
        super(FixedPointLayer, self).__init__()
        self.W = torch.nn.Parameter(W, requires_grad = True)
        self.tol = tol
        self.max_iter = max_iter
        # implement by vmap
        self.implicit_model = torch.vmap(self.implicit_model_)
        self.jac_batched = torch.vmap(torch.func.jacfwd(self.implicit_model_, argnums = 0))

    def implicit_model_(self, z, x):
        return z - torch.tanh(self.W @ z + x)
    
    def newton_step(self, z, x, g):
        J = self.jac_batched(z, x)
        z = z - torch.linalg.solve(J, g)
        return z, J

    def forward(self, x):
        self.iteration = 0
        with torch.no_grad():
            z = torch.tanh(x)
            while self.iteration < self.max_iter:
                g = self.implicit_model(z, x)
                self.err = torch.norm(g)

                if self.err < self.tol:
                    break

                # newton's method
                z, J = self.newton_step(z, x, g)
                self.iteration += 1
        
        # re-engage the autograd tape
        z = z - self.implicit_model(z, x)
        z.register_hook(lambda grad : torch.linalg.solve(J.transpose(1,2), grad))

        return z

def implicit_model(W, x, z):
    # the g function
    return z - torch.tanh(W @ z + x)


def implicit_model_test(W, x, z):

    if x.dim() == 1:
        # single sample case
        print('using the implicit model on one sample')
        z_ = z.clone().detach()
        x_ = x.clone().detach()
        
        dl_dz = torch.func.grad(loss_fn)(z_)
        df_dW, df_dz = torch.func.jacfwd(implicit_model, argnums = (0,2))(W, x_, z_)
        
        adjoint_variable = torch.linalg.solve(df_dz.T, -dl_dz)
        
        dl_dW = torch.einsum('i,ikl->kl', adjoint_variable, df_dW)
    
    else:
        print('using the implicit model on all samples')
        z = z.clone().detach()
        x = x.clone().detach()
        
        dl_dz = torch.func.grad(loss_fn)(z)

        jacfwd_batched = torch.vmap(torch.func.jacfwd(implicit_model, argnums = (0,2)), in_dims = (None, 0, 0))
        df_dW, df_dz = jacfwd_batched(W, x, z)

        adjoint_variable = torch.linalg.solve(df_dz.transpose(1,2), -dl_dz)

        dl_dW = torch.einsum('bi,bikl->kl', adjoint_variable, df_dW)
    
    print('dl_dz', dl_dz.shape)
    print('df_dW', df_dW.shape)
    print('df_dz', df_dz.shape)
    print('adjoint_variable', adjoint_variable.shape)
    print('dl_dW', dl_dW)

# maint function
torch.random.manual_seed(0)

batch_size = 10
n = 5
W = torch.randn(n,n).double() * 0.5
x = torch.randn(batch_size,n, requires_grad=True).double()

print('using the model')
model = FixedPointLayer(W, tol=1e-10, max_iter = 50).double()

# check with the numerical gradient
torch.autograd.gradcheck(model, x, check_undefined_grad=False, raise_exception=True)

z = model(x)
loss = loss_fn(z)
loss.backward()
print(model.W.grad)

# implicit model method
implicit_model_test(W, x[0], z[0])
implicit_model_test(W, x, z)

In the above example, torch.vmap() is used for multi-batch implementation of the implicit function and the Jacobian. The torch.func.jacfwd() is used for the Jacobian calculation. Note that the Jacobian \(\frac{\partial g}{\partial z}\) is found by automatic differentiation, instead of analytical computation. The torch.linalg.solve() is used for the linear system solution. The torch.einsum() is used for the matrix multiplication.

In detail, in

jacfwd_batched = torch.vmap(torch.func.jacfwd(implicit_model, argnums = (0,2)), in_dims = (None, 0, 0)

The torch.func.jacfwd(implicit_model, argnums = (0,2)) is to find the Jacobian of the implicit model with respect to \(W\) and \(z\). The torch.vmap() is to apply the torch.func.jacfwd() to each batch of the input. The in_dims = (None, 0, 0) is to specify that only z and x are batched while W is not.