GSoC 2019: Gradient Descent for ODEs (But This Time, In Theano)

7 minute read

Published:

A little while ago, I wrote a post on doing gradient descent for ODEs. In that post, I used autograd to do the automatic differentiation. While neat, it was really a way for me to get familiar with some math that I was to use for GSoC. After taking some time to learn more about theano, I’ve reimplemented the blog post, this time using theano to perform the automatic differentiation. If you’re read the previous post, then skip right to the code.

Gradient descent usually isn’t used to fit Ordinary Differential Equations (ODEs) to data (at least, that isn’t how the Applied Mathematics departments to which I have been a part have done it). Nevertheless, that doesn’t mean that it can’t be done. For some of my recent GSoC work, I’ve been investigating how to compute gradients of solutions to ODEs without access to the solution’s analytical form. In this blog post, I describe how these gradients can be computed and how they can be used to fit ODEs to synchronous data with gradient descent.

Up To Speed With ODEs

I realize not everyone might have studied ODEs. Here is everything you need to know:

A differential equation relates an unknown function $y \in \mathbb{R}^n$ to it’s own derivative through a function $f: \mathbb{R}^n \times \mathbb{R} \times \mathbb{R}^m \rightarrow \mathbb{R}^n$, which also depends on time $t \in \mathbb{R}$ and possibly a set of parameters $\theta \in \mathbb{R}^m$. We usually write ODEs as

Here, we refer to the vector $y$ as “the system”, since the ODE above really defines a system of equations. The problem is usually equipped with an initial state of the system $y(t_0) = y_0$ from which the system evolves forward in $t$. Solutions to ODEs in analytic form are often very hard if not impossible, so most of the time we just numerically approximate the solution. It doesn’t matter how this is done because numerical integration is not the point of this post. If you’re interested, look up the class of Runge-Kutta methods.

Computing Gradients for ODEs

In this section, I’m going to be using derivative notation rather than $\nabla$ for gradients. I think it is less ambiguous.

If we want to fit an ODE model to data by minimizing some loss function $\mathcal{L}$, then gradient descent looks like

In order to compute the gradient of the loss, we need the gradient of the solution, $y$, with respect to $\theta$. The gradient of the solution is the hard part here because it can not be computed (a) analytically (because analytic solutions are hard AF), or (b) through automatic differentiation without differentiating through the numerical integration of our ODE (which seems computationally wasteful).

Thankfully, years of research into ODEs yields a way to do this (that is not the adjoint method. Surprise! You thought I was going to say the adjoint method didn’t you?). Forward mode sensitivity analysis calculates gradients by extending the ODE system to include the following equations:

Here, $\mathcal{J}$ is the Jacobian of $f$ with respect to $y$. The forward sensitivity analysis is just another differential equation (see how it relates the derivative of the unknown $\partial y / \partial \theta$ to itself?)! In order to compute the gradient of $y$ with respect to $\theta$ at time $t_i$, we compute

I know this looks scary, but since forward mode sensitivities are just ODEs, we actually just get this from what we can consider to be a black box

So now that we have our gradient in hand, we can use the chain rule to write

We can use automatic differentiation to compute $\dfrac{\partial \mathcal{L}}{\partial y}$.

OK, so that is some math (interesting to me, maybe not so much to you). Let’s actually implement this in python.

Gradient Descent for the SIR Model

The SIR model is a set of differential equations which govern how a disease spreads through a homogeneously mixed closed populations. I could write an entire thesis on this model and its various extensions (in fact, I have), so I’ll let you read about those on your free time.

The system, shown below, is parameterized by a single parameter:

Let’s define the system, the appropriate derivatives, generate some observations and fit $\theta$ using gradient descent. Here si what you’ll need to get started:

import theano
import theano.tensor as tt
import numpy as np
import matplotlib.pyplot as plt
import scipy.integrate

Let’s then define the ODE system

def f(y, t, theta):
    """"This is the ODE system.

    The function can act on either numpy arrays or theano TensorVariables

    Args:
        y (vector): system state
        t (float): current time (optional)
        theta (vector): parameters of the ODEs

    Returns:
        dydt (list): result of the ODEs
    """
    return [
        -theta*y[0]*y[1],   #= dS/dt
        theta*y[0]*y[1] - y[1] #= dI/dt
            ]

and create a computation graph with theano.

#Define the differential Equation

#Present state of the system
y = tt.dvector('y')
#Parameter: Basic reproductive ratio
p = tt.dscalar('p')
#Present state of the gradients: will always be 0 unless the parameter is the inital condition
dydp = tt.dvector('dydp')

f_tensor = tt.stack(f(y, None, p))

#Now compute gradients
#Use Rop to compute the Jacobian vector product
Jdfdy = tt.Rop(f_tensor, y, dydp)

grad_f = tt.jacobian(f_tensor, p)
#This is the time derivative of dydp
ddt_dydp = Jdfdy + grad_f


#Compile the system as a theano function
#Args:
    #y - array of length 2 representing current state of the system (S and I)
    #dydp - array of length 2 representing current state of the gradient (dS/dp and dI/dp)
system = theano.function(
    inputs=[y, dydp, p],
    outputs=[f_tensor, ddt_dydp]
)

Next, we’ll define the augmented system (that is, the ODE plus the sensitivities).

#create a function to spit out derivatives
def ODESYS(Y,t,p):
    """
    Args:
        Y (vector): current state and gradient
        t (scalar): current time
        p (vector): parameters

    Returns:
        derivatives (vector): derivatives of state and gradient
    """
    dydt, dydp = system(Y[:2], Y[2:], p)
    return np.concatenate([dydt, dydp])

We’ll optimize the $L_2$ norm of the error. This is done in theano.

#Tensor for observed data
t_y_obs = tt.dmatrix('y_obs')
#Tensor for predictions
t_y_pred = tt.dmatrix('y_pred')

#Define error and cost
err = t_y_obs - t_y_pred
Cost = err.norm(2,axis = 1).sum()
Cost_gradient = theano.tensor.grad(Cost,t_y_pred)

C = theano.function(inputs = [t_y_obs, t_y_pred], outputs=Cost)
del_C = theano.function(inputs = [t_y_obs, t_y_pred], outputs=Cost_gradient)

Create some observations from which to fit

#Initial Condition
np.random.seed(19920908)
Y0 = np.array([0.99,0.01, 0.0, 0.0])
#Space to compute solutions
t_dense = np.linspace(0,5,101)
#True param value
p_true = 8

y_hat_theano = scipy.integrate.odeint(ODESYS, y0=Y0, t=t_dense, args=(p_true,))

S_obs,I_obs, *_ = y_hat_theano.T + np.random.normal(0,0.1,size = y_hat_theano.T.shape)
y_obs = np.c_[S_obs,I_obs]

Perform Gradient Descent


p_gd = 1.1
learning_rate = 0.01

num_steps = 1000

prev_cost = float('inf')

tol = 1e-5

while True:


    #Evaluate solution
    y_hat_theano = scipy.integrate.odeint(ODESYS, y0=Y0, t=t_dense, args=(p_gd,))

    #Splice out the numerical solution and numerical gradients
    y_pred = y_hat_theano[:,:2]
    gradients = y_hat_theano[:,2:]

    #Perform the gradient step
    p_gd-= learning_rate*np.sum(del_C(y_obs,y_pred)*gradients)

    #Has the loss changed a large amount?
    cost = C(y_obs,y_pred)

    #If so, keep going.  Stop when the loss stops shrinking
    if abs(cost - prev_cost)<tol:
        break

    prev_cost = cost


And lastly, compare our fitted curves to the true curves

plt.figure(dpi = 120)

plt.scatter(t_dense,y_obs[:,0], alpha = 0.5, marker = '.')
plt.scatter(t_dense,y_obs[:,1],alpha = 0.5, marker = '.')

plt.plot(t_dense,y_hat_theano[:,0], color = 'k', zorder = 100)
plt.plot(t_dense,y_hat_theano[:,1], color = 'k', zorder = 100)

plt.plot(t_dense,y_hat_analytical[:,0], linewidth = 3)
plt.plot(t_dense,y_hat_analytical[:,1], linewidth = 3)

plt.legend()

Here is our result!