Automatic Differentiation in MESA
May 11, 2022 15:28 · 3825 words · 18 minute read
Over the last two years I developed an automatic differentiation (auto_diff
) module in Fortran to support development of the Modules for Experiments in Stellar Astrophysics (MESA) project. This post gives a brief rundown of how the auto_diff
module works and how I built it.
All of the code I mention below lives on GitHub.
What is auto_diff?
Forward-mode automatic differentiation via operator overloading:
- Forward-mode means we calculate the chain rule as we go.
- Each variable in the calculation needs to be able to track derivative information.
- Variables need to know how the chain rule applies to each operation.
- Fortran source files are generated automatically by a python program.
- This allows robust support for many different functions/operators/derivative configurations.
What does it look like in Fortran?
type :: auto_diff_real_1var_order1
real(dp) :: val
real(dp) :: d1val1
end type auto_diff_real_1var_order1
The types themselves are simple! Here’s the 1var_order1
type, which supports 1 independent variable through 1 (first order) derivative. val
stores the value, d1val1
stores the derivative with respect to the independent variable.
Concretely, we might set up a variable like this:
type(auto_diff_real_1var_order1) :: x
x = 1d0 ! Sets val to 1, zero's out d1val1
x%d1val1 = 1d0 ! Says dx/d(val1) = 1.
! Often used as a shorthand for saying 'x' is the independent variable.
And we might perform operations with these variables:
f = sin(x) ! Now f%val = sin(1), f%d1val1 = cos(1)
f = exp(x) ! Now f%val = e, f%d1val1 = e
f = pow3(x + 1) ! Now f%val = 8, f%d1val1 = 6
f = f + x ! Now f%val = 9, f%d1val1 = 7
Note that we do not support assignments like (real) = (auto_diff)
. Why? Because we don’t want to accidentally lose the derivative information, and a real
type doesn’t have anywhere to put it!
So if you want to get the value you have to use f%val
, and if you want the derivative info that’s in f%d1val1
.
Other types support more derivatives and more variables. The general pattern is Nvar_orderM
will support all derivatives through $m$-th total order in all combinations of $n$ variables. So for instance auto_diff_real_2var_order2
supports d1val1
, d1val2
, d1val1_d1val2
, d2val1
, d2val2
, which are the five combinations of (mixed) partial derivatives of two variables up to total order 2. Note that the mixed ones are always ordered by the val
index, not the d
index, e.g. d2val1_d1val2
is how you’d write one of the third order mixed partials.
How does it work in Fortran?
Behind the scenes are ludicrously large Fortran files, which begin like:
module auto_diff_real_1var_order1_module
use const_def, only: dp, ln10, pi
use utils_lib
use support_functions
use math_lib
implicit none
private
public :: auto_diff_real_1var_order1, &
assignment(=), &
operator(.eq.), &
operator(.ne.), &
operator(.gt.), &
operator(.lt.), &
operator(.le.), &
operator(.ge.), &
make_unop, &
make_binop, &
sign, &
safe_sqrt, &
operator(-), &
exp, &
expm1, &
...
It goes on for a while. This is just exporting all the many (many) operators that get overloaded. Scrolling down we find the implementations of these overloaded routines, like
function expm1_self(x) result(unary)
type(auto_diff_real_1var_order1), intent(in) :: x
type(auto_diff_real_1var_order1) :: unary
unary%val = expm1(x%val)
unary%d1val1 = x%d1val1*exp(x%val)
end function expm1_self
and
function add_self(x, y) result(binary)
type(auto_diff_real_1var_order1), intent(in) :: x
type(auto_diff_real_1var_order1), intent(in) :: y
type(auto_diff_real_1var_order1) :: binary
binary%val = x%val + y%val
binary%d1val1 = x%d1val1 + y%d1val1
end function add_self
The operators are all labelled as either unary or binary. Binary operators generally are named by the types they work with (e.g. add_self
adds two auto_diff
types, add_self_int
adds an auto_diff
type and an integer
, etc.).
Sometimes the operators are a little inscrutable:
function dim_self(x, y) result(binary)
type(auto_diff_real_1var_order1), intent(in) :: x
type(auto_diff_real_1var_order1), intent(in) :: y
type(auto_diff_real_1var_order1) :: binary
real(dp) :: q0
q0 = x%val - y%val
binary%val = -0.5_dp*y%val + 0.5_dp*x%val + 0.5_dp*Abs(q0)
binary%d1val1 = -0.5_dp*y%d1val1 + 0.5_dp*x%d1val1 + 0.5_dp*(x%d1val1 - y%d1val1)*sgn(q0)
end function dim_self
The reason for this is that they’re all auto-generated by python scripts, in a way that optimizes for (Fortran) runtime speed at all costs.
Derivative Operators
Each auto_diff
type additionally has some number of derivative operators, one per independent variable. These work like:
df_dx = differentiate_1(f)
df_dy = differentiate_2(f)
The idea here is that you might want an auto_diff
type which itself represents the derivatives of another auto_diff
variable (so you can propagate higher order derivatives through later operations). This is what let’s Skye do things like writing the pressure as
p = pow2(den) * differentiate_2(F)
and have p
still contain derivative information.
These methods can’t fill in higher order derivatives than exist. In the above example F
has a third derivative with respect to rho
. p
is a derivative of F
with respect to rho
, so we don’t know enough to construct the third derivative of p
with respect to rho
. This is handled by just zeroing out the derivatives we don’t know.
We considered using NaN’s instead of zeros, following a philosophy that you should know very clearly when you’ve mistakenly read a missing entry (silent failure is bad). The problem with using NaN’s here is that we want to be able to run MESA with floating point error (FPE) checking turned on as a way to catch numerical problems, and if we assign NaN to variables routinely that becomes impossible.
Custom Operators
Two functions I want to highlight are make_unop
and make_binop
:
function make_unary_operator(x, z_val, z_d1x) result(unary)
type(auto_diff_real_1var_order1), intent(in) :: x
real(dp), intent(in) :: z_val
real(dp), intent(in) :: z_d1x
type(auto_diff_real_1var_order1) :: unary
unary%val = z_val
unary%d1val1 = x%d1val1*z_d1x
end function make_unary_operator
function make_binary_operator(x, y, z_val, z_d1x, z_d1y) result(binary)
type(auto_diff_real_1var_order1), intent(in) :: x
type(auto_diff_real_1var_order1), intent(in) :: y
real(dp), intent(in) :: z_val
real(dp), intent(in) :: z_d1x
real(dp), intent(in) :: z_d1y
type(auto_diff_real_1var_order1) :: binary
binary%val = z_val
binary%d1val1 = x%d1val1*z_d1x + y%d1val1*z_d1y
end function make_binary_operator
Let’s focus on make_unop
. It takes as input an auto_diff
variable and z_val
and z_d1x
. The latter two are the specification of a function and its derivative with respect to an arbitrary independent variable, evaluated at that value of x
. make_unop
then propagates through the chain rule to apply that function to x
and give a result which inherits derivatives from x
. These helper routines are there to perform variable substitutions. The idea is you might know the function z(independent_variable)
but want to instead have z(x)
(which has different derivatives because x
may itself be a complicated function of independent variables). make_binop
does the same but for binary operators.
As far as I’m aware these functions only get used in the Skye equation of state in MESA, where we play some tricks with custom operators, but they’re there if you ever need to do a variable substitution.
Array Types
There are two special types that break the mold:
type :: auto_diff_real_star_order1
real(dp) :: val
real(dp) :: d1Array(33)
end type auto_diff_real_star_order1
This is the basic auto_diff
type used everywhere in MESA/star
. Instead of 33 different independent variable entries it puts them all in an array. The meaning of these is set in MESA/star_data/public/star_data_def.inc
, where you’ll find
! auto_diff constants for solver variables
! used to access auto_diff_real_star_order1 d1Array
integer, parameter :: i_lnd_m1 = 1
integer, parameter :: i_lnd_00 = 2
integer, parameter :: i_lnd_p1 = 3
integer, parameter :: i_lnT_m1 = 4
integer, parameter :: i_lnT_00 = 5
integer, parameter :: i_lnT_p1 = 6
integer, parameter :: i_w_m1 = 7
integer, parameter :: i_w_00 = 8
integer, parameter :: i_w_p1 = 9
integer, parameter :: i_lnR_m1 = 10
integer, parameter :: i_lnR_00 = 11
integer, parameter :: i_lnR_p1 = 12
...
which tells the solver which indices correspond to which variables in the array. Hence d1Array(5)
corresponds to the derivative with respect to lnT
in the current cell, d1Array(6)
with respect to lnT
in the next cell, and so on.
If you need to change the number of independent variables, you can do that by updating (1) the entry in the auto_diff config file (both for the star and tdc types), (2) adding new indexing parameters to star_data_def.inc
, and (3) adding new helper routines to MESA/star/private/auto_diff_support.f90
to handle your new independent variables.
There are also lots of helper routines in MESA/star/private/auto_diff_support.f90
for manipulating these objects, including ways to shift the indexing so p1 -> 00
(and vice-versa), ways to generate e.g. lnT(k)
with the appropriate derivative setup (d1Array(1:4,6:33)==0
, d1Array(5)==1
), etc.
The other special one is
type :: auto_diff_real_tdc
real(dp) :: val
real(dp) :: d1val1
real(dp) :: d1Array(33)
real(dp) :: d1val1_d1Array(33)
end type auto_diff_real_tdc
This type is only used in the time-dependent convection (TDC) code, and exists because we needed a type that has a derivative with respect to one additional variable (the superadiabaticity on a face) and needed all mixed partial derivatives with all of the star solver variables.
How does it work in Python?
So how does the python side generate these files?
Config Files
In MESA/auto_diff/config
there are a bunch of files, one per auto_diff
type. These are yaml files, and look like:
name: auto_diff_real_2var_order1
orders: [[1,0],[0,1]]
array: False
This says:
Make a type named
auto_diff_real_2var_order1
. It has to have all partial derivatives up to and including the first derivative with respect to the first variable and the first derivative with respect to the second variable.It does not store derivatives as an array.
Another example:
name: auto_diff_real_2var_order3
orders: [[3,0],[2,1],[1,2],[0,3]]
array: False
which says
Make a type named
auto_diff_real_2var_order3
. It has to have all partial derivatives up to and including the third derivative with respect to the first variable, the (2,1) mixed partial, the (1,2) mixed partial, and the third derivative with respect to the third variable. It does not store derivatives as an array.
Finally, the star example:
name: auto_diff_real_star_order1
orders: [[1]]
array: True
fixed_length: True
array_length: 33
which says
Make a type named
auto_diff_real_star_order1
. It stores derivatives as arrays of fixed length 33 and has to have all partial derivatives up to and including the first derivative with respect to each component of the array.
Parser
You can regenerate the auto_diff
Fortran source by calling python parser.py
in the python
directory. The parser is reasonably straightforward. It begins by getting the list of config files:
# Get config files
config_path = '../config'
config_files = [f for f in listdir(config_path) if isfile(join(config_path, f)) and '.config' in f]
config_files = [join(config_path, f) for f in config_files]
It then makes two lists of files. The compilation_list
are all the files that make
will need to act on, and the use_list
is all the modules that need to be shared by the auto_diff
public interface.
# compilation_list stores a list of all the fortran files that will need compiling.
# This is used in the makefile.
compilation_list = []
compilation_list.append('support_functions.f90')
# use_list stores a list of all private auto_diff modules that need importing into the public auto_diff module.
use_list = []
use_list.append(tab + 'use support_functions')
We then loop over all config files. For each, we read out the relevant info:
data = load(fi, Loader=Loader)
# gfortran does not (as of September 2021) support variable-length
# arrays in parameterized-derived-types. So stick with fixed-length
# arrays. If this changes in the future you can set fixed_length
# to False and use variable-length arrays as desired.
if data['array'] and data['fixed_length']:
array_length = data['array_length']
else:
array_length = None
Note that we can’t do variable length arrays. The Python side can generate parameterized derived auto_diff
types supporting variable length arrays, but gfortran doesn’t actually implement the F2003 spec and so won’t compile it. Some versions of ifort worked with this functionality but I can’t remember which. The gfortran bug report is here.
Then construct the list of all partial derivatives required:
# Read desired highest-order partial derivatives
partials = list(Partial(orders, data['array']) for orders in data['orders'])
This fills in all lower-order derivatives needed to fulfill the requested list of higher-order ones (e.g. if you request a third order derivative, this adds in a second and a first as well). The Partial
data type is defined in partial.py
and just has some helper methods for helping implement the chain rule.
That done, we build the types and write them out to files, appending them to the compilation_list
and use_list
:
# Build auto_diff type with those and all lower-order derivatives.
adr = AutoDiffType(data['name'], data['array'], array_length, partials)
out_fi = open('../private/' + data['name'] + '_module.f90', 'w+')
out_fi.write(py_to_fort(make_auto_diff_type(adr, unary_operators, binary_operators, comparison_operators, intrinsics)))
out_fi.close()
compilation_list.append(data['name'] + '.f90')
use_list.append(tab + 'use ' + data['name'] + '_module')
AutoDiffType
The AutoDiffType
class lives in auto_diff_type.py
. This type is the internal representation of an auto_diff
Fortran type on the Python side. It’s initialized as
class AutoDiffType:
def __init__(self, name, array, array_length, partials):
'''
Stores a list of partials that is complete, in the sense that there is enough information
to compute the chain rule within that set of partials, and sorted by total order.
'''
So you pass the partials you want, the variable name, and some information about arrays.
Now that I look again, it seems that we don’t need the array information, because the Partial
type already has that. So the array
and array_length
entries could be safely removed here.
The initialization has a few important pieces. First, we work out the complete set of partial derivatives we need:
# Complete the partials list
partials = set(partials)
complete = False
while not complete:
ps = list(partials)
for p in ps:
partials.update(p.completion_partials())
if len(partials) == len(ps):
complete = True
The routine completion_partials
returns any additional partial derivatives that a given partial needs to be able to propagate in the chain rule. For instance $\partial_x\partial_y^2$ in the chain rule needs access to $\partial_x\partial_y$ and $\partial_y^2$, so its completion will return those two. We just keep calling completion_partials
till it stops returning new derivatives.
Next, we put these in a sorted order so we can refer to them consistently:
self.partials = sorted(list(partials), key=lambda p: [p.net_order, tuple(-o for o in p.orders)])
Finally, we construct the sets of partials of unary operators and binary operators out to the maximum order represented. These, too, are needed by the chain rule:
self.unary_partials = sorted(list(Partial((i,), False) for i in range(self.max_order+1)), key=lambda p: [p.net_order, tuple(-o for o in p.orders)])
self.binary_partials = sorted(list(Partial((i,j), False) for i in range(self.max_order+1) for j in range(self.max_order+1) if i+j <= self.max_order), key=lambda p: [p.net_order, tuple(-o for o in p.orders)])
You can think of these as $\partial_x f(x,y)$ and $\partial_y f(x,y)$, which you need to compute the chain rule for $\partial_u (f(x(u),y(u)))$.
The rest of the class specification is full of functions that construct the various operators that appear on the Fortran side. For instance
def specific_unary_operator_function(self, operator_name, operator):
'''
Returns a function which implements the specified unary operator.
'''
function_name = operator_name + '_self'
function_arguments = [('x', self.declare_name(ref='*'), 'in')]
function_result = ('unary', self.declare_name(ref='x'))
function_body, function_declarations = unary_specific_chain_rule(self, operator, fixed_length=self.array_length)
function_body = function_declarations + function_body
# Special case handling for safe_log
if 'safe' in operator_name:
for i in range(len(function_body)):
function_body[i] = function_body[i].replace('log', 'safe_log')
return FortranFunction(function_name, function_arguments, function_result, function_body)
takes as input an operator’s name and the operator itself (as a sympy
function) and returns a valid Fortran function (as a string) implementing the derivative propagation logic. Most of this is wrapper logic: all the magic and complicated stuff that goes in the body gets constructed in unary_specific_chain_rule(self, operator, fixed_length=self.array_length)
(and there are equivalent functions for binary operators).
chain_rule
The real magic on the Python side all happens in chain_rule.py
. That’s where functions like unary_specific_chain_rule(self, operator, fixed_length=self.array_length
are defined.
There are four functions in this file. They are each labelled unary
or binary
, after the kind of operator they represent, and specific
or generic
. The generic
ones are used to write the [Custom Operator][Custom Operators] routines and the specific
ones are used to implement actual specific operators like exp
and +
and so on.
How do they work?
specific
This is a bit complicated.
Everything here uses sympy
for calculus and algebra, which means most of what we’re doing is setting up lots of sympy
variables and manipulating them.
We start by making symbols for the independent variables:
# Construct sympy variables corresponding to the various independent variables.
# These never appear on the Fortran side, but we keep the naming consistent to correspond to the
# names in partial_orders.
# So these are called val1, val2, ..., valN.
indep_symbol_str = ' '.join(auto_diff_type.partials[0].val_name(i) for i in range(auto_diff_type.num_independent))
indep_syms = wrap_element(symbols(indep_symbol_str, real=True))
Then we make symbols for the places we’ll store the derivatives (this is where d1val1
comes from!):
# Construct sympy variables corresponding to the various derivatives dx/d(...).
# Note that these variable names correspond to the names we'll use on the Fortran side, so
# we can just directly map sympy expressions to strings and get valid Fortran :-)
# Hence these are called x%d1val1, x%d2val1, ..., x%d1val2, x%d2val2, ..., x%d1val1_d1val2, ...
# The first integer in each 'd_val_' block is the number of derivatives,
# the second is the independent variable those derivatives are with respect to.
x_symbol_str = ' '.join(auto_diff_type.partial_str_in_instance('x', p).replace(':','colon') for p in partials)
x_syms = wrap_element(symbols(x_symbol_str, real=True))
We then represent x
as a power series in terms of its partial derivative symbols:
# Construct x as a power series in terms of its partial derivatives (sym) with respect to the independent
# variables (indep).
x = 0
for p,sym in zip(*(partials, x_syms)):
term = sym
for order, indep in zip(*(p.orders, indep_syms)):
term = term * indep ** order / factorial(order)
x = x + term
And then call our operator on x
to get z(x)
:
z = operator(x)
The reason we play around with the power series is so that z
has an explicit representation in terms of the partial derivatives of x
, which in turn are explicitly represented as individual sympy
symbols.
With all that done, we actually extract derivatives. This starts with a few lists:
expressions = []
left_hand_names = []
derivatives = []
Here expressions
is the list of derivative expressions we’ll build, left_hand_names
is the corresponding list of e.g. d1val1
(which appear on the left-hand side in the Fortran code), and derivatives
appears to be an unused list that I forgot to delete.
We then iterate over all required partials:
for p in partials:
For each, we construct the left-hand side of the expression:
unary_symbol_str = auto_diff_type.partial_str_in_instance('unary', p).replace(':','colon')
The replace
business here is just to make sure we only use valid sympy
symbols. There’s a lot of that all over this code (string replacements to avoid invalid or reserved sympy
symbols, followed by back-replacement at the very end right before we write the file).
If life were simple, we’d then just ask sympy
for the derivative at the right order. But some use cases require auto_diff
to support non-differentiable functions like abs
and >
and min
and so on. Those spawn Dirac Delta’s when you try differentiate them. Which is awful because (1) we can’t do anything with those in any numerical methods and (2) they’re zero everywhere but a set of measure zero, so we don’t care about them. So we get a bunch of logic that special cases Dirac Delta and a few related objects and zero’s them out:
d = z
for order, indep in zip(*(p.orders, indep_syms)):
d = diff(d, indep, order)
d = d.replace(DiracDelta, zero_function) # Diract Delta is only non-zero in a set of measure zero, so we set it to zero.
d = d.replace(sign, sgn) # simplify can do weird things with sign, turning it into the piecewise operator. We want to avoid that so we redirect to a custom sgn function.
d = d.replace(Derivative, zero_function) # Eliminates derivatives of the Dirac Delta and Sign, which are non-zero only at sets of measure zero.
d = d.subs(indep, 0)
This is taking the derivatives one at a time, clearing out garbage as it arises.
Life would be nice if this were all we had to do, but we want the resulting Fortran code to be fast, so we do some simplifications:
d = simplify(d, measure=weighted_count_ops, force=True, ratio=1)
More on that later.
The rest of the routine is just a bunch of string manipulation to get everything into the right format for Fortran. You can print the results as they accumulate if you’re interested to see how the substitutions gradually turn a pile of algebra into valid Fortran.
generic
The generic ones cheat by calling the specific ones on a dummy function. They start by constructing symbols corresponding to the partial derivatives of some general function z(x)
out through the highest order we care about:
# Construct the symbols corresponding to the partial derivatives of z.
# These are d1z, d2z, ..., dNz, giving dz/dx, d^2 / dx^2, and so on.
z_symbol_strs = ['z_' + str(p).replace('val1','x') for p in auto_diff_type.unary_partials]
z_symbol_str = ' '.join(z_symbol_strs)
z_syms = wrap_element(symbols(z_symbol_str, real=True))
We then construct a Taylor series out of these symbols:
def operator(x):
# Construct z as a power series in terms of its partial derivatives (z_syms) with
# respect to the x.
z = sum(sym * x**p.orders[0] / factorial(p.orders[0]) for sym,p in zip(*(z_syms, auto_diff_type.unary_partials)))
return z
Then we call unary_specific_chain_rule
to give us the chain rule code for this dummy operator, and that gets everything in terms of the partial derivatives of z(x)
, which we can then supply as inputs to the custom operator builders.
make_auto_diff_type
This file puts it all together, going over all the functions and all the Fortran boiler plate and doing a bunch of accounting to make sure every function blah
gets closed by end function blah
and so on. It’s super boring.
measure
This is where all performance optimizations happen. We use the built-in sympy
function simplify
, but with a twist. We don’t care how complicated the functions are, we care how fast they are. And moreover speed is actually set by how many divides and special function calls we have. So we have to tell simplify
about all of that. In measure.py
specify crudely how much each function call costs:
# 'basic' here means roughly a one-cycle op.
# 'div' is division, which takes ~30 cycles.
# 'special' is a special function, which takes ~1000 cycles.
# DIRACDELTA and DERIVATIVE get eliminated in post-processing and so are free.
special = 1000
div = 30
basic = 1
weights = {
'SIN': special,
'COS': special,
'TAN': special,
'TANH': special,
'COSH': special,
'SINH': special,
'ASIN': special,
'ACOS': special,
'ATAN': special,
'ATANH': special,
'ACOSH': special,
'ASINH': special,
'EXP': special,
'LOG': special,
'POW': special,
'ADD': basic,
'MUL': basic,
'NEG': basic,
'SUB': basic,
'HEAVISIDE': basic,
'ABS': basic,
'DIV': div,
'SGN': basic,
'POWM1': div,
'SSQRT': special,
'DIRACDELTA': 0,
'DERIVATIVE': 0
}
Then we have a function that goes through a sympy
expression counting function calls and tallying them up:
def weighted_count_ops(expr_original, verbose=False):
Understanding this code requires a decent amount of knowledge of sympy
’s API, but suffice it to say that we’re crawling an abstract syntax tree and counting instances of functions as we encounter them.
functions
The functions.py
file defines all the supported auto_diff
functions in sympy
language. Not much more to say there.
Helper Methods
There are a bunch of boring helper methods in routine.py
(for spitting out valid Fortran routines), and utils.py
(for string manipulation and a few performance optimizations like pow(x,N) -> powN(x)
).