User Guide
sympy_solver_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
10 
11 #include "ast/all.hpp"
13 #include "pybind/pyembed.hpp"
14 #include "symtab/symbol.hpp"
15 #include "utils/logger.hpp"
16 #include "utils/string_utils.hpp"
18 
19 
20 namespace pywrap = nmodl::pybind_wrappers;
21 
22 namespace nmodl {
23 namespace visitor {
24 
26 
28  // clear any previous data
29  expression_statements.clear();
30  eq_system.clear();
31  state_vars_in_block.clear();
32  last_expression_statement = nullptr;
34  eq_system_is_valid = true;
35  conserve_equation.clear();
36  // get set of local block vars & global vars
37  vars = global_vars;
38  if (auto symtab = node->get_statement_block()->get_symbol_table()) {
39  auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
40  for (const auto& localvar: localvars) {
41  std::string var_name = localvar->get_name();
42  if (localvar->is_array()) {
43  var_name += "[" + std::to_string(localvar->get_length()) + "]";
44  }
45  vars.insert(var_name);
46  }
47  }
48  const auto& fcall_nodes = collect_nodes(*node->get_statement_block(),
49  {ast::AstNodeType::FUNCTION_CALL});
50  for (const auto& call: fcall_nodes) {
51  function_calls.insert(call->get_node_name());
52  }
53 }
54 
56  state_vars.clear();
57  for (const auto& state_var: all_state_vars) {
58  if (state_vars_in_block.find(state_var) != state_vars_in_block.cend()) {
59  state_vars.push_back(state_var);
60  }
61  }
62 
63  // in case we have a SOLVEFOR in the block, we need to set `state_vars` to those instead
64  if (node->is_linear_block()) {
65  const auto& solvefor_vars = dynamic_cast<const ast::LinearBlock*>(node)->get_solvefor();
66  if (!solvefor_vars.empty()) {
67  state_vars.clear();
68  for (const auto& solvefor_var: solvefor_vars) {
69  state_vars.push_back(solvefor_var->get_node_name());
70  }
71  }
72  } else if (node->is_non_linear_block()) {
73  const auto& solvefor_vars = dynamic_cast<const ast::NonLinearBlock*>(node)->get_solvefor();
74  if (!solvefor_vars.empty()) {
75  state_vars.clear();
76  for (const auto& solvefor_var: solvefor_vars) {
77  state_vars.push_back(solvefor_var->get_node_name());
78  }
79  }
80  }
81 }
82 
84  const std::string& new_expr) {
85  auto new_statement = create_statement(new_expr);
86  auto new_expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(new_statement);
87  auto new_bin_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(
88  new_expr_statement->get_expression());
89  expr.set_expression(std::move(new_bin_expr));
90 }
91 
93  /// all ode/kinetic/(non)linear statements (typically) appear in the same statement block
94  /// if this is not the case, for now return an error (and should instead use fallback solver)
95  if (block_with_expression_statements != nullptr &&
97  logger->warn(
98  "SympySolverVisitor :: Coupled equations are appearing in different blocks - not "
99  "supported");
100  eq_system_is_valid = false;
101  }
103 }
104 
105 ast::StatementVector::const_iterator SympySolverVisitor::get_solution_location_iterator(
106  const ast::StatementVector& statements) {
107  // find out where to insert solutions in statement block
108  // returns iterator pointing to the first element after the last (non)linear eq
109  // so if there are no such elements, it returns statements.end()
110  auto it = statements.begin();
111  if (last_expression_statement != nullptr) {
112  while ((it != statements.end()) &&
113  (std::dynamic_pointer_cast<ast::ExpressionStatement>(*it).get() !=
115  logger->debug("SympySolverVisitor :: {} != {}",
116  to_nmodl(*it),
118  ++it;
119  }
120  if (it != statements.end()) {
121  logger->debug("SympySolverVisitor :: {} == {}",
122  to_nmodl(std::dynamic_pointer_cast<ast::ExpressionStatement>(*it)),
124  ++it;
125  }
126  }
127  return it;
128 }
129 
130 /**
131  * Check if provided statement is local variable declaration statement
132  * @param statement AST node representing statement in the MOD file
133  * @return True if statement is local variable declaration else False
134  *
135  * Statement declaration could be wrapped into another statement type like
136  * expression statement and hence we try to look inside if it's really a
137  * variable declaration.
138  */
139 static bool is_local_statement(const std::shared_ptr<ast::Statement>& statement) {
140  if (statement->is_local_list_statement()) {
141  return true;
142  }
143  if (statement->is_expression_statement()) {
144  auto e_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
145  auto expression = e_statement->get_expression();
146  if (expression->is_local_list_statement()) {
147  return true;
148  }
149  }
150  return false;
151 }
152 
153 std::string& SympySolverVisitor::replaceAll(std::string& context,
154  const std::string& from,
155  const std::string& to) {
156  std::size_t lookHere = 0;
157  std::size_t foundHere{};
158  while ((foundHere = context.find(from, lookHere)) != std::string::npos) {
159  context.replace(foundHere, from.size(), to);
160  lookHere = foundHere + to.size();
161  }
162  return context;
163 }
164 
166  const std::vector<std::string>& original_vector,
167  const std::string& original_string,
168  const std::string& substitution_string) {
169  std::vector<std::string> filtered_vector;
170  for (auto element: original_vector) {
171  std::string filtered_element = replaceAll(element, original_string, substitution_string);
172  filtered_vector.push_back(filtered_element);
173  }
174  return filtered_vector;
175 }
176 
178  const std::vector<std::string>& pre_solve_statements,
179  const std::vector<std::string>& solutions,
180  bool linear) {
181  auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x[");
182  solutions_filtered = filter_string_vector(solutions_filtered, "dX_[", "nmodl_eigen_dx[");
183  solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j[");
184  solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm[");
185  solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f[");
186 
187  for (const auto& sol: solutions_filtered) {
188  logger->debug("SympySolverVisitor :: -> adding statement: {}", sol);
189  }
190 
191  std::vector<std::string> pre_solve_statements_and_setup_x_eqs = pre_solve_statements;
192  std::vector<std::string> update_statements;
193 
194  for (int i = 0; i < state_vars.size(); i++) {
195  auto eigen_name = fmt::format("nmodl_eigen_x[{}]", i);
196 
197  auto update_state = fmt::format("{} = {}", state_vars[i], eigen_name);
198  update_statements.push_back(update_state);
199  logger->debug("SympySolverVisitor :: update_state: {}", update_state);
200 
201  auto setup_x = fmt::format("{} = {}", eigen_name, state_vars[i]);
202  pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
203  logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
204  }
205 
206  visitor::SympyReplaceSolutionsVisitor solution_replacer(
207  pre_solve_statements_and_setup_x_eqs,
208  solutions_filtered,
211  state_vars.size() + 1,
212  "");
214 
215  // split in the various blocks for eigen
216  auto n_state_vars = std::make_shared<ast::Integer>(state_vars.size(), nullptr);
217 
218  const auto& statements = block_with_expression_statements->get_statements();
219 
220  ast::StatementVector variable_statements; // LOCAL //
221  ast::StatementVector initialize_statements; // pre_solve_statements //
222  ast::StatementVector setup_x_statements; // old_x = x, X[0] = x //
223  ast::StatementVector functor_statements; // J[0]_row * X = F[0], additional assignments during
224  // computation //
225  ast::StatementVector finalize_statements; // assignments at the end //
226  std::ptrdiff_t const sr_begin{solution_replacer.replaced_statements_begin()};
227  std::ptrdiff_t const sr_end{solution_replacer.replaced_statements_end()};
228 
229  // initialize and edge case where the system of equations is empty
230  for (size_t idx = 0; idx < statements.size(); ++idx) {
231  auto& s = statements[idx];
232  if (is_local_statement(s)) {
233  variable_statements.push_back(s);
234  } else if (sr_begin == statements.size() || idx < sr_begin) {
235  initialize_statements.push_back(s);
236  }
237  }
238 
239  if (sr_begin != statements.size()) {
240  auto init_begin = statements.begin() + sr_begin;
241  auto init_end = init_begin + static_cast<std::ptrdiff_t>(pre_solve_statements.size());
242  initialize_statements.insert(initialize_statements.end(), init_begin, init_end);
243 
244  auto setup_x_begin = init_end;
245  auto setup_x_end = setup_x_begin + static_cast<std::ptrdiff_t>(state_vars.size());
246  setup_x_statements = ast::StatementVector(setup_x_begin, setup_x_end);
247 
248  auto functor_begin = setup_x_end;
249  auto functor_end = statements.begin() + sr_end;
250  functor_statements = ast::StatementVector(functor_begin, functor_end);
251 
252  auto finalize_begin = functor_end;
253  auto finalize_end = statements.end();
254  finalize_statements = ast::StatementVector(finalize_begin, finalize_end);
255  }
256 
257  const size_t total_statements_size = variable_statements.size() + initialize_statements.size() +
258  setup_x_statements.size() + functor_statements.size() +
259  finalize_statements.size();
260  if (statements.size() != total_statements_size) {
261  logger->error(
262  "SympySolverVisitor :: statement number missmatch ({} =/= {}) during splitting before "
263  "creation of "
264  "eigen "
265  "solver block.",
266  statements.size(),
267  total_statements_size);
268  return;
269  }
270 
271  auto variable_block = std::make_shared<ast::StatementBlock>(std::move(variable_statements));
272  auto initialize_block = std::make_shared<ast::StatementBlock>(std::move(initialize_statements));
273  auto update_state_block = create_statement_block(update_statements);
274  auto finalize_block = std::make_shared<ast::StatementBlock>(std::move(finalize_statements));
275  if (linear) {
276  /// functor and initialize block converge in the same block
277  setup_x_statements.insert(setup_x_statements.end(),
278  functor_statements.begin(),
279  functor_statements.end());
280  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
281  auto solver_block = std::make_shared<ast::EigenLinearSolverBlock>(n_state_vars,
282  variable_block,
283  initialize_block,
284  setup_x_block,
285  update_state_block,
286  finalize_block);
287  /// replace statement block with solver block as it contains all statements
288  ast::StatementVector solver_block_statements{
289  std::make_shared<ast::ExpressionStatement>(solver_block)};
290  block_with_expression_statements->set_statements(std::move(solver_block_statements));
291  } else {
292  /// create eigen newton solver block
293  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
294  auto functor_block = std::make_shared<ast::StatementBlock>(std::move(functor_statements));
295  auto solver_block = std::make_shared<ast::EigenNewtonSolverBlock>(n_state_vars,
296  variable_block,
297  initialize_block,
298  setup_x_block,
299  functor_block,
300  update_state_block,
301  finalize_block);
302  /// replace statement block with solver block as it contains all statements
303  ast::StatementVector solver_block_statements{
304  std::make_shared<ast::ExpressionStatement>(solver_block)};
305  block_with_expression_statements->set_statements(std::move(solver_block_statements));
306  }
307 }
308 
309 
311  const std::vector<std::string>& pre_solve_statements) {
312  // construct ordered vector of state vars used in linear system
313  init_state_vars_vector(&node);
314  // call sympy linear solver
315  bool small_system = (eq_system.size() <= SMALL_LINEAR_SYSTEM_MAX_STATES);
317  // this is necessary after we destroy the solver
318  const auto tmp_unique_prefix = suffix_random_string(vars, "tmp");
319 
320  // returns a vector of solutions, i.e. new statements to add to block;
321  // and a vector of new local variables that need to be declared in the block;
322  // may also return a python exception message:
323  auto [solutions, new_local_vars, exception_message] = solver(
324  eq_system, state_vars, vars, small_system, elimination, tmp_unique_prefix, function_calls);
325 
326  if (!exception_message.empty()) {
327  logger->warn(
328  "SympySolverVisitor :: solve_lin_system python exception occured. (--verbose=info)");
329  logger->info(exception_message +
330  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
331  return;
332  }
333  // find out where to insert solutions in statement block
334  if (small_system) {
335  // for small number of state vars, linear solver
336  // directly returns solution by solving symbolically at compile time
337  logger->debug("SympySolverVisitor :: Solving *small* linear system of eqs");
338  // declare new local vars
339  if (!new_local_vars.empty()) {
340  for (const auto& new_local_var: new_local_vars) {
341  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}",
342  new_local_var);
344  }
345  }
346  visitor::SympyReplaceSolutionsVisitor solution_replacer(
347  pre_solve_statements,
348  solutions,
351  1,
352  tmp_unique_prefix);
354  } else {
355  // otherwise it returns a linear matrix system to solve
356  logger->debug("SympySolverVisitor :: Constructing linear newton solve block");
357  construct_eigen_solver_block(pre_solve_statements, solutions, true);
358  }
359 }
360 
362  const ast::Node& node,
363  const std::vector<std::string>& pre_solve_statements) {
364  // construct ordered vector of state vars used in non-linear system
365  init_state_vars_vector(&node);
366 
368  auto [solutions, exception_message] = solver(eq_system, state_vars, vars, function_calls);
369 
370  if (!exception_message.empty()) {
371  logger->warn(
372  "SympySolverVisitor :: solve_non_lin_system python exception. (--verbose=info)");
373  logger->info(exception_message +
374  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
375  return;
376  }
377  logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
378 
379  construct_eigen_solver_block(pre_solve_statements, solutions, false);
380 }
381 
383  if (collect_state_vars) {
384  std::string var_name = node.get_node_name();
385  if (node.get_name()->is_indexed_name()) {
386  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(node.get_name());
387  var_name +=
388  "[" +
390  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
391  "]";
392  }
393  // if var_name is a state var, add it to set
394  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), var_name) !=
395  all_state_vars.cend()) {
396  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
397  state_vars_in_block.insert(var_name);
398  }
399  }
400 }
401 
402 // Skip visiting CVODE block
404 
406  const auto& lhs = node.get_expression()->get_lhs();
407 
408  if (!lhs->is_var_name()) {
409  logger->warn("SympySolverVisitor :: LHS of differential equation is not a VariableName");
410  return;
411  }
412  auto lhs_name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
413  if ((lhs_name->is_indexed_name() &&
414  !std::dynamic_pointer_cast<ast::IndexedName>(lhs_name)->get_name()->is_prime_name()) ||
415  (!lhs_name->is_indexed_name() && !lhs_name->is_prime_name())) {
416  logger->warn("SympySolverVisitor :: LHS of differential equation is not a PrimeName");
417  return;
418  }
419 
421 
422  const auto node_as_nmodl = to_nmodl_for_sympy(node);
424 
426  auto [solution, exception_message] = (*diffeq_solver)(
427  node_as_nmodl, dt_var, vars, use_pade_approx, function_calls, solve_method);
429  // replace x' = f(x) differential equation
430  // with forwards Euler timestep:
431  // x = x + f(x) * dt
432  logger->debug("SympySolverVisitor :: EULER - solving: {}", node_as_nmodl);
434  // replace x' = f(x) differential equation
435  // with analytic solution for x(t+dt) in terms of x(t)
436  // x = ...
437  logger->debug("SympySolverVisitor :: CNEXP - solving: {}", node_as_nmodl);
438  } else {
439  // for other solver methods: just collect the ODEs & return
440  std::string eq_str = to_nmodl_for_sympy(node);
441  std::string var_name = lhs_name->get_node_name();
442  if (lhs_name->is_indexed_name()) {
443  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(lhs_name);
444  var_name +=
445  "[" +
447  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
448  "]";
449  }
450  logger->debug("SympySolverVisitor :: adding ODE system: {}", eq_str);
451  eq_system.push_back(eq_str);
452  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
453  state_vars_in_block.insert(var_name);
456  return;
457  }
458 
459  // replace ODE with solution in AST
460  logger->debug("SympySolverVisitor :: -> solution: {}", solution);
461 
462  if (!exception_message.empty()) {
463  logger->warn("SympySolverVisitor :: python exception. (--verbose=info)");
464  logger->info(exception_message +
465  "\n (Note: line numbers are of by a few compared to `ode.py`.)");
466  return;
467  }
468 
469  if (!solution.empty()) {
470  replace_diffeq_expression(node, solution);
471  } else {
472  logger->warn("SympySolverVisitor :: solution to differential equation not possible");
473  }
474 }
475 
477  // Replace ODE for state variable on LHS of CONSERVE statement with
478  // algebraic expression on RHS (see p244 of NEURON book)
479  logger->debug("SympySolverVisitor :: CONSERVE statement: {}", to_nmodl(node));
480  expression_statements.insert(&node);
481  std::string conserve_equation_statevar;
482  if (node.get_react()->is_react_var_name()) {
483  conserve_equation_statevar = node.get_react()->get_node_name();
484  }
485  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), conserve_equation_statevar) ==
486  all_state_vars.cend()) {
487  logger->error(
488  "SympySolverVisitor :: Invalid CONSERVE statement for DERIVATIVE block, LHS should be "
489  "a state variable, instead found: {}. Ignoring CONSERVE statement",
490  to_nmodl(node.get_react()));
491  return;
492  }
493  auto conserve_equation_str = to_nmodl_for_sympy(*node.get_expr());
494  logger->debug("SympySolverVisitor :: --> replace ODE for state var {} with equation {}",
495  conserve_equation_statevar,
496  conserve_equation_str);
497  conserve_equation[conserve_equation_statevar] = conserve_equation_str;
498 }
499 
501  /// clear information from previous block, get global vars + block local vars
502  init_block_data(&node);
503 
504  // get user specified solve method for this block
506 
507  // visit each differential equation:
508  // - for CNEXP or EULER, each equation is independent & is replaced with its solution
509  // - otherwise, each equation is added to eq_system
510  node.visit_children(*this);
511 
512  if (eq_system_is_valid && !eq_system.empty()) {
513  // solve system of ODEs in eq_system
514  logger->debug("SympySolverVisitor :: Solving {} system of ODEs", solve_method);
515 
516  // construct implicit Euler equations from ODEs
517  std::vector<std::string> pre_solve_statements;
518  for (auto& eq: eq_system) {
519  auto split_eq = stringutils::split_string(eq, '=');
520  auto x_prime_split = stringutils::split_string(split_eq[0], '\'');
521  auto x = stringutils::trim(x_prime_split[0]);
522  std::string x_array_index;
523  std::string x_array_index_i;
524  if (x_prime_split.size() > 1 && stringutils::trim(x_prime_split[1]).size() > 2) {
525  x_array_index = stringutils::trim(x_prime_split[1]);
526  x_array_index_i = "_" + x_array_index.substr(1, x_array_index.size() - 2);
527  }
528  std::string state_var_name = x + x_array_index;
529  auto var_eq_pair = conserve_equation.find(state_var_name);
530  if (var_eq_pair != conserve_equation.cend()) {
531  // replace the ODE for this state var with corresponding CONSERVE equation
532  eq = state_var_name + " = " + var_eq_pair->second;
533  logger->debug(
534  "SympySolverVisitor :: -> instead of Euler eq using CONSERVE equation: {} = {}",
535  state_var_name,
536  var_eq_pair->second);
537  } else {
538  // no CONSERVE equation, construct Euler equation
539  auto dxdt = stringutils::trim(split_eq[1]);
540 
541  auto const old_x = [&]() {
542  std::string old_x_name{"old_"};
543  old_x_name.append(x);
544  old_x_name.append(x_array_index_i);
545  return suffix_random_string(vars, old_x_name);
546  }();
547  // declare old_x
548  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}", old_x);
550  // assign old_x = x
551  {
552  std::string expression{old_x};
553  expression.append(" = ");
554  expression.append(x);
555  expression.append(x_array_index);
556  pre_solve_statements.push_back(std::move(expression));
557  }
558  // replace ODE with Euler equation
559  eq = "(";
560  eq.append(x);
561  eq.append(x_array_index);
562  eq.append(" - ");
563  eq.append(old_x);
564  eq.append(") / ");
566  eq.append(" = ");
567  eq.append(dxdt);
568  logger->debug("SympySolverVisitor :: -> constructed Euler eq: {}", eq);
569  }
570  }
571 
574  solve_non_linear_system(node, pre_solve_statements);
575  } else {
576  logger->error("SympySolverVisitor :: Solve method {} not supported", solve_method);
577  }
578  }
579 }
580 
583  std::string lin_eq = to_nmodl_for_sympy(*node.get_lhs());
584  lin_eq += " = ";
585  lin_eq += to_nmodl_for_sympy(*node.get_rhs());
586  eq_system.push_back(lin_eq);
589  logger->debug("SympySolverVisitor :: adding linear eq: {}", lin_eq);
590  collect_state_vars = true;
591  node.visit_children(*this);
592  collect_state_vars = false;
593 }
594 
596  logger->debug("SympySolverVisitor :: found LINEAR block: {}", node.get_node_name());
597 
598  /// clear information from previous block, get global vars + block local vars
599  init_block_data(&node);
600 
601  // collect linear equations
602  node.visit_children(*this);
603 
604  if (eq_system_is_valid && !eq_system.empty()) {
605  solve_linear_system(node);
606  }
607 }
608 
611  std::string non_lin_eq = to_nmodl_for_sympy(*node.get_lhs());
612  non_lin_eq += " = ";
613  non_lin_eq += to_nmodl_for_sympy(*node.get_rhs());
614  eq_system.push_back(non_lin_eq);
617  logger->debug("SympySolverVisitor :: adding non-linear eq: {}", non_lin_eq);
618  collect_state_vars = true;
619  node.visit_children(*this);
620  collect_state_vars = false;
621 }
622 
624  logger->debug("SympySolverVisitor :: found NONLINEAR block: {}", node.get_node_name());
625 
626  /// clear information from previous block, get global vars + block local vars
627  init_block_data(&node);
628 
629  // collect non-linear equations
630  node.visit_children(*this);
631 
632  if (eq_system_is_valid && !eq_system.empty()) {
634  }
635 }
636 
638  auto prev_expression_statement = current_expression_statement;
640  node.visit_children(*this);
641  current_expression_statement = prev_expression_statement;
642 }
643 
645  auto prev_statement_block = current_statement_block;
646  current_statement_block = &node;
647  node.visit_children(*this);
648  current_statement_block = prev_statement_block;
649 }
650 
653 
655 
656  // get list of solve statements with names & methods
657  const auto& solve_block_nodes = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
658  for (const auto& block: solve_block_nodes) {
659  if (auto block_ptr = std::dynamic_pointer_cast<const ast::SolveBlock>(block)) {
660  const auto& block_name = block_ptr->get_block_name()->get_value()->eval();
661  if (block_ptr->get_method()) {
662  // Note: solve method name is an optional parameter
663  // LINEAR and NONLINEAR blocks do not have solve method specified
664  const auto& solve_method = block_ptr->get_method()->get_value()->eval();
665  logger->debug("SympySolverVisitor :: Found SOLVE statement: using {} for {}",
666  solve_method,
667  block_name);
669  }
670  }
671  }
672 
673  // get set of all state vars
674  all_state_vars.clear();
675  if (auto symtab = node.get_symbol_table()) {
676  auto statevars = symtab->get_variables_with_properties(NmodlType::state_var);
677  for (const auto& v: statevars) {
678  std::string var_name = v->get_name();
679  if (v->is_array()) {
680  for (int i = 0; i < v->get_length(); ++i) {
681  std::string var_name_i = var_name + "[" + std::to_string(i) + "]";
682  all_state_vars.push_back(var_name_i);
683  }
684  } else {
685  all_state_vars.push_back(var_name);
686  }
687  }
688  }
689 
690  node.visit_children(*this);
691 }
692 
693 } // namespace visitor
694 } // namespace nmodl
nmodl::pybind_wrappers::pybind_wrap_api::solve_linear_system
decltype(&call_solve_linear_system) solve_linear_system
Definition: wrapper.hpp:64
nmodl::visitor::SympySolverVisitor::visit_cvode_block
void visit_cvode_block(ast::CvodeBlock &node) override
visit node of type ast::CvodeBlock
Definition: sympy_solver_visitor.cpp:403
nmodl::visitor::SympySolverVisitor::to_nmodl_for_sympy
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
Definition: sympy_solver_visitor.hpp:88
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:29
nmodl::ast::Node
Base class for all AST node.
Definition: node.hpp:40
nmodl::visitor::SympySolverVisitor::visit_lin_equation
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
Definition: sympy_solver_visitor.cpp:581
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
nmodl::ast::LinEquation::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6952
nmodl::visitor::SympySolverVisitor::visit_linear_block
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
Definition: sympy_solver_visitor.cpp:595
nmodl::ast::Conserve::get_react
std::shared_ptr< Expression > get_react() const noexcept
Getter for member variable Conserve::react.
Definition: conserve.hpp:159
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::visitor::SympyReplaceSolutionsVisitor::replaced_statements_end
int replaced_statements_end() const
idx (in the new statementVector) of the last statement that was added.
Definition: sympy_replace_solutions_visitor.hpp:236
nmodl::ast::DiffEqExpression
Represents differential equation in DERIVATIVE block.
Definition: diff_eq_expression.hpp:38
nmodl::visitor::SympyReplaceSolutionsVisitor
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
Definition: sympy_replace_solutions_visitor.hpp:212
nmodl::ast::Ast::is_non_linear_block
virtual bool is_non_linear_block() const noexcept
Check if the ast node is an instance of ast::NonLinearBlock.
Definition: ast.cpp:136
nmodl::visitor::SympySolverVisitor::replaceAll
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
Definition: sympy_solver_visitor.cpp:153
nmodl::visitor::SympySolverVisitor::visit_conserve
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
Definition: sympy_solver_visitor.cpp:476
nmodl::codegen::naming::CNEXP_METHOD
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
Definition: codegen_naming.hpp:30
nmodl::ast::NonLinEquation
One equation in a system of equations that collectively make a NONLINEAR block.
Definition: non_lin_equation.hpp:38
nmodl::visitor::SympySolverVisitor::visit_non_linear_block
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
Definition: sympy_solver_visitor.cpp:623
nmodl::ast::StatementBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3162
nmodl::ast::Conserve
Represent CONSERVE statement in NMODL.
Definition: conserve.hpp:38
nmodl::visitor::SympyReplaceSolutionsVisitor::replaced_statements_begin
int replaced_statements_begin() const
idx (in the new statementVector) of the first statement that was added.
Definition: sympy_replace_solutions_visitor.hpp:231
nmodl::ast::Conserve::get_expr
std::shared_ptr< Expression > get_expr() const noexcept
Getter for member variable Conserve::expr.
Definition: conserve.hpp:168
nmodl::visitor::SympySolverVisitor::solve_method
std::string solve_method
method specified in solve block
Definition: sympy_solver_visitor.hpp:134
nmodl::ast::StatementVector
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
nmodl::ast::NonLinearBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3525
nmodl::visitor::SympyReplaceSolutionsVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: sympy_replace_solutions_visitor.cpp:64
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::ast::StatementBlock::set_statements
void set_statements(StatementVector &&statements)
Setter for member variable StatementBlock::statements (rvalue reference)
Definition: ast.cpp:3222
nmodl::visitor::SympySolverVisitor::filter_string_vector
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
Definition: sympy_solver_visitor.cpp:165
nmodl::visitor::is_local_statement
static bool is_local_statement(const std::shared_ptr< ast::Statement > &statement)
Check if provided statement is local variable declaration statement.
Definition: sympy_solver_visitor.cpp:139
nmodl::ast::CvodeBlock
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
nmodl::visitor::SympySolverVisitor::all_state_vars
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
Definition: sympy_solver_visitor.hpp:146
nmodl::visitor::SympySolverVisitor::init_block_data
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
Definition: sympy_solver_visitor.cpp:27
nmodl::codegen::naming::NTHREAD_DT_VARIABLE
static constexpr char NTHREAD_DT_VARIABLE[]
dt variable in neuron thread structure
Definition: codegen_naming.hpp:108
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
string_utils.hpp
Implement string manipulation functions.
nmodl::visitor::SympySolverVisitor::derivative_block_solve_method
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
Definition: sympy_solver_visitor.hpp:115
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::codegen::naming::SPARSE_METHOD
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
Definition: codegen_naming.hpp:42
sympy_replace_solutions_visitor.hpp
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
nmodl::ast::LinearBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3376
nmodl::visitor::SympySolverVisitor::vars
std::set< std::string > vars
local variables in current block + globals
Definition: sympy_solver_visitor.hpp:109
codegen_naming.hpp
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
nmodl::ast::ExpressionStatement
TODO.
Definition: expression_statement.hpp:38
nmodl::visitor::SympySolverVisitor::global_vars
std::set< std::string > global_vars
global variables
Definition: sympy_solver_visitor.hpp:106
nmodl::visitor::SympySolverVisitor::init_state_vars_vector
void init_state_vars_vector(const ast::Node *node)
construct vector from set of state vars in correct order
Definition: sympy_solver_visitor.cpp:55
nmodl::ast::DerivativeBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3246
nmodl::visitor::SympySolverVisitor::SMALL_LINEAR_SYSTEM_MAX_STATES
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
Definition: sympy_solver_visitor.hpp:165
nmodl::stringutils::trim
static std::string trim(std::string text)
Definition: string_utils.hpp:63
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::DerivativeBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3253
nmodl::visitor::SympySolverVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: sympy_solver_visitor.cpp:644
nmodl::visitor::SympySolverVisitor::function_calls
std::set< std::string > function_calls
custom function calls used in ODE block
Definition: sympy_solver_visitor.hpp:112
nmodl::pybind_wrappers::pybind_wrap_api::diffeq_solver
decltype(&call_diffeq_solver) diffeq_solver
Definition: wrapper.hpp:65
nmodl::visitor::SympySolverVisitor::block_with_expression_statements
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
Definition: sympy_solver_visitor.hpp:131
nmodl::visitor::SympySolverVisitor::replace_diffeq_expression
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
Definition: sympy_solver_visitor.cpp:83
nmodl::pybind_wrappers::pybind_wrap_api::solve_nonlinear_system
decltype(&call_solve_nonlinear_system) solve_nonlinear_system
Definition: wrapper.hpp:63
nmodl::visitor::SympySolverVisitor::solve_non_linear_system
void solve_non_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
Definition: sympy_solver_visitor.cpp:361
nmodl::ast::LinearBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3366
nmodl::pybind_wrappers::EmbeddedPythonLoader::api
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:135
nmodl::symtab::syminfo::to_string
std::string to_string(const T &obj)
Definition: symbol_properties.hpp:282
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12906
nmodl::visitor::SympySolverVisitor::visit_var_name
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
Definition: sympy_solver_visitor.cpp:382
nmodl::visitor::SympySolverVisitor::visit_expression_statement
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
Definition: sympy_solver_visitor.cpp:637
nmodl::visitor::SympySolverVisitor::solve_linear_system
void solve_linear_system(const ast::Node &node, const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
Definition: sympy_solver_visitor.cpp:310
nmodl::ast::AstNodeType::SOLVE_BLOCK
@ SOLVE_BLOCK
type of ast::SolveBlock
nmodl::ast::LinEquation::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable LinEquation::lhs.
Definition: lin_equation.hpp:159
nmodl::ast::DiffEqExpression::get_expression
std::shared_ptr< BinaryExpression > get_expression() const noexcept
Getter for member variable DiffEqExpression::expression.
Definition: diff_eq_expression.hpp:143
nmodl::visitor::SympySolverVisitor::state_vars_in_block
std::set< std::string > state_vars_in_block
set of state variables used in block
Definition: sympy_solver_visitor.hpp:149
nmodl::ast::NonLinEquation::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6834
nmodl::visitor::SympyReplaceSolutionsVisitor::ReplacePolicy::GREEDY
@ GREEDY
Replace statements greedily.
nmodl::visitor::SympySolverVisitor::get_solution_location_iterator
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
Definition: sympy_solver_visitor.cpp:105
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::visitor::SympySolverVisitor::use_pade_approx
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
Definition: sympy_solver_visitor.hpp:159
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:127
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
nmodl::visitor::SympySolverVisitor::collect_state_vars
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
Definition: sympy_solver_visitor.hpp:143
nmodl::visitor::SympySolverVisitor::conserve_equation
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
Definition: sympy_solver_visitor.hpp:156
nmodl::ast::DiffEqExpression::set_expression
void set_expression(std::shared_ptr< BinaryExpression > &&expression)
Setter for member variable DiffEqExpression::expression (rvalue reference)
Definition: ast.cpp:6700
nmodl::visitor::SympySolverVisitor::state_vars
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
Definition: sympy_solver_visitor.hpp:152
nmodl::ast::NonLinearBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3535
nmodl::visitor::SympySolverVisitor::construct_eigen_solver_block
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
Definition: sympy_solver_visitor.cpp:177
nmodl::stringutils::split_string
static std::vector< std::string > split_string(const std::string &text, char delimiter)
Split a text in a list of words, using a given delimiter character.
Definition: string_utils.hpp:116
nmodl::ast::ExpressionStatement::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:9034
nmodl::ast::NonLinEquation::get_rhs
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable NonLinEquation::rhs.
Definition: non_lin_equation.hpp:168
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::visitor::SympySolverVisitor::visit_non_lin_equation
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
Definition: sympy_solver_visitor.cpp:609
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::visitor::create_statement_block
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:156
nmodl::visitor::SympySolverVisitor::elimination
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
Definition: sympy_solver_visitor.hpp:162
nmodl::ast::LinearBlock
Represents LINEAR block in the NMODL.
Definition: linear_block.hpp:53
nmodl::visitor::SympySolverVisitor::last_expression_statement
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
Definition: sympy_solver_visitor.hpp:125
nmodl::codegen::naming::DERIVIMPLICIT_METHOD
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Definition: codegen_naming.hpp:24
nmodl::ast::NonLinearBlock
Represents NONLINEAR block in the NMODL.
Definition: non_linear_block.hpp:50
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::SympySolverVisitor::current_statement_block
ast::StatementBlock * current_statement_block
current statement block being visited
Definition: sympy_solver_visitor.hpp:128
nmodl::visitor::SympySolverVisitor::visit_derivative_block
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
Definition: sympy_solver_visitor.cpp:500
nmodl::pybind_wrappers
Definition: pyembed.cpp:25
nmodl::ast::Ast::is_linear_block
virtual bool is_linear_block() const noexcept
Check if the ast node is an instance of ast::LinearBlock.
Definition: ast.cpp:134
nmodl::codegen::naming::EULER_METHOD
static constexpr char EULER_METHOD[]
euler method in nmodl
Definition: codegen_naming.hpp:27
nmodl::ast::LinEquation
One equation in a system of equations tha collectively form a LINEAR block.
Definition: lin_equation.hpp:38
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:93
nmodl::ast::VarName::get_name
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
Definition: var_name.hpp:164
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl::visitor::SympySolverVisitor::expression_statements
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
Definition: sympy_solver_visitor.hpp:119
nmodl::ast::NonLinEquation::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable NonLinEquation::lhs.
Definition: non_lin_equation.hpp:159
nmodl::visitor::SympySolverVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: sympy_solver_visitor.cpp:651
nmodl::visitor::SympySolverVisitor::eq_system
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
Definition: sympy_solver_visitor.hpp:137
nmodl::ast::Ast::get_statement_block
virtual std::shared_ptr< StatementBlock > get_statement_block() const
Return associated statement block for the AST node.
Definition: ast.cpp:32
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::visitor::SympySolverVisitor::eq_system_is_valid
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
Definition: sympy_solver_visitor.hpp:140
nmodl::visitor::suffix_random_string
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
Definition: visitor_utils.cpp:33
nmodl::ast::LinEquation::get_rhs
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable LinEquation::rhs.
Definition: lin_equation.hpp:168
nmodl::visitor::SympySolverVisitor::current_expression_statement
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
Definition: sympy_solver_visitor.hpp:122
nmodl::visitor::SympySolverVisitor::check_expr_statements_in_same_block
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
Definition: sympy_solver_visitor.cpp:92
nmodl::visitor::SympyReplaceSolutionsVisitor::ReplacePolicy::VALUE
@ VALUE
Replace statements matching by lhs varName.
sympy_solver_visitor.hpp
Visitor for systems of algebraic and differential equations
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:171
all.hpp
Auto generated AST classes declaration.
nmodl::ast::VarName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1124
pyembed.hpp
nmodl::visitor::SympySolverVisitor::visit_diff_eq_expression
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
Definition: sympy_solver_visitor.cpp:405