{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### NMODL SympySolver - cnexp\n", "\n", "This notebook describes the implementation of the `cnexp` part of the `SympySolverVisitor`, which solves the systems of ODEs defined in `DERIVATIVE` blocks when these ODEs are *linear* and *independent*.\n", "\n", "For a higher level overview of the approach to solving ODEs in NMODL, please see the [nmodl-odes-overview](nmodl-odes-overview.ipynb) notebook. \n", "\n", "For a more general tutorial on using the NMODL python interface, please see the [tutorial notebook](nmodl-python-tutorial.ipynb).\n", "\n", "***\n", "\n", "#### Implementation\n", "The `SympySolverVisitor` for solver method `cnexp` does the following:\n", "\n", "* Get list of all global scope variables from the Symbol Table, as well as any local variables in DERIVATIVE block\n", "* For each differential equation in DERIVATIVE block:\n", " * Parse equation into SymPy, giving it the list of variables\n", " * This gives us a differential equation of the form:\n", " * $\\frac{dm}{dt} = f(m, \\dots)$\n", " * where the function $f$ depends on $m$, as well as possibly other variables reprensented by $\\dots$ which we assume do not depend on $m$ or $t$\n", " * Solve equation analytically using [sympy.dsolve](https://docs.sympy.org/latest/modules/solvers/ode.html) to give a solution of the form:\n", " * $m(t+dt) = g(m(t), dt, \\dots)$\n", " * where $g$ is some function that depends on the value of $m$ at time t, the timestep $dt$, and the other variables ($\\dots$).\n", " * Replace ODE with analytic solution as C code using [sympy.printing.ccode](https://docs.sympy.org/latest/_modules/sympy/printing/ccode.html)\n", " * If we fail to find a solution then do nothing - so currently NMODL reverts to using the legacy CNEXP solver routine (same as mod2c or nocmodl)\n", "\n", "***\n", "\n", "#### Pade approximant\n", "There is an option `use_pade_approx` which if enabled does the following extra step:\n", "\n", "* Given the analytic solution $f(t)$:\n", " * Expand the solution in a Taylor series in `dt`, extract the coefficients $a_i$\n", " * $f(t + dt) = f(t) + dt f'(t) + dt^2 f''(t) / 2 + \\dots = a_0 + a_1 dt + a_2 dt^2 + \\dots$\n", " * Construct the (1,1) Pade approximant to the solution using these Taylor coefficients\n", " * $f_{PADE}(t+dt) = (a_0 a_1 + (a_1^2 - a_0 a_2) dt)/(a_1 - a_2 dt)$\n", " * Return this approximate solution (correct to second order in $dt$) as C code\n", "\n", "(Replacing the exponential with a Pade aproximant here was suggested in sec 5.2 of (https://www.eccomas2016.org/proceedings/pdf/7366.pdf) - since the overall numerical integration scheme in NEURON is only correct to first or second order in $dt$, it is valid to expand the analytic solution here to the same order and so avoid evaluating the exponential function)\n", "\n", "***\n", "\n", "#### Implementation Tests\n", "The unit tests may be helpful to understand what these functions are doing\n", " - `SympySolverVisitor` tests are located in [test/visitor/sympy_solver.cpp](https://github.com/BlueBrain/nmodl/blob/master/test/visitor/sympy_solver.cpp), and tests involving `cnexp` have the tag \"`[cnexp]`\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Examples" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "! pip install nmodl" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import nmodl.dsl as nmodl\n", "\n", "\n", "def run_sympy_solver(mod_string, pade=False):\n", " # parse NMDOL file (supplied as a string) into AST\n", " driver = nmodl.NmodlDriver()\n", " AST = driver.parse_string(mod_string)\n", " # run SymtabVisitor to generate Symbol Table\n", " nmodl.symtab.SymtabVisitor().visit_program(AST)\n", " # constant folding, inlining & local variable renaming passes\n", " nmodl.visitor.ConstantFolderVisitor().visit_program(AST)\n", " nmodl.visitor.InlineVisitor().visit_program(AST)\n", " nmodl.visitor.LocalVarRenameVisitor().visit_program(AST)\n", " # run SympySolver visitor\n", " nmodl.visitor.SympySolverVisitor(use_pade_approx=pade).visit_program(AST)\n", " # return contents of new DERIVATIVE block as list of strings\n", " return nmodl.to_nmodl(\n", " nmodl.visitor.AstLookupVisitor().lookup(\n", " AST, nmodl.ast.AstNodeType.DERIVATIVE_BLOCK\n", " )[0]\n", " ).splitlines()[1:-1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 1\n", "Single constant ODE" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m = 4*dt+m\n", "pade approx:\t m = 4*dt+m\n" ] } ], "source": [ "ex1 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = 4\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex1, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex1, pade=True)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 2\n", "Single linear ODE" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m = m*exp(a*dt)\n", "pade approx:\t m = -m*(a*dt+2)/(a*dt-2)\n" ] } ], "source": [ "ex2 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = a*m\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex2, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex2, pade=True)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 3\n", "Single linear ODE" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m = minf-(-m+minf)*exp(-dt/mtau)\n", "pade approx:\t m = (-dt*m+2*dt*minf+2*m*mtau)/(dt+2*mtau)\n" ] } ], "source": [ "ex3 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = (minf-m)/mtau\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex3, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex3, pade=True)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 4\n", "Single linear ODE that can be simplified" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m = minf+(m-minf)*exp(dt/mtau)\n", "pade approx:\t m = (-dt*m+2*dt*minf-2*m*mtau)/(dt-2*mtau)\n" ] } ], "source": [ "ex4 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = (minf-m)/mtau - m/mtau - 2*minf/mtau + 3*m/mtau\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex4, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex4, pade=True)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 5\n", "Single nonlinear ODE with analytic solution" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m = sqrt(-pow(m, 2)/(2*dt*pow(m, 2)-1))\n", "pade approx:\t m = pow(m, 2)*(dt*pow(m, 2)-2)*pow(pow(m, 2), -0.5)/(3*dt*pow(m, 2)-2)\n" ] } ], "source": [ "ex5 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = m^3\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex5, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex5, pade=True)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Ex. 6\n", "Single nonlinear ODE (more complicated and not handled yet): unsupported, so does not modify DERIVATIVE block, leaves it to a later visitor pass to deal with" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exact solution:\t m' = exp(m)^2\n", "pade approx:\t m' = exp(m)^2\n" ] } ], "source": [ "ex6 = \"\"\"\n", "BREAKPOINT {\n", " SOLVE states METHOD cnexp\n", "}\n", "DERIVATIVE states {\n", " m' = exp(m)^2\n", "}\n", "\"\"\"\n", "print(\"exact solution:\\t\", run_sympy_solver(ex6, pade=False)[0])\n", "print(\"pade approx:\\t\", run_sympy_solver(ex6, pade=True)[0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }