![Logo](logo.png) |
User Guide
|
Go to the documentation of this file.
39 auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
40 for (
const auto& localvar: localvars) {
41 std::string var_name = localvar->get_name();
42 if (localvar->is_array()) {
45 vars.insert(var_name);
49 {ast::AstNodeType::FUNCTION_CALL});
50 for (
const auto& call: fcall_nodes) {
65 const auto& solvefor_vars =
dynamic_cast<const ast::LinearBlock*
>(node)->get_solvefor();
66 if (!solvefor_vars.empty()) {
68 for (
const auto& solvefor_var: solvefor_vars) {
69 state_vars.push_back(solvefor_var->get_node_name());
74 if (!solvefor_vars.empty()) {
76 for (
const auto& solvefor_var: solvefor_vars) {
77 state_vars.push_back(solvefor_var->get_node_name());
84 const std::string& new_expr) {
86 auto new_expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(new_statement);
87 auto new_bin_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(
88 new_expr_statement->get_expression());
98 "SympySolverVisitor :: Coupled equations are appearing in different blocks - not "
110 auto it = statements.begin();
112 while ((it != statements.end()) &&
113 (std::dynamic_pointer_cast<ast::ExpressionStatement>(*it).get() !=
115 logger->debug(
"SympySolverVisitor :: {} != {}",
120 if (it != statements.end()) {
121 logger->debug(
"SympySolverVisitor :: {} == {}",
122 to_nmodl(std::dynamic_pointer_cast<ast::ExpressionStatement>(*it)),
140 if (statement->is_local_list_statement()) {
143 if (statement->is_expression_statement()) {
144 auto e_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
145 auto expression = e_statement->get_expression();
146 if (expression->is_local_list_statement()) {
154 const std::string& from,
155 const std::string& to) {
156 std::size_t lookHere = 0;
157 std::size_t foundHere{};
158 while ((foundHere = context.find(from, lookHere)) != std::string::npos) {
159 context.replace(foundHere, from.size(), to);
160 lookHere = foundHere + to.size();
166 const std::vector<std::string>& original_vector,
167 const std::string& original_string,
168 const std::string& substitution_string) {
169 std::vector<std::string> filtered_vector;
170 for (
auto element: original_vector) {
171 std::string filtered_element =
replaceAll(element, original_string, substitution_string);
172 filtered_vector.push_back(filtered_element);
174 return filtered_vector;
178 const std::vector<std::string>& pre_solve_statements,
179 const std::vector<std::string>& solutions,
187 for (
const auto& sol: solutions_filtered) {
188 logger->debug(
"SympySolverVisitor :: -> adding statement: {}", sol);
191 std::vector<std::string> pre_solve_statements_and_setup_x_eqs = pre_solve_statements;
192 std::vector<std::string> update_statements;
195 auto eigen_name = fmt::format(
"nmodl_eigen_x[{}]", i);
197 auto update_state = fmt::format(
"{} = {}",
state_vars[i], eigen_name);
198 update_statements.push_back(update_state);
199 logger->debug(
"SympySolverVisitor :: update_state: {}", update_state);
201 auto setup_x = fmt::format(
"{} = {}", eigen_name,
state_vars[i]);
202 pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
203 logger->debug(
"SympySolverVisitor :: setup_x_eigen: {}", setup_x);
207 pre_solve_statements_and_setup_x_eqs,
216 auto n_state_vars = std::make_shared<ast::Integer>(
state_vars.size(),
nullptr);
230 for (
size_t idx = 0; idx < statements.size(); ++idx) {
231 auto& s = statements[idx];
233 variable_statements.push_back(s);
234 }
else if (sr_begin == statements.size() || idx < sr_begin) {
235 initialize_statements.push_back(s);
239 if (sr_begin != statements.size()) {
240 auto init_begin = statements.begin() + sr_begin;
241 auto init_end = init_begin +
static_cast<std::ptrdiff_t
>(pre_solve_statements.size());
242 initialize_statements.insert(initialize_statements.end(), init_begin, init_end);
244 auto setup_x_begin = init_end;
245 auto setup_x_end = setup_x_begin +
static_cast<std::ptrdiff_t
>(
state_vars.size());
248 auto functor_begin = setup_x_end;
249 auto functor_end = statements.begin() + sr_end;
252 auto finalize_begin = functor_end;
253 auto finalize_end = statements.end();
257 const size_t total_statements_size = variable_statements.size() + initialize_statements.size() +
258 setup_x_statements.size() + functor_statements.size() +
259 finalize_statements.size();
260 if (statements.size() != total_statements_size) {
262 "SympySolverVisitor :: statement number missmatch ({} =/= {}) during splitting before "
267 total_statements_size);
271 auto variable_block = std::make_shared<ast::StatementBlock>(std::move(variable_statements));
272 auto initialize_block = std::make_shared<ast::StatementBlock>(std::move(initialize_statements));
274 auto finalize_block = std::make_shared<ast::StatementBlock>(std::move(finalize_statements));
277 setup_x_statements.insert(setup_x_statements.end(),
278 functor_statements.begin(),
279 functor_statements.end());
280 auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
281 auto solver_block = std::make_shared<ast::EigenLinearSolverBlock>(n_state_vars,
289 std::make_shared<ast::ExpressionStatement>(solver_block)};
293 auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
294 auto functor_block = std::make_shared<ast::StatementBlock>(std::move(functor_statements));
295 auto solver_block = std::make_shared<ast::EigenNewtonSolverBlock>(n_state_vars,
304 std::make_shared<ast::ExpressionStatement>(solver_block)};
311 const std::vector<std::string>& pre_solve_statements) {
323 auto [solutions, new_local_vars, exception_message] = solver(
326 if (!exception_message.empty()) {
328 "SympySolverVisitor :: solve_lin_system python exception occured. (--verbose=info)");
329 logger->info(exception_message +
330 "\n (Note: line numbers are of by a few compared to `ode.py`.)");
337 logger->debug(
"SympySolverVisitor :: Solving *small* linear system of eqs");
339 if (!new_local_vars.empty()) {
340 for (
const auto& new_local_var: new_local_vars) {
341 logger->debug(
"SympySolverVisitor :: -> declaring new local variable: {}",
347 pre_solve_statements,
356 logger->debug(
"SympySolverVisitor :: Constructing linear newton solve block");
363 const std::vector<std::string>& pre_solve_statements) {
370 if (!exception_message.empty()) {
372 "SympySolverVisitor :: solve_non_lin_system python exception. (--verbose=info)");
373 logger->info(exception_message +
374 "\n (Note: line numbers are of by a few compared to `ode.py`.)");
377 logger->debug(
"SympySolverVisitor :: Constructing eigen newton solve block");
385 if (node.
get_name()->is_indexed_name()) {
386 auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(node.
get_name());
390 std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
396 logger->debug(
"SympySolverVisitor :: adding state var: {}", var_name);
408 if (!lhs->is_var_name()) {
409 logger->warn(
"SympySolverVisitor :: LHS of differential equation is not a VariableName");
412 auto lhs_name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
413 if ((lhs_name->is_indexed_name() &&
414 !std::dynamic_pointer_cast<ast::IndexedName>(lhs_name)->get_name()->is_prime_name()) ||
415 (!lhs_name->is_indexed_name() && !lhs_name->is_prime_name())) {
416 logger->warn(
"SympySolverVisitor :: LHS of differential equation is not a PrimeName");
426 auto [solution, exception_message] = (*diffeq_solver)(
432 logger->debug(
"SympySolverVisitor :: EULER - solving: {}", node_as_nmodl);
437 logger->debug(
"SympySolverVisitor :: CNEXP - solving: {}", node_as_nmodl);
441 std::string var_name = lhs_name->get_node_name();
442 if (lhs_name->is_indexed_name()) {
443 auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(lhs_name);
447 std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
450 logger->debug(
"SympySolverVisitor :: adding ODE system: {}", eq_str);
452 logger->debug(
"SympySolverVisitor :: adding state var: {}", var_name);
460 logger->debug(
"SympySolverVisitor :: -> solution: {}", solution);
462 if (!exception_message.empty()) {
463 logger->warn(
"SympySolverVisitor :: python exception. (--verbose=info)");
464 logger->info(exception_message +
465 "\n (Note: line numbers are of by a few compared to `ode.py`.)");
469 if (!solution.empty()) {
472 logger->warn(
"SympySolverVisitor :: solution to differential equation not possible");
479 logger->debug(
"SympySolverVisitor :: CONSERVE statement: {}",
to_nmodl(node));
481 std::string conserve_equation_statevar;
482 if (node.
get_react()->is_react_var_name()) {
483 conserve_equation_statevar = node.
get_react()->get_node_name();
488 "SympySolverVisitor :: Invalid CONSERVE statement for DERIVATIVE block, LHS should be "
489 "a state variable, instead found: {}. Ignoring CONSERVE statement",
494 logger->debug(
"SympySolverVisitor :: --> replace ODE for state var {} with equation {}",
495 conserve_equation_statevar,
496 conserve_equation_str);
517 std::vector<std::string> pre_solve_statements;
522 std::string x_array_index;
523 std::string x_array_index_i;
524 if (x_prime_split.size() > 1 &&
stringutils::trim(x_prime_split[1]).size() > 2) {
526 x_array_index_i =
"_" + x_array_index.substr(1, x_array_index.size() - 2);
528 std::string state_var_name = x + x_array_index;
532 eq = state_var_name +
" = " + var_eq_pair->second;
534 "SympySolverVisitor :: -> instead of Euler eq using CONSERVE equation: {} = {}",
536 var_eq_pair->second);
541 auto const old_x = [&]() {
542 std::string old_x_name{
"old_"};
543 old_x_name.append(x);
544 old_x_name.append(x_array_index_i);
548 logger->debug(
"SympySolverVisitor :: -> declaring new local variable: {}", old_x);
552 std::string expression{old_x};
553 expression.append(
" = ");
554 expression.append(x);
555 expression.append(x_array_index);
556 pre_solve_statements.push_back(std::move(expression));
561 eq.append(x_array_index);
568 logger->debug(
"SympySolverVisitor :: -> constructed Euler eq: {}", eq);
589 logger->debug(
"SympySolverVisitor :: adding linear eq: {}", lin_eq);
617 logger->debug(
"SympySolverVisitor :: adding non-linear eq: {}", non_lin_eq);
658 for (
const auto& block: solve_block_nodes) {
659 if (
auto block_ptr = std::dynamic_pointer_cast<const ast::SolveBlock>(block)) {
660 const auto& block_name = block_ptr->get_block_name()->get_value()->eval();
661 if (block_ptr->get_method()) {
664 const auto&
solve_method = block_ptr->get_method()->get_value()->eval();
665 logger->debug(
"SympySolverVisitor :: Found SOLVE statement: using {} for {}",
676 auto statevars = symtab->get_variables_with_properties(NmodlType::state_var);
677 for (
const auto& v: statevars) {
678 std::string var_name = v->get_name();
680 for (
int i = 0; i < v->get_length(); ++i) {
681 std::string var_name_i = var_name +
"[" +
std::to_string(i) +
"]";
decltype(&call_solve_linear_system) solve_linear_system
void visit_cvode_block(ast::CvodeBlock &node) override
visit node of type ast::CvodeBlock
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Base class for all AST node.
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
void visit_children(visitor::Visitor &v) override
visit children i.e.
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
std::shared_ptr< Expression > get_react() const noexcept
Getter for member variable Conserve::react.
Implement class to represent a symbol in Symbol Table.
int replaced_statements_end() const
idx (in the new statementVector) of the last statement that was added.
Represents differential equation in DERIVATIVE block.
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
virtual bool is_non_linear_block() const noexcept
Check if the ast node is an instance of ast::NonLinearBlock.
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
One equation in a system of equations that collectively make a NONLINEAR block.
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
void visit_children(visitor::Visitor &v) override
visit children i.e.
Represent CONSERVE statement in NMODL.
int replaced_statements_begin() const
idx (in the new statementVector) of the first statement that was added.
std::shared_ptr< Expression > get_expr() const noexcept
Getter for member variable Conserve::expr.
std::string solve_method
method specified in solve block
std::vector< std::shared_ptr< Statement > > StatementVector
std::string get_node_name() const override
Return name of the node.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
encapsulates code generation backend implementations
void set_statements(StatementVector &&statements)
Setter for member variable StatementBlock::statements (rvalue reference)
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
static bool is_local_statement(const std::shared_ptr< ast::Statement > &statement)
Check if provided statement is local variable declaration statement.
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
static constexpr char NTHREAD_DT_VARIABLE[]
dt variable in neuron thread structure
Implement string manipulation functions.
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
void visit_children(visitor::Visitor &v) override
visit children i.e.
std::set< std::string > vars
local variables in current block + globals
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
std::set< std::string > global_vars
global variables
void init_state_vars_vector(const ast::Node *node)
construct vector from set of state vars in correct order
std::string get_node_name() const override
Return name of the node.
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
static std::string trim(std::string text)
Utility functions for visitors implementation.
void visit_children(visitor::Visitor &v) override
visit children i.e.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::set< std::string > function_calls
custom function calls used in ODE block
decltype(&call_diffeq_solver) diffeq_solver
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
decltype(&call_solve_nonlinear_system) solve_nonlinear_system
void solve_non_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
std::string get_node_name() const override
Return name of the node.
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
std::string to_string(const T &obj)
void visit_children(visitor::Visitor &v) override
visit children i.e.
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
void solve_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
@ SOLVE_BLOCK
type of ast::SolveBlock
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable LinEquation::lhs.
std::shared_ptr< BinaryExpression > get_expression() const noexcept
Getter for member variable DiffEqExpression::expression.
std::set< std::string > state_vars_in_block
set of state variables used in block
void visit_children(visitor::Visitor &v) override
visit children i.e.
@ GREEDY
Replace statements greedily.
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
Represents DERIVATIVE block in the NMODL.
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
void set_expression(std::shared_ptr< BinaryExpression > &&expression)
Setter for member variable DiffEqExpression::expression (rvalue reference)
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
void visit_children(visitor::Visitor &v) override
visit children i.e.
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
static std::vector< std::string > split_string(const std::string &text, char delimiter)
Split a text in a list of words, using a given delimiter character.
void visit_children(visitor::Visitor &v) override
visit children i.e.
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable NonLinEquation::rhs.
Represents block encapsulating list of statements.
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
NmodlType
NMODL variable properties.
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
Represents LINEAR block in the NMODL.
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Represents NONLINEAR block in the NMODL.
Implement logger based on spdlog library.
ast::StatementBlock * current_statement_block
current statement block being visited
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
virtual bool is_linear_block() const noexcept
Check if the ast node is an instance of ast::LinearBlock.
static constexpr char EULER_METHOD[]
euler method in nmodl
One equation in a system of equations tha collectively form a LINEAR block.
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable NonLinEquation::lhs.
void visit_program(ast::Program &node) override
visit node of type ast::Program
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
virtual std::shared_ptr< StatementBlock > get_statement_block() const
Return associated statement block for the AST node.
Represents top level AST node for whole NMODL input.
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable LinEquation::rhs.
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
@ VALUE
Replace statements matching by lhs varName.
Visitor for systems of algebraic and differential equations
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Auto generated AST classes declaration.
std::string get_node_name() const override
Return name of the node.
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression