|
User Guide
|
Go to the documentation of this file.
39 auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
40 for (
const auto& localvar: localvars) {
41 std::string var_name = localvar->get_name();
42 if (localvar->is_array()) {
45 vars.insert(var_name);
49 {ast::AstNodeType::FUNCTION_CALL});
50 for (
const auto& call: fcall_nodes) {
65 const std::string& new_expr) {
67 auto new_expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(new_statement);
68 auto new_bin_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(
69 new_expr_statement->get_expression());
79 "SympySolverVisitor :: Coupled equations are appearing in different blocks - not "
91 auto it = statements.begin();
93 while ((it != statements.end()) &&
94 (std::dynamic_pointer_cast<ast::ExpressionStatement>(*it).get() !=
96 logger->debug(
"SympySolverVisitor :: {} != {}",
101 if (it != statements.end()) {
102 logger->debug(
"SympySolverVisitor :: {} == {}",
103 to_nmodl(std::dynamic_pointer_cast<ast::ExpressionStatement>(*it)),
121 if (statement->is_local_list_statement()) {
124 if (statement->is_expression_statement()) {
125 auto e_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
126 auto expression = e_statement->get_expression();
127 if (expression->is_local_list_statement()) {
135 const std::string& from,
136 const std::string& to) {
137 std::size_t lookHere = 0;
138 std::size_t foundHere{};
139 while ((foundHere = context.find(from, lookHere)) != std::string::npos) {
140 context.replace(foundHere, from.size(), to);
141 lookHere = foundHere + to.size();
147 const std::vector<std::string>& original_vector,
148 const std::string& original_string,
149 const std::string& substitution_string) {
150 std::vector<std::string> filtered_vector;
151 for (
auto element: original_vector) {
152 std::string filtered_element =
replaceAll(element, original_string, substitution_string);
153 filtered_vector.push_back(filtered_element);
155 return filtered_vector;
159 const std::vector<std::string>& pre_solve_statements,
160 const std::vector<std::string>& solutions,
167 for (
const auto& sol: solutions_filtered) {
168 logger->debug(
"SympySolverVisitor :: -> adding statement: {}", sol);
171 std::vector<std::string> pre_solve_statements_and_setup_x_eqs(pre_solve_statements);
172 std::vector<std::string> update_statements;
177 pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
178 update_statements.push_back(update_state);
179 logger->debug(
"SympySolverVisitor :: setup_x_eigen: {}", setup_x);
180 logger->debug(
"SympySolverVisitor :: update_state: {}", update_state);
184 pre_solve_statements_and_setup_x_eqs,
193 auto n_state_vars = std::make_shared<ast::Integer>(
state_vars.size(),
nullptr);
207 for (
size_t idx = 0; idx < statements.size(); ++idx) {
208 auto& s = statements[idx];
210 variable_statements.push_back(s);
211 }
else if (sr_begin == statements.size() || idx < sr_begin) {
212 initialize_statements.push_back(s);
216 if (sr_begin != statements.size()) {
217 initialize_statements.insert(initialize_statements.end(),
218 statements.begin() + sr_begin,
219 statements.begin() + sr_begin +
220 static_cast<std::ptrdiff_t
>(pre_solve_statements.size()));
222 statements.begin() + sr_begin +
223 static_cast<std::ptrdiff_t
>(pre_solve_statements.size()),
224 statements.begin() + sr_begin +
225 static_cast<std::ptrdiff_t
>(pre_solve_statements.size() +
state_vars.size()));
227 statements.begin() + sr_begin +
228 static_cast<std::ptrdiff_t
>(pre_solve_statements.size() +
state_vars.size()),
229 statements.begin() + sr_end);
233 const size_t total_statements_size = variable_statements.size() + initialize_statements.size() +
234 setup_x_statements.size() + functor_statements.size() +
235 finalize_statements.size();
236 if (statements.size() != total_statements_size) {
238 "SympySolverVisitor :: statement number missmatch ({} =/= {}) during splitting before "
243 total_statements_size);
247 auto variable_block = std::make_shared<ast::StatementBlock>(std::move(variable_statements));
248 auto initialize_block = std::make_shared<ast::StatementBlock>(std::move(initialize_statements));
250 auto finalize_block = std::make_shared<ast::StatementBlock>(std::move(finalize_statements));
253 setup_x_statements.insert(setup_x_statements.end(),
254 functor_statements.begin(),
255 functor_statements.end());
256 auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
257 auto solver_block = std::make_shared<ast::EigenLinearSolverBlock>(n_state_vars,
265 std::make_shared<ast::ExpressionStatement>(solver_block)};
269 auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
270 auto functor_block = std::make_shared<ast::StatementBlock>(std::move(functor_statements));
271 auto solver_block = std::make_shared<ast::EigenNewtonSolverBlock>(n_state_vars,
280 std::make_shared<ast::ExpressionStatement>(solver_block)};
295 solver->small_system = small_system;
299 solver->tmp_unique_prefix = tmp_unique_prefix;
303 auto solutions = solver->solutions;
305 auto new_local_vars = solver->new_local_vars;
307 auto exception_message = solver->exception_message;
310 if (!exception_message.empty()) {
311 logger->warn(
"SympySolverVisitor :: solve_lin_system python exception: " +
319 logger->debug(
"SympySolverVisitor :: Solving *small* linear system of eqs");
321 if (!new_local_vars.empty()) {
322 for (
const auto& new_local_var: new_local_vars) {
323 logger->debug(
"SympySolverVisitor :: -> declaring new local variable: {}",
329 pre_solve_statements,
338 logger->debug(
"SympySolverVisitor :: Constructing linear newton solve block");
344 const std::vector<std::string>& pre_solve_statements) {
356 auto solutions = solver->solutions;
358 auto exception_message = solver->exception_message;
360 if (!exception_message.empty()) {
361 logger->warn(
"SympySolverVisitor :: solve_non_lin_system python exception: " +
365 logger->debug(
"SympySolverVisitor :: Constructing eigen newton solve block");
372 if (node.
get_name()->is_indexed_name()) {
373 auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(node.
get_name());
377 std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
383 logger->debug(
"SympySolverVisitor :: adding state var: {}", var_name);
392 if (!lhs->is_var_name()) {
393 logger->warn(
"SympySolverVisitor :: LHS of differential equation is not a VariableName");
396 auto lhs_name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
397 if ((lhs_name->is_indexed_name() &&
398 !std::dynamic_pointer_cast<ast::IndexedName>(lhs_name)->get_name()->is_prime_name()) ||
399 (!lhs_name->is_indexed_name() && !lhs_name->is_prime_name())) {
400 logger->warn(
"SympySolverVisitor :: LHS of differential equation is not a PrimeName");
413 diffeq_solver->node_as_nmodl = node_as_nmodl;
415 diffeq_solver->vars =
vars;
424 logger->debug(
"SympySolverVisitor :: EULER - solving: {}", node_as_nmodl);
429 logger->debug(
"SympySolverVisitor :: CNEXP - solving: {}", node_as_nmodl);
433 std::string var_name = lhs_name->get_node_name();
434 if (lhs_name->is_indexed_name()) {
435 auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(lhs_name);
439 std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
442 logger->debug(
"SympySolverVisitor :: adding ODE system: {}", eq_str);
444 logger->debug(
"SympySolverVisitor :: adding state var: {}", var_name);
452 auto solution = diffeq_solver->solution;
453 logger->debug(
"SympySolverVisitor :: -> solution: {}", solution);
455 auto exception_message = diffeq_solver->exception_message;
456 if (!exception_message.empty()) {
457 logger->warn(
"SympySolverVisitor :: python exception: " + exception_message);
461 if (!solution.empty()) {
464 logger->warn(
"SympySolverVisitor :: solution to differential equation not possible");
471 logger->debug(
"SympySolverVisitor :: CONSERVE statement: {}",
to_nmodl(node));
473 std::string conserve_equation_statevar;
474 if (node.
get_react()->is_react_var_name()) {
475 conserve_equation_statevar = node.
get_react()->get_node_name();
480 "SympySolverVisitor :: Invalid CONSERVE statement for DERIVATIVE block, LHS should be "
481 "a state variable, instead found: {}. Ignoring CONSERVE statement",
486 logger->debug(
"SympySolverVisitor :: --> replace ODE for state var {} with equation {}",
487 conserve_equation_statevar,
488 conserve_equation_str);
509 std::vector<std::string> pre_solve_statements;
514 std::string x_array_index;
515 std::string x_array_index_i;
516 if (x_prime_split.size() > 1 &&
stringutils::trim(x_prime_split[1]).size() > 2) {
518 x_array_index_i =
"_" + x_array_index.substr(1, x_array_index.size() - 2);
520 std::string state_var_name = x + x_array_index;
524 eq = state_var_name +
" = " + var_eq_pair->second;
526 "SympySolverVisitor :: -> instead of Euler eq using CONSERVE equation: {} = {}",
528 var_eq_pair->second);
533 auto const old_x = [&]() {
534 std::string old_x_name{
"old_"};
535 old_x_name.append(x);
536 old_x_name.append(x_array_index_i);
540 logger->debug(
"SympySolverVisitor :: -> declaring new local variable: {}", old_x);
544 std::string expression{old_x};
545 expression.append(
" = ");
546 expression.append(x);
547 expression.append(x_array_index);
548 pre_solve_statements.push_back(std::move(expression));
552 eq.append(x_array_index);
560 logger->debug(
"SympySolverVisitor :: -> constructed Euler eq: {}", eq);
581 logger->debug(
"SympySolverVisitor :: adding linear eq: {}", lin_eq);
609 logger->debug(
"SympySolverVisitor :: adding non-linear eq: {}", non_lin_eq);
650 for (
const auto& block: solve_block_nodes) {
651 if (
auto block_ptr = std::dynamic_pointer_cast<const ast::SolveBlock>(block)) {
652 const auto& block_name = block_ptr->get_block_name()->get_value()->eval();
653 if (block_ptr->get_method()) {
656 const auto&
solve_method = block_ptr->get_method()->get_value()->eval();
657 logger->debug(
"SympySolverVisitor :: Found SOLVE statement: using {} for {}",
668 auto statevars = symtab->get_variables_with_properties(NmodlType::state_var);
669 for (
const auto& v: statevars) {
670 std::string var_name = v->get_name();
672 for (
int i = 0; i < v->get_length(); ++i) {
673 std::string var_name_i = var_name +
"[" +
std::to_string(i) +
"]";
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Base class for all AST node.
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
void visit_children(visitor::Visitor &v) override
visit children i.e.
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
std::shared_ptr< Expression > get_react() const noexcept
Getter for member variable Conserve::react.
Implement class to represent a symbol in Symbol Table.
int replaced_statements_end() const
idx (in the new statementVector) of the last statement that was added.
Represents differential equation in DERIVATIVE block.
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
decltype(&create_nsls_executor_func) create_nsls_executor
void visit_children(visitor::Visitor &v) override
visit children i.e.
Represent CONSERVE statement in NMODL.
int replaced_statements_begin() const
idx (in the new statementVector) of the first statement that was added.
std::shared_ptr< Expression > get_expr() const noexcept
Getter for member variable Conserve::expr.
std::string solve_method
method specified in solve block
std::vector< std::shared_ptr< Statement > > StatementVector
std::string get_node_name() const override
Return name of the node.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
encapsulates code generation backend implementations
void set_statements(StatementVector &&statements)
Setter for member variable StatementBlock::statements (rvalue reference)
const pybind_wrap_api * api()
Get a pointer to the pybind_wrap_api struct.
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
static bool is_local_statement(const std::shared_ptr< ast::Statement > &statement)
Check if provided statement is local variable declaration statement.
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
static constexpr char NTHREAD_DT_VARIABLE[]
dt variable in neuron thread structure
Implement string manipulation functions.
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
void solve_non_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
void visit_children(visitor::Visitor &v) override
visit children i.e.
std::set< std::string > vars
local variables in current block + globals
decltype(&destroy_des_executor_func) destroy_des_executor
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
std::set< std::string > global_vars
global variables
std::string get_node_name() const override
Return name of the node.
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
static std::string trim(std::string text)
Utility functions for visitors implementation.
void visit_children(visitor::Visitor &v) override
visit children i.e.
std::shared_ptr< Expression > get_left_linxpression() const noexcept
Getter for member variable LinEquation::left_linxpression.
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
std::set< std::string > function_calls
custom function calls used in ODE block
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
std::string get_node_name() const override
Return name of the node.
std::string to_string(const T &obj)
void visit_children(visitor::Visitor &v) override
visit children i.e.
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
@ SOLVE_BLOCK
type of ast::SolveBlock
std::shared_ptr< BinaryExpression > get_expression() const noexcept
Getter for member variable DiffEqExpression::expression.
std::set< std::string > state_vars_in_block
set of state variables used in block
void visit_children(visitor::Visitor &v) override
visit children i.e.
@ GREEDY
Replace statements greedily.
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
Represents DERIVATIVE block in the NMODL.
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
void init_state_vars_vector()
construct vector from set of state vars in correct order
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
void set_expression(std::shared_ptr< BinaryExpression > &&expression)
Setter for member variable DiffEqExpression::expression (rvalue reference)
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
void visit_children(visitor::Visitor &v) override
visit children i.e.
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
static std::vector< std::string > split_string(const std::string &text, char delimiter)
Split a text in a list of words, using a given delimiter character.
decltype(&create_des_executor_func) create_des_executor
void visit_children(visitor::Visitor &v) override
visit children i.e.
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable NonLinEquation::rhs.
Represents block encapsulating list of statements.
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
NmodlType
NMODL variable properties.
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
decltype(&create_sls_executor_func) create_sls_executor
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
Represents LINEAR block in the NMODL.
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Represents NONLINEAR block in the NMODL.
Implement logger based on spdlog library.
ast::StatementBlock * current_statement_block
current statement block being visited
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
static constexpr char EULER_METHOD[]
euler method in nmodl
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable NonLinEquation::lhs.
void visit_program(ast::Program &node) override
visit node of type ast::Program
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
virtual std::shared_ptr< StatementBlock > get_statement_block() const
Return associated statement block for the AST node.
Represents top level AST node for whole NMODL input.
decltype(&destroy_nsls_executor_func) destroy_nsls_executor
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
void solve_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
decltype(&destroy_sls_executor_func) destroy_sls_executor
@ VALUE
Replace statements matching by lhs varName.
std::shared_ptr< Expression > get_linxpression() const noexcept
Getter for member variable LinEquation::linxpression.
Visitor for systems of algebraic and differential equations
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Auto generated AST classes declaration.
std::string get_node_name() const override
Return name of the node.
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression