NMODL SympySolver - cnexp

This notebook describes the implementation of the cnexp part of the SympySolverVisitor, which solves the systems of ODEs defined in DERIVATIVE blocks when these ODEs are linear and independent.

For a higher level overview of the approach to solving ODEs in NMODL, please see the nmodl-odes-overview notebook.

For a more general tutorial on using the NMODL python interface, please see the tutorial notebook.


Implementation

The SympySolverVisitor for solver method cnexp does the following:

  • Get list of all global scope variables from the Symbol Table, as well as any local variables in DERIVATIVE block

  • For each differential equation in DERIVATIVE block:

    • Parse equation into SymPy, giving it the list of variables

    • This gives us a differential equation of the form:

      • \frac{dm}{dt} = f(m, \dots)

      • where the function f depends on m, as well as possibly other variables reprensented by \dots which we assume do not depend on m or t

    • Solve equation analytically using sympy.dsolve to give a solution of the form:

      • m(t+dt) = g(m(t), dt, \dots)

      • where g is some function that depends on the value of m at time t, the timestep dt, and the other variables (\dots).

    • Replace ODE with analytic solution as C code using sympy.printing.ccode

    • If we fail to find a solution then do nothing - so currently NMODL reverts to using the legacy CNEXP solver routine (same as mod2c or nocmodl)


Pade approximant

There is an option use_pade_approx which if enabled does the following extra step:

  • Given the analytic solution f(t):

    • Expand the solution in a Taylor series in dt, extract the coefficients a_i

      • f(t + dt) = f(t) + dt f'(t) + dt^2 f''(t) / 2 + \dots = a_0 + a_1 dt + a_2 dt^2 + \dots

    • Construct the (1,1) Pade approximant to the solution using these Taylor coefficients

      • f_{PADE}(t+dt) = (a_0 a_1 + (a_1^2 - a_0 a_2) dt)/(a_1 - a_2 dt)

    • Return this approximate solution (correct to second order in dt) as C code

(Replacing the exponential with a Pade aproximant here was suggested in sec 5.2 of (https://www.eccomas2016.org/proceedings/pdf/7366.pdf) - since the overall numerical integration scheme in NEURON is only correct to first or second order in dt, it is valid to expand the analytic solution here to the same order and so avoid evaluating the exponential function)


Implementation Tests

The unit tests may be helpful to understand what these functions are doing - SympySolverVisitor tests are located in test/visitor/sympy_solver.cpp, and tests involving cnexp have the tag “[cnexp]

Examples

[1]:
%%capture
! pip install nmodl
[2]:
import nmodl.dsl as nmodl


def run_sympy_solver(mod_string, pade=False):
    # parse NMDOL file (supplied as a string) into AST
    driver = nmodl.NmodlDriver()
    AST = driver.parse_string(mod_string)
    # run SymtabVisitor to generate Symbol Table
    nmodl.symtab.SymtabVisitor().visit_program(AST)
    # constant folding, inlining & local variable renaming passes
    nmodl.visitor.ConstantFolderVisitor().visit_program(AST)
    nmodl.visitor.InlineVisitor().visit_program(AST)
    nmodl.visitor.LocalVarRenameVisitor().visit_program(AST)
    # run SympySolver visitor
    nmodl.visitor.SympySolverVisitor(use_pade_approx=pade).visit_program(AST)
    # return contents of new DERIVATIVE block as list of strings
    return nmodl.to_nmodl(
        nmodl.visitor.AstLookupVisitor().lookup(
            AST, nmodl.ast.AstNodeType.DERIVATIVE_BLOCK
        )[0]
    ).splitlines()[1:-1]

Ex. 1

Single constant ODE

[3]:
ex1 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = 4
}
"""
print("exact solution:\t", run_sympy_solver(ex1, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex1, pade=True)[0])
exact solution:      m = 4.0*dt+m
pade approx:         m = 4.0*dt+m

Ex. 2

Single linear ODE

[4]:
ex2 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = a*m
}
"""
print("exact solution:\t", run_sympy_solver(ex2, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex2, pade=True)[0])
exact solution:      m = m*exp(a*dt)
pade approx:         m = m*(-a*dt-2.0)/(a*dt-2.0)

Ex. 3

Single linear ODE

[5]:
ex3 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = (minf-m)/mtau
}
"""
print("exact solution:\t", run_sympy_solver(ex3, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex3, pade=True)[0])
exact solution:      m = minf-(-m+minf)*exp(-dt/mtau)
pade approx:         m = (-dt*m+2.0*dt*minf+2.0*m*mtau)/(dt+2.0*mtau)

Ex. 4

Single linear ODE that can be simplified

[6]:
ex4 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = (minf-m)/mtau - m/mtau - 2*minf/mtau + 3*m/mtau
}
"""
print("exact solution:\t", run_sympy_solver(ex4, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex4, pade=True)[0])
exact solution:      m = minf+(m-minf)*exp(dt/mtau)
pade approx:         m = (-dt*m+2.0*dt*minf-2.0*m*mtau)/(dt-2.0*mtau)

Ex. 5

Single nonlinear ODE with analytic solution

[7]:
ex5 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = m^3
}
"""
print("exact solution:\t", run_sympy_solver(ex5, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex5, pade=True)[0])
exact solution:      m = sqrt(-pow(m, 2)/(2.0*dt*pow(m, 2)-1.0))
pade approx:         m = m*(dt*pow(m, 2)-2.0)/(3.0*dt*pow(m, 2)-2.0)

Ex. 6

Single nonlinear ODE (more complicated and not handled yet): unsupported, so does not modify DERIVATIVE block, leaves it to a later visitor pass to deal with

[8]:
ex6 = """
BREAKPOINT {
    SOLVE states METHOD cnexp
}
DERIVATIVE states {
    m' = exp(m)^2
}
"""
print("exact solution:\t", run_sympy_solver(ex6, pade=False)[0])
print("pade approx:\t", run_sympy_solver(ex6, pade=True)[0])
exact solution:      m = 0.5*log(-exp(2*m)/(2*dt*exp(2*m)-1))
pade approx:         m = 0.5*(dt*(log(exp(2*m))-2.0)*exp(2*m)-log(exp(2*m)))/(dt*exp(2*m)-1.0)