GSoC 2019: Gradient Descent for ODEs (But This Time, In Theano)

8 minute read


A little while ago, I wrote a post on doing gradient descent for ODEs. In that post, I used autograd to do the automatic differentiation. While neat, it was really a way for me to get familiar with some math that I was to use for GSoC. After taking some time to learn more about theano, I’ve reimplemented the blog post, this time using theano to perform the automatic differentiation. If you’re read the previous post, then skip right to the code.

Gradient descent usually isn’t used to fit Ordinary Differential Equations (ODEs) to data (at least, that isn’t how the Applied Mathematics departments to which I have been a part have done it). Nevertheless, that doesn’t mean that it can’t be done. For some of my recent GSoC work, I’ve been investigating how to compute gradients of solutions to ODEs without access to the solution’s analytical form. In this blog post, I describe how these gradients can be computed and how they can be used to fit ODEs to synchronous data with gradient descent.

Up To Speed With ODEs

I realize not everyone might have studied ODEs. Here is everything you need to know:

A differential equation relates an unknown function $y \in \mathbb{R}^n$ to it’s own derivative through a function $f: \mathbb{R}^n \times \mathbb{R} \times \mathbb{R}^m \rightarrow \mathbb{R}^n$, which also depends on time $t \in \mathbb{R}$ and possibly a set of parameters $\theta \in \mathbb{R}^m$. We usually write ODEs as

\[y' = f(y,t,\theta) \quad y(t_0) = y_0\]

Here, we refer to the vector $y$ as “the system”, since the ODE above really defines a system of equations. The problem is usually equipped with an initial state of the system $y(t_0) = y_0$ from which the system evolves forward in $t$. Solutions to ODEs in analytic form are often very hard if not impossible, so most of the time we just numerically approximate the solution. It doesn’t matter how this is done because numerical integration is not the point of this post. If you’re interested, look up the class of Runge-Kutta methods.

Computing Gradients for ODEs

In this section, I’m going to be using derivative notation rather than $\nabla$ for gradients. I think it is less ambiguous.

If we want to fit an ODE model to data by minimizing some loss function $\mathcal{L}$, then gradient descent looks like

\[\theta_{n+1} = \theta_n - \alpha \dfrac{\partial \mathcal{L}}{\partial \theta}\]

In order to compute the gradient of the loss, we need the gradient of the solution, $y$, with respect to $\theta$. The gradient of the solution is the hard part here because it can not be computed (a) analytically (because analytic solutions are hard AF), or (b) through automatic differentiation without differentiating through the numerical integration of our ODE (which seems computationally wasteful).

Thankfully, years of research into ODEs yields a way to do this (that is not the adjoint method. Surprise! You thought I was going to say the adjoint method didn’t you?). Forward mode sensitivity analysis calculates gradients by extending the ODE system to include the following equations:

\[\dfrac{d}{dt}\left( \dfrac{\partial y}{\partial \theta} \right) = \mathcal{J}_f \dfrac{\partial y}{\partial \theta} +\dfrac{\partial f}{\partial \theta}\]

Here, $\mathcal{J}$ is the Jacobian of $f$ with respect to $y$. The forward sensitivity analysis is just another differential equation (see how it relates the derivative of the unknown $\partial y / \partial \theta$ to itself?)! In order to compute the gradient of $y$ with respect to $\theta$ at time $t_i$, we compute

\[\dfrac{\partial y}{\partial \theta} = \int_{t_0}^{t_i} \mathcal{J}_f \dfrac{\partial y}{\partial \theta} + \dfrac{\partial f}{\partial \theta} \, dt\]

I know this looks scary, but since forward mode sensitivities are just ODEs, we actually just get this from what we can consider to be a black box

\[\dfrac{\partial y}{\partial \theta} = \operatorname{BlackBox}(f(y,t,\theta), t_0, y_0, \theta)\]

So now that we have our gradient in hand, we can use the chain rule to write

\[\dfrac{\partial \mathcal{L}}{\partial \theta} =\dfrac{\partial \mathcal{L}}{\partial y} \dfrac{\partial y}{\partial \theta}\]

We can use automatic differentiation to compute $\dfrac{\partial \mathcal{L}}{\partial y}$.

OK, so that is some math (interesting to me, maybe not so much to you). Let’s actually implement this in python.

Gradient Descent for the SIR Model

The SIR model is a set of differential equations which govern how a disease spreads through a homogeneously mixed closed populations. I could write an entire thesis on this model and its various extensions (in fact, I have), so I’ll let you read about those on your free time.

The system, shown below, is parameterized by a single parameter:

\[\dfrac{dS}{dt} = -\theta SI \quad S(0) = 0.99\] \[\dfrac{dI}{dt} = \theta SI - I \quad I(0) = 0.01\]

Let’s define the system, the appropriate derivatives, generate some observations and fit $\theta$ using gradient descent. Here si what you’ll need to get started:

import theano
import theano.tensor as tt
import numpy as np
import matplotlib.pyplot as plt
import scipy.integrate

Let’s then define the ODE system

def f(y, t, theta):
    """"This is the ODE system.

    The function can act on either numpy arrays or theano TensorVariables

        y (vector): system state
        t (float): current time (optional)
        theta (vector): parameters of the ODEs

        dydt (list): result of the ODEs
    return [
        -theta*y[0]*y[1],   #= dS/dt
        theta*y[0]*y[1] - y[1] #= dI/dt

and create a computation graph with theano.

#Define the differential Equation

#Present state of the system
y = tt.dvector('y')
#Parameter: Basic reproductive ratio
p = tt.dscalar('p')
#Present state of the gradients: will always be 0 unless the parameter is the inital condition
dydp = tt.dvector('dydp')

f_tensor = tt.stack(f(y, None, p))

#Now compute gradients
#Use Rop to compute the Jacobian vector product
Jdfdy = tt.Rop(f_tensor, y, dydp)

grad_f = tt.jacobian(f_tensor, p)
#This is the time derivative of dydp
ddt_dydp = Jdfdy + grad_f

#Compile the system as a theano function
    #y - array of length 2 representing current state of the system (S and I)
    #dydp - array of length 2 representing current state of the gradient (dS/dp and dI/dp)
system = theano.function(
    inputs=[y, dydp, p],
    outputs=[f_tensor, ddt_dydp]

Next, we’ll define the augmented system (that is, the ODE plus the sensitivities).

#create a function to spit out derivatives
def ODESYS(Y,t,p):
        Y (vector): current state and gradient
        t (scalar): current time
        p (vector): parameters

        derivatives (vector): derivatives of state and gradient
    dydt, dydp = system(Y[:2], Y[2:], p)
    return np.concatenate([dydt, dydp])

We’ll optimize the $L_2$ norm of the error. This is done in theano.

#Tensor for observed data
t_y_obs = tt.dmatrix('y_obs')
#Tensor for predictions
t_y_pred = tt.dmatrix('y_pred')

#Define error and cost
err = t_y_obs - t_y_pred
Cost = err.norm(2,axis = 1).sum()
Cost_gradient = theano.tensor.grad(Cost,t_y_pred)

C = theano.function(inputs = [t_y_obs, t_y_pred], outputs=Cost)
del_C = theano.function(inputs = [t_y_obs, t_y_pred], outputs=Cost_gradient)

Create some observations from which to fit

#Initial Condition
Y0 = np.array([0.99,0.01, 0.0, 0.0])
#Space to compute solutions
t_dense = np.linspace(0,5,101)
#True param value
p_true = 8

y_hat_theano = scipy.integrate.odeint(ODESYS, y0=Y0, t=t_dense, args=(p_true,))

S_obs,I_obs, *_ = y_hat_theano.T + np.random.normal(0,0.1,size = y_hat_theano.T.shape)
y_obs = np.c_[S_obs,I_obs]

Perform Gradient Descent

p_gd = 1.1
learning_rate = 0.01

num_steps = 1000

prev_cost = float('inf')

tol = 1e-5

while True:

    #Evaluate solution
    y_hat_theano = scipy.integrate.odeint(ODESYS, y0=Y0, t=t_dense, args=(p_gd,))

    #Splice out the numerical solution and numerical gradients
    y_pred = y_hat_theano[:,:2]
    gradients = y_hat_theano[:,2:]

    #Perform the gradient step
    p_gd-= learning_rate*np.sum(del_C(y_obs,y_pred)*gradients)

    #Has the loss changed a large amount?
    cost = C(y_obs,y_pred)

    #If so, keep going.  Stop when the loss stops shrinking
    if abs(cost - prev_cost)<tol:

    prev_cost = cost

And lastly, compare our fitted curves to the true curves

plt.figure(dpi = 120)

plt.scatter(t_dense,y_obs[:,0], alpha = 0.5, marker = '.')
plt.scatter(t_dense,y_obs[:,1],alpha = 0.5, marker = '.')

plt.plot(t_dense,y_hat_theano[:,0], color = 'k', zorder = 100)
plt.plot(t_dense,y_hat_theano[:,1], color = 'k', zorder = 100)

plt.plot(t_dense,y_hat_analytical[:,0], linewidth = 3)
plt.plot(t_dense,y_hat_analytical[:,1], linewidth = 3)


Here is our result!