User Guide
sympy_solver_visitor.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #pragma once
9 
10 /**
11  * \file
12  * \brief \copybrief nmodl::visitor::SympySolverVisitor
13  */
14 
15 #include <pybind11/embed.h>
16 #include <pybind11/stl.h>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "ast/ast.hpp"
22 #include "symtab/symbol.hpp"
23 #include "visitors/ast_visitor.hpp"
25 
26 namespace nmodl {
27 namespace visitor {
28 
29 /**
30  * @addtogroup visitor_classes
31  * @{
32  */
33 
34 /**
35  * \class SympySolverVisitor
36  * \brief %Visitor for systems of algebraic and differential equations
37  *
38  * For DERIVATIVE block, solver method `cnexp`:
39  * - replace each ODE with its analytic solution
40  * - optionally using the `(1,1)` order Pade approximant in dt
41  *
42  * For `DERIVATIVE` block, solver method `euler`:
43  * - replace each ODE with forwards Euler timestep
44  *
45  * For `DERIVATIVE` block, solver method `sparse` and `derivimplicit`:
46  * - construct backwards Euler timestep non-linear system
47  * - return function F and its Jacobian J to be solved by newton solver
48  *
49  * For `LINEAR` blocks:
50  * - for small systems: solve linear system of algebraic equations by
51  * Gaussian elimination, replace equations with solutions
52  * - for large systems: return matrix and vector of linear system
53  * to be solved by e.g. LU factorization
54  *
55  * For `NON_LINEAR` blocks:
56  * - return function F and its Jacobian J to be solved by newton solver
57  */
59  private:
60  /// clear any data from previous block & get set of block local vars + global vars
61  void init_block_data(ast::Node* node);
62 
63  /// construct vector from set of state vars in correct order
65 
66  /// replace binary expression with new expression provided as string
67  static void replace_diffeq_expression(ast::DiffEqExpression& expr, const std::string& new_expr);
68 
69  /// raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
71 
72  /// return iterator pointing to where solution should be inserted in statement block
73  ast::StatementVector::const_iterator get_solution_location_iterator(
74  const ast::StatementVector& statements);
75 
76  /// construct solver block
77  void construct_eigen_solver_block(const std::vector<std::string>& pre_solve_statements,
78  const std::vector<std::string>& solutions,
79  bool linear);
80 
81  /// solve linear system (for "LINEAR")
82  void solve_linear_system(const std::vector<std::string>& pre_solve_statements = {});
83 
84  /// solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
85  void solve_non_linear_system(const std::vector<std::string>& pre_solve_statements = {});
86 
87  /// return NMODL string version of node, excluding any units
88  static std::string to_nmodl_for_sympy(ast::Ast& node) {
90  }
91 
92  /// Function used by SympySolverVisitor::filter_X to replace the name X in a std::string
93  /// to X_operator
94  static std::string& replaceAll(std::string& context,
95  const std::string& from,
96  const std::string& to);
97 
98  /// Check original_vector for elements that contain a variable named original_string and
99  /// rename it to substitution_string
100  static std::vector<std::string> filter_string_vector(
101  const std::vector<std::string>& original_vector,
102  const std::string& original_string,
103  const std::string& substitution_string);
104 
105  /// global variables
106  std::set<std::string> global_vars;
107 
108  /// local variables in current block + globals
109  std::set<std::string> vars;
110 
111  /// custom function calls used in ODE block
112  std::set<std::string> function_calls;
113 
114  /// map between derivative block names and associated solver method
115  std::unordered_map<std::string, std::string> derivative_block_solve_method{};
116 
117  /// expression statements appearing in the block
118  /// (these can be of type DiffEqExpression, LinEquation or NonLinEquation)
119  std::unordered_set<ast::Statement*> expression_statements;
120 
121  /// current expression statement being visited (to track ODEs / (non)lineqs)
123 
124  /// last expression statement visited (to know where to insert solutions in statement block)
126 
127  /// current statement block being visited
129 
130  /// block where expression statements appear (to check there is only one)
132 
133  /// method specified in solve block
134  std::string solve_method;
135 
136  /// vector of {ODE, linear eq, non-linear eq} system to solve
137  std::vector<std::string> eq_system;
138 
139  /// only solve eq_system system of equations if this is true:
140  bool eq_system_is_valid = true;
141 
142  /// true for (non)linear eqs, to identify all state vars used in equations
143  bool collect_state_vars = false;
144 
145  /// vector of all state variables (in order specified in STATE block in mod file)
146  std::vector<std::string> all_state_vars;
147 
148  /// set of state variables used in block
149  std::set<std::string> state_vars_in_block;
150 
151  /// vector of state vars used *in block* (in same order as all_state_vars)
152  std::vector<std::string> state_vars;
153 
154  /// map from state vars to the algebraic equation from CONSERVE statement that should replace
155  /// their ODE, if any
156  std::unordered_map<std::string, std::string> conserve_equation;
157 
158  /// optionally replace cnexp solution with (1,1) pade approx
160 
161  /// optionally do CSE (common subexpression elimination) for sparse solver
163 
164  /// max number of state vars allowed for small system linear solver
166 
167  public:
168  explicit SympySolverVisitor(bool use_pade_approx = false,
169  bool elimination = true,
174 
175  void visit_var_name(ast::VarName& node) override;
176  void visit_diff_eq_expression(ast::DiffEqExpression& node) override;
177  void visit_conserve(ast::Conserve& node) override;
178  void visit_derivative_block(ast::DerivativeBlock& node) override;
179  void visit_lin_equation(ast::LinEquation& node) override;
180  void visit_linear_block(ast::LinearBlock& node) override;
181  void visit_non_lin_equation(ast::NonLinEquation& node) override;
182  void visit_non_linear_block(ast::NonLinearBlock& node) override;
184  void visit_statement_block(ast::StatementBlock& node) override;
185  void visit_program(ast::Program& node) override;
186 };
187 
188 /** @} */ // end of visitor_classes
189 
190 } // namespace visitor
191 } // namespace nmodl
nmodl::visitor::SympySolverVisitor::to_nmodl_for_sympy
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
Definition: sympy_solver_visitor.hpp:88
nmodl::ast::Node
Base class for all AST node.
Definition: node.hpp:40
nmodl::visitor::SympySolverVisitor::visit_lin_equation
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
Definition: sympy_solver_visitor.cpp:573
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::visitor::SympySolverVisitor::visit_linear_block
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
Definition: sympy_solver_visitor.cpp:587
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::ast::DiffEqExpression
Represents differential equation in DERIVATIVE block.
Definition: diff_eq_expression.hpp:38
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
nmodl::visitor::SympySolverVisitor
Visitor for systems of algebraic and differential equations
Definition: sympy_solver_visitor.hpp:58
nmodl::visitor::SympySolverVisitor::replaceAll
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
Definition: sympy_solver_visitor.cpp:134
nmodl::visitor::SympySolverVisitor::visit_conserve
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
Definition: sympy_solver_visitor.cpp:468
nmodl::ast::NonLinEquation
TODO.
Definition: non_lin_equation.hpp:38
nmodl::visitor::SympySolverVisitor::visit_non_linear_block
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
Definition: sympy_solver_visitor.cpp:615
nmodl::ast::Conserve
Represent CONSERVE statement in NMODL.
Definition: conserve.hpp:38
nmodl::visitor::SympySolverVisitor::solve_method
std::string solve_method
method specified in solve block
Definition: sympy_solver_visitor.hpp:134
nmodl::ast::StatementVector
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:298
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::visitor::SympySolverVisitor::filter_string_vector
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
Definition: sympy_solver_visitor.cpp:146
nmodl::visitor::SympySolverVisitor::all_state_vars
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
Definition: sympy_solver_visitor.hpp:146
nmodl::visitor::SympySolverVisitor::init_block_data
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
Definition: sympy_solver_visitor.cpp:27
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
nmodl::visitor::SympySolverVisitor::derivative_block_solve_method
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
Definition: sympy_solver_visitor.hpp:115
nmodl::visitor::SympySolverVisitor::solve_non_linear_system
void solve_non_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
Definition: sympy_solver_visitor.cpp:343
nmodl::visitor::SympySolverVisitor::vars
std::set< std::string > vars
local variables in current block + globals
Definition: sympy_solver_visitor.hpp:109
nmodl::ast::ExpressionStatement
TODO.
Definition: expression_statement.hpp:38
nmodl::visitor::SympySolverVisitor::global_vars
std::set< std::string > global_vars
global variables
Definition: sympy_solver_visitor.hpp:106
nmodl::visitor::SympySolverVisitor::SMALL_LINEAR_SYSTEM_MAX_STATES
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
Definition: sympy_solver_visitor.hpp:165
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::visitor::SympySolverVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: sympy_solver_visitor.cpp:636
nmodl::visitor::SympySolverVisitor::function_calls
std::set< std::string > function_calls
custom function calls used in ODE block
Definition: sympy_solver_visitor.hpp:112
nmodl::visitor::SympySolverVisitor::block_with_expression_statements
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
Definition: sympy_solver_visitor.hpp:131
nmodl::visitor::SympySolverVisitor::replace_diffeq_expression
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
Definition: sympy_solver_visitor.cpp:64
nmodl::visitor::SympySolverVisitor::visit_var_name
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
Definition: sympy_solver_visitor.cpp:369
nmodl::visitor::SympySolverVisitor::visit_expression_statement
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
Definition: sympy_solver_visitor.cpp:629
nmodl::visitor::AstVisitor
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
nmodl::visitor::SympySolverVisitor::state_vars_in_block
std::set< std::string > state_vars_in_block
set of state variables used in block
Definition: sympy_solver_visitor.hpp:149
nmodl::ast::AstNodeType::UNIT_DEF
@ UNIT_DEF
type of ast::UnitDef
nmodl::visitor::SympySolverVisitor::get_solution_location_iterator
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
Definition: sympy_solver_visitor.cpp:86
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::visitor::SympySolverVisitor::use_pade_approx
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
Definition: sympy_solver_visitor.hpp:159
nmodl::visitor::SympySolverVisitor::init_state_vars_vector
void init_state_vars_vector()
construct vector from set of state vars in correct order
Definition: sympy_solver_visitor.cpp:55
nmodl::visitor::SympySolverVisitor::collect_state_vars
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
Definition: sympy_solver_visitor.hpp:143
ast.hpp
Auto generated AST classes declaration.
nmodl::visitor::SympySolverVisitor::conserve_equation
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
Definition: sympy_solver_visitor.hpp:156
nmodl::visitor::SympySolverVisitor::state_vars
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
Definition: sympy_solver_visitor.hpp:152
nmodl::visitor::SympySolverVisitor::construct_eigen_solver_block
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
Definition: sympy_solver_visitor.cpp:158
nmodl::ast::AstNodeType::UNIT
@ UNIT
type of ast::Unit
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::visitor::SympySolverVisitor::visit_non_lin_equation
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
Definition: sympy_solver_visitor.cpp:601
nmodl::visitor::SympySolverVisitor::elimination
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
Definition: sympy_solver_visitor.hpp:162
nmodl::ast::LinearBlock
Represents LINEAR block in the NMODL.
Definition: linear_block.hpp:53
nmodl::visitor::SympySolverVisitor::last_expression_statement
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
Definition: sympy_solver_visitor.hpp:125
nmodl::ast::NonLinearBlock
Represents NONLINEAR block in the NMODL.
Definition: non_linear_block.hpp:50
nmodl::visitor::SympySolverVisitor::current_statement_block
ast::StatementBlock * current_statement_block
current statement block being visited
Definition: sympy_solver_visitor.hpp:128
nmodl::visitor::SympySolverVisitor::SympySolverVisitor
SympySolverVisitor(bool use_pade_approx=false, bool elimination=true, int SMALL_LINEAR_SYSTEM_MAX_STATES=3)
Definition: sympy_solver_visitor.hpp:168
nmodl::visitor::SympySolverVisitor::visit_derivative_block
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
Definition: sympy_solver_visitor.cpp:492
nmodl::ast::LinEquation
TODO.
Definition: lin_equation.hpp:38
nmodl::visitor::SympySolverVisitor::expression_statements
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
Definition: sympy_solver_visitor.hpp:119
nmodl::visitor::SympySolverVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: sympy_solver_visitor.cpp:643
nmodl::visitor::SympySolverVisitor::eq_system
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
Definition: sympy_solver_visitor.hpp:137
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::visitor::SympySolverVisitor::eq_system_is_valid
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
Definition: sympy_solver_visitor.hpp:140
nmodl::visitor::SympySolverVisitor::solve_linear_system
void solve_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
Definition: sympy_solver_visitor.cpp:286
nmodl::visitor::SympySolverVisitor::current_expression_statement
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
Definition: sympy_solver_visitor.hpp:122
nmodl::visitor::SympySolverVisitor::check_expr_statements_in_same_block
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
Definition: sympy_solver_visitor.cpp:73
nmodl::visitor::SympySolverVisitor::visit_diff_eq_expression
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
Definition: sympy_solver_visitor.cpp:389
ast_visitor.hpp
Concrete visitor for all AST classes.