User Guide
sympy_solver_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
10 
11 #include "ast/all.hpp"
13 #include "pybind/pyembed.hpp"
14 #include "symtab/symbol.hpp"
15 #include "utils/logger.hpp"
16 #include "utils/string_utils.hpp"
18 
19 
20 namespace pywrap = nmodl::pybind_wrappers;
21 
22 namespace nmodl {
23 namespace visitor {
24 
26 
28  // clear any previous data
29  expression_statements.clear();
30  eq_system.clear();
31  state_vars_in_block.clear();
32  last_expression_statement = nullptr;
34  eq_system_is_valid = true;
35  conserve_equation.clear();
36  // get set of local block vars & global vars
37  vars = global_vars;
38  if (auto symtab = node->get_statement_block()->get_symbol_table()) {
39  auto localvars = symtab->get_variables_with_properties(NmodlType::local_var);
40  for (const auto& localvar: localvars) {
41  std::string var_name = localvar->get_name();
42  if (localvar->is_array()) {
43  var_name += "[" + std::to_string(localvar->get_length()) + "]";
44  }
45  vars.insert(var_name);
46  }
47  }
48  const auto& fcall_nodes = collect_nodes(*node->get_statement_block(),
49  {ast::AstNodeType::FUNCTION_CALL});
50  for (const auto& call: fcall_nodes) {
51  function_calls.insert(call->get_node_name());
52  }
53 }
54 
56  state_vars.clear();
57  for (const auto& state_var: all_state_vars) {
58  if (state_vars_in_block.find(state_var) != state_vars_in_block.cend()) {
59  state_vars.push_back(state_var);
60  }
61  }
62 }
63 
65  const std::string& new_expr) {
66  auto new_statement = create_statement(new_expr);
67  auto new_expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(new_statement);
68  auto new_bin_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(
69  new_expr_statement->get_expression());
70  expr.set_expression(std::move(new_bin_expr));
71 }
72 
74  /// all ode/kinetic/(non)linear statements (typically) appear in the same statement block
75  /// if this is not the case, for now return an error (and should instead use fallback solver)
76  if (block_with_expression_statements != nullptr &&
78  logger->warn(
79  "SympySolverVisitor :: Coupled equations are appearing in different blocks - not "
80  "supported");
81  eq_system_is_valid = false;
82  }
84 }
85 
86 ast::StatementVector::const_iterator SympySolverVisitor::get_solution_location_iterator(
87  const ast::StatementVector& statements) {
88  // find out where to insert solutions in statement block
89  // returns iterator pointing to the first element after the last (non)linear eq
90  // so if there are no such elements, it returns statements.end()
91  auto it = statements.begin();
92  if (last_expression_statement != nullptr) {
93  while ((it != statements.end()) &&
94  (std::dynamic_pointer_cast<ast::ExpressionStatement>(*it).get() !=
96  logger->debug("SympySolverVisitor :: {} != {}",
97  to_nmodl(*it),
99  ++it;
100  }
101  if (it != statements.end()) {
102  logger->debug("SympySolverVisitor :: {} == {}",
103  to_nmodl(std::dynamic_pointer_cast<ast::ExpressionStatement>(*it)),
105  ++it;
106  }
107  }
108  return it;
109 }
110 
111 /**
112  * Check if provided statement is local variable declaration statement
113  * @param statement AST node representing statement in the MOD file
114  * @return True if statement is local variable declaration else False
115  *
116  * Statement declaration could be wrapped into another statement type like
117  * expression statement and hence we try to look inside if it's really a
118  * variable declaration.
119  */
120 static bool is_local_statement(const std::shared_ptr<ast::Statement>& statement) {
121  if (statement->is_local_list_statement()) {
122  return true;
123  }
124  if (statement->is_expression_statement()) {
125  auto e_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
126  auto expression = e_statement->get_expression();
127  if (expression->is_local_list_statement()) {
128  return true;
129  }
130  }
131  return false;
132 }
133 
134 std::string& SympySolverVisitor::replaceAll(std::string& context,
135  const std::string& from,
136  const std::string& to) {
137  std::size_t lookHere = 0;
138  std::size_t foundHere{};
139  while ((foundHere = context.find(from, lookHere)) != std::string::npos) {
140  context.replace(foundHere, from.size(), to);
141  lookHere = foundHere + to.size();
142  }
143  return context;
144 }
145 
147  const std::vector<std::string>& original_vector,
148  const std::string& original_string,
149  const std::string& substitution_string) {
150  std::vector<std::string> filtered_vector;
151  for (auto element: original_vector) {
152  std::string filtered_element = replaceAll(element, original_string, substitution_string);
153  filtered_vector.push_back(filtered_element);
154  }
155  return filtered_vector;
156 }
157 
159  const std::vector<std::string>& pre_solve_statements,
160  const std::vector<std::string>& solutions,
161  bool linear) {
162  auto solutions_filtered = filter_string_vector(solutions, "X[", "nmodl_eigen_x[");
163  solutions_filtered = filter_string_vector(solutions_filtered, "J[", "nmodl_eigen_j[");
164  solutions_filtered = filter_string_vector(solutions_filtered, "Jm[", "nmodl_eigen_jm[");
165  solutions_filtered = filter_string_vector(solutions_filtered, "F[", "nmodl_eigen_f[");
166 
167  for (const auto& sol: solutions_filtered) {
168  logger->debug("SympySolverVisitor :: -> adding statement: {}", sol);
169  }
170 
171  std::vector<std::string> pre_solve_statements_and_setup_x_eqs(pre_solve_statements);
172  std::vector<std::string> update_statements;
173  for (int i = 0; i < state_vars.size(); i++) {
174  auto update_state = state_vars[i] + " = nmodl_eigen_x[" + std::to_string(i) + "]";
175  auto setup_x = "nmodl_eigen_x[" + std::to_string(i) + "] = " + state_vars[i];
176 
177  pre_solve_statements_and_setup_x_eqs.push_back(setup_x);
178  update_statements.push_back(update_state);
179  logger->debug("SympySolverVisitor :: setup_x_eigen: {}", setup_x);
180  logger->debug("SympySolverVisitor :: update_state: {}", update_state);
181  }
182 
183  visitor::SympyReplaceSolutionsVisitor solution_replacer(
184  pre_solve_statements_and_setup_x_eqs,
185  solutions_filtered,
188  state_vars.size() + 1,
189  "");
191 
192  // split in the various blocks for eigen
193  auto n_state_vars = std::make_shared<ast::Integer>(state_vars.size(), nullptr);
194 
195  const auto& statements = block_with_expression_statements->get_statements();
196 
197  ast::StatementVector variable_statements; // LOCAL //
198  ast::StatementVector initialize_statements; // pre_solve_statements //
199  ast::StatementVector setup_x_statements; // old_x = x, X[0] = x //
200  ast::StatementVector functor_statements; // J[0]_row * X = F[0], additional assignments during
201  // computation //
202  ast::StatementVector finalize_statements; // assignments at the end //
203  std::ptrdiff_t const sr_begin{solution_replacer.replaced_statements_begin()};
204  std::ptrdiff_t const sr_end{solution_replacer.replaced_statements_end()};
205 
206  // initialize and edge case where the system of equations is empty
207  for (size_t idx = 0; idx < statements.size(); ++idx) {
208  auto& s = statements[idx];
209  if (is_local_statement(s)) {
210  variable_statements.push_back(s);
211  } else if (sr_begin == statements.size() || idx < sr_begin) {
212  initialize_statements.push_back(s);
213  }
214  }
215 
216  if (sr_begin != statements.size()) {
217  initialize_statements.insert(initialize_statements.end(),
218  statements.begin() + sr_begin,
219  statements.begin() + sr_begin +
220  static_cast<std::ptrdiff_t>(pre_solve_statements.size()));
221  setup_x_statements = ast::StatementVector(
222  statements.begin() + sr_begin +
223  static_cast<std::ptrdiff_t>(pre_solve_statements.size()),
224  statements.begin() + sr_begin +
225  static_cast<std::ptrdiff_t>(pre_solve_statements.size() + state_vars.size()));
226  functor_statements = ast::StatementVector(
227  statements.begin() + sr_begin +
228  static_cast<std::ptrdiff_t>(pre_solve_statements.size() + state_vars.size()),
229  statements.begin() + sr_end);
230  finalize_statements = ast::StatementVector(statements.begin() + sr_end, statements.end());
231  }
232 
233  const size_t total_statements_size = variable_statements.size() + initialize_statements.size() +
234  setup_x_statements.size() + functor_statements.size() +
235  finalize_statements.size();
236  if (statements.size() != total_statements_size) {
237  logger->error(
238  "SympySolverVisitor :: statement number missmatch ({} =/= {}) during splitting before "
239  "creation of "
240  "eigen "
241  "solver block.",
242  statements.size(),
243  total_statements_size);
244  return;
245  }
246 
247  auto variable_block = std::make_shared<ast::StatementBlock>(std::move(variable_statements));
248  auto initialize_block = std::make_shared<ast::StatementBlock>(std::move(initialize_statements));
249  auto update_state_block = create_statement_block(update_statements);
250  auto finalize_block = std::make_shared<ast::StatementBlock>(std::move(finalize_statements));
251  if (linear) {
252  /// functor and initialize block converge in the same block
253  setup_x_statements.insert(setup_x_statements.end(),
254  functor_statements.begin(),
255  functor_statements.end());
256  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
257  auto solver_block = std::make_shared<ast::EigenLinearSolverBlock>(n_state_vars,
258  variable_block,
259  initialize_block,
260  setup_x_block,
261  update_state_block,
262  finalize_block);
263  /// replace statement block with solver block as it contains all statements
264  ast::StatementVector solver_block_statements{
265  std::make_shared<ast::ExpressionStatement>(solver_block)};
266  block_with_expression_statements->set_statements(std::move(solver_block_statements));
267  } else {
268  /// create eigen newton solver block
269  auto setup_x_block = std::make_shared<ast::StatementBlock>(std::move(setup_x_statements));
270  auto functor_block = std::make_shared<ast::StatementBlock>(std::move(functor_statements));
271  auto solver_block = std::make_shared<ast::EigenNewtonSolverBlock>(n_state_vars,
272  variable_block,
273  initialize_block,
274  setup_x_block,
275  functor_block,
276  update_state_block,
277  finalize_block);
278  /// replace statement block with solver block as it contains all statements
279  ast::StatementVector solver_block_statements{
280  std::make_shared<ast::ExpressionStatement>(solver_block)};
281  block_with_expression_statements->set_statements(std::move(solver_block_statements));
282  }
283 }
284 
285 
286 void SympySolverVisitor::solve_linear_system(const std::vector<std::string>& pre_solve_statements) {
287  // construct ordered vector of state vars used in linear system
289  // call sympy linear solver
290  bool small_system = (eq_system.size() <= SMALL_LINEAR_SYSTEM_MAX_STATES);
292  solver->eq_system = eq_system;
293  solver->state_vars = state_vars;
294  solver->vars = vars;
295  solver->small_system = small_system;
296  solver->elimination = elimination;
297  // this is necessary after we destroy the solver
298  const auto tmp_unique_prefix = suffix_random_string(vars, "tmp");
299  solver->tmp_unique_prefix = tmp_unique_prefix;
300  solver->function_calls = function_calls;
301  (*solver)();
302  // returns a vector of solutions, i.e. new statements to add to block:
303  auto solutions = solver->solutions;
304  // and a vector of new local variables that need to be declared in the block:
305  auto new_local_vars = solver->new_local_vars;
306  // may also return a python exception message:
307  auto exception_message = solver->exception_message;
308  // destroy solver
310  if (!exception_message.empty()) {
311  logger->warn("SympySolverVisitor :: solve_lin_system python exception: " +
312  exception_message);
313  return;
314  }
315  // find out where to insert solutions in statement block
316  if (small_system) {
317  // for small number of state vars, linear solver
318  // directly returns solution by solving symbolically at compile time
319  logger->debug("SympySolverVisitor :: Solving *small* linear system of eqs");
320  // declare new local vars
321  if (!new_local_vars.empty()) {
322  for (const auto& new_local_var: new_local_vars) {
323  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}",
324  new_local_var);
326  }
327  }
328  visitor::SympyReplaceSolutionsVisitor solution_replacer(
329  pre_solve_statements,
330  solutions,
333  1,
334  tmp_unique_prefix);
336  } else {
337  // otherwise it returns a linear matrix system to solve
338  logger->debug("SympySolverVisitor :: Constructing linear newton solve block");
339  construct_eigen_solver_block(pre_solve_statements, solutions, true);
340  }
341 }
342 
344  const std::vector<std::string>& pre_solve_statements) {
345  // construct ordered vector of state vars used in non-linear system
347  // call sympy non-linear solver
348 
350  solver->eq_system = eq_system;
351  solver->state_vars = state_vars;
352  solver->vars = vars;
353  solver->function_calls = function_calls;
354  (*solver)();
355  // returns a vector of solutions, i.e. new statements to add to block:
356  auto solutions = solver->solutions;
357  // may also return a python exception message:
358  auto exception_message = solver->exception_message;
360  if (!exception_message.empty()) {
361  logger->warn("SympySolverVisitor :: solve_non_lin_system python exception: " +
362  exception_message);
363  return;
364  }
365  logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
366  construct_eigen_solver_block(pre_solve_statements, solutions, false);
367 }
368 
370  if (collect_state_vars) {
371  std::string var_name = node.get_node_name();
372  if (node.get_name()->is_indexed_name()) {
373  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(node.get_name());
374  var_name +=
375  "[" +
377  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
378  "]";
379  }
380  // if var_name is a state var, add it to set
381  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), var_name) !=
382  all_state_vars.cend()) {
383  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
384  state_vars_in_block.insert(var_name);
385  }
386  }
387 }
388 
390  const auto& lhs = node.get_expression()->get_lhs();
391 
392  if (!lhs->is_var_name()) {
393  logger->warn("SympySolverVisitor :: LHS of differential equation is not a VariableName");
394  return;
395  }
396  auto lhs_name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
397  if ((lhs_name->is_indexed_name() &&
398  !std::dynamic_pointer_cast<ast::IndexedName>(lhs_name)->get_name()->is_prime_name()) ||
399  (!lhs_name->is_indexed_name() && !lhs_name->is_prime_name())) {
400  logger->warn("SympySolverVisitor :: LHS of differential equation is not a PrimeName");
401  return;
402  }
403 
405 
406  const auto node_as_nmodl = to_nmodl_for_sympy(node);
407  const auto deleter = [](nmodl::pybind_wrappers::DiffeqSolverExecutor* ptr) {
409  };
410  std::unique_ptr<nmodl::pybind_wrappers::DiffeqSolverExecutor, decltype(deleter)> diffeq_solver{
412 
413  diffeq_solver->node_as_nmodl = node_as_nmodl;
414  diffeq_solver->dt_var = codegen::naming::NTHREAD_DT_VARIABLE;
415  diffeq_solver->vars = vars;
416  diffeq_solver->use_pade_approx = use_pade_approx;
417  diffeq_solver->function_calls = function_calls;
418  diffeq_solver->method = solve_method;
419  (*diffeq_solver)();
421  // replace x' = f(x) differential equation
422  // with forwards Euler timestep:
423  // x = x + f(x) * dt
424  logger->debug("SympySolverVisitor :: EULER - solving: {}", node_as_nmodl);
426  // replace x' = f(x) differential equation
427  // with analytic solution for x(t+dt) in terms of x(t)
428  // x = ...
429  logger->debug("SympySolverVisitor :: CNEXP - solving: {}", node_as_nmodl);
430  } else {
431  // for other solver methods: just collect the ODEs & return
432  std::string eq_str = to_nmodl_for_sympy(node);
433  std::string var_name = lhs_name->get_node_name();
434  if (lhs_name->is_indexed_name()) {
435  auto index_name = std::dynamic_pointer_cast<ast::IndexedName>(lhs_name);
436  var_name +=
437  "[" +
439  std::dynamic_pointer_cast<ast::Integer>(index_name->get_length())->eval()) +
440  "]";
441  }
442  logger->debug("SympySolverVisitor :: adding ODE system: {}", eq_str);
443  eq_system.push_back(eq_str);
444  logger->debug("SympySolverVisitor :: adding state var: {}", var_name);
445  state_vars_in_block.insert(var_name);
448  return;
449  }
450 
451  // replace ODE with solution in AST
452  auto solution = diffeq_solver->solution;
453  logger->debug("SympySolverVisitor :: -> solution: {}", solution);
454 
455  auto exception_message = diffeq_solver->exception_message;
456  if (!exception_message.empty()) {
457  logger->warn("SympySolverVisitor :: python exception: " + exception_message);
458  return;
459  }
460 
461  if (!solution.empty()) {
462  replace_diffeq_expression(node, solution);
463  } else {
464  logger->warn("SympySolverVisitor :: solution to differential equation not possible");
465  }
466 }
467 
469  // Replace ODE for state variable on LHS of CONSERVE statement with
470  // algebraic expression on RHS (see p244 of NEURON book)
471  logger->debug("SympySolverVisitor :: CONSERVE statement: {}", to_nmodl(node));
472  expression_statements.insert(&node);
473  std::string conserve_equation_statevar;
474  if (node.get_react()->is_react_var_name()) {
475  conserve_equation_statevar = node.get_react()->get_node_name();
476  }
477  if (std::find(all_state_vars.cbegin(), all_state_vars.cend(), conserve_equation_statevar) ==
478  all_state_vars.cend()) {
479  logger->error(
480  "SympySolverVisitor :: Invalid CONSERVE statement for DERIVATIVE block, LHS should be "
481  "a state variable, instead found: {}. Ignoring CONSERVE statement",
482  to_nmodl(node.get_react()));
483  return;
484  }
485  auto conserve_equation_str = to_nmodl_for_sympy(*node.get_expr());
486  logger->debug("SympySolverVisitor :: --> replace ODE for state var {} with equation {}",
487  conserve_equation_statevar,
488  conserve_equation_str);
489  conserve_equation[conserve_equation_statevar] = conserve_equation_str;
490 }
491 
493  /// clear information from previous block, get global vars + block local vars
494  init_block_data(&node);
495 
496  // get user specified solve method for this block
498 
499  // visit each differential equation:
500  // - for CNEXP or EULER, each equation is independent & is replaced with its solution
501  // - otherwise, each equation is added to eq_system
502  node.visit_children(*this);
503 
504  if (eq_system_is_valid && !eq_system.empty()) {
505  // solve system of ODEs in eq_system
506  logger->debug("SympySolverVisitor :: Solving {} system of ODEs", solve_method);
507 
508  // construct implicit Euler equations from ODEs
509  std::vector<std::string> pre_solve_statements;
510  for (auto& eq: eq_system) {
511  auto split_eq = stringutils::split_string(eq, '=');
512  auto x_prime_split = stringutils::split_string(split_eq[0], '\'');
513  auto x = stringutils::trim(x_prime_split[0]);
514  std::string x_array_index;
515  std::string x_array_index_i;
516  if (x_prime_split.size() > 1 && stringutils::trim(x_prime_split[1]).size() > 2) {
517  x_array_index = stringutils::trim(x_prime_split[1]);
518  x_array_index_i = "_" + x_array_index.substr(1, x_array_index.size() - 2);
519  }
520  std::string state_var_name = x + x_array_index;
521  auto var_eq_pair = conserve_equation.find(state_var_name);
522  if (var_eq_pair != conserve_equation.cend()) {
523  // replace the ODE for this state var with corresponding CONSERVE equation
524  eq = state_var_name + " = " + var_eq_pair->second;
525  logger->debug(
526  "SympySolverVisitor :: -> instead of Euler eq using CONSERVE equation: {} = {}",
527  state_var_name,
528  var_eq_pair->second);
529  } else {
530  // no CONSERVE equation, construct Euler equation
531  auto dxdt = stringutils::trim(split_eq[1]);
532 
533  auto const old_x = [&]() {
534  std::string old_x_name{"old_"};
535  old_x_name.append(x);
536  old_x_name.append(x_array_index_i);
537  return suffix_random_string(vars, old_x_name);
538  }();
539  // declare old_x
540  logger->debug("SympySolverVisitor :: -> declaring new local variable: {}", old_x);
542  // assign old_x = x
543  {
544  std::string expression{old_x};
545  expression.append(" = ");
546  expression.append(x);
547  expression.append(x_array_index);
548  pre_solve_statements.push_back(std::move(expression));
549  }
550  // replace ODE with Euler equation
551  eq = x;
552  eq.append(x_array_index);
553  eq.append(" = ");
554  eq.append(old_x);
555  eq.append(" + ");
557  eq.append(" * (");
558  eq.append(dxdt);
559  eq.append(")");
560  logger->debug("SympySolverVisitor :: -> constructed Euler eq: {}", eq);
561  }
562  }
563 
566  solve_non_linear_system(pre_solve_statements);
567  } else {
568  logger->error("SympySolverVisitor :: Solve method {} not supported", solve_method);
569  }
570  }
571 }
572 
575  std::string lin_eq = to_nmodl_for_sympy(*node.get_left_linxpression());
576  lin_eq += " = ";
577  lin_eq += to_nmodl_for_sympy(*node.get_linxpression());
578  eq_system.push_back(lin_eq);
581  logger->debug("SympySolverVisitor :: adding linear eq: {}", lin_eq);
582  collect_state_vars = true;
583  node.visit_children(*this);
584  collect_state_vars = false;
585 }
586 
588  logger->debug("SympySolverVisitor :: found LINEAR block: {}", node.get_node_name());
589 
590  /// clear information from previous block, get global vars + block local vars
591  init_block_data(&node);
592 
593  // collect linear equations
594  node.visit_children(*this);
595 
596  if (eq_system_is_valid && !eq_system.empty()) {
598  }
599 }
600 
603  std::string non_lin_eq = to_nmodl_for_sympy(*node.get_lhs());
604  non_lin_eq += " = ";
605  non_lin_eq += to_nmodl_for_sympy(*node.get_rhs());
606  eq_system.push_back(non_lin_eq);
609  logger->debug("SympySolverVisitor :: adding non-linear eq: {}", non_lin_eq);
610  collect_state_vars = true;
611  node.visit_children(*this);
612  collect_state_vars = false;
613 }
614 
616  logger->debug("SympySolverVisitor :: found NONLINEAR block: {}", node.get_node_name());
617 
618  /// clear information from previous block, get global vars + block local vars
619  init_block_data(&node);
620 
621  // collect non-linear equations
622  node.visit_children(*this);
623 
624  if (eq_system_is_valid && !eq_system.empty()) {
626  }
627 }
628 
630  auto prev_expression_statement = current_expression_statement;
632  node.visit_children(*this);
633  current_expression_statement = prev_expression_statement;
634 }
635 
637  auto prev_statement_block = current_statement_block;
638  current_statement_block = &node;
639  node.visit_children(*this);
640  current_statement_block = prev_statement_block;
641 }
642 
645 
647 
648  // get list of solve statements with names & methods
649  const auto& solve_block_nodes = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
650  for (const auto& block: solve_block_nodes) {
651  if (auto block_ptr = std::dynamic_pointer_cast<const ast::SolveBlock>(block)) {
652  const auto& block_name = block_ptr->get_block_name()->get_value()->eval();
653  if (block_ptr->get_method()) {
654  // Note: solve method name is an optional parameter
655  // LINEAR and NONLINEAR blocks do not have solve method specified
656  const auto& solve_method = block_ptr->get_method()->get_value()->eval();
657  logger->debug("SympySolverVisitor :: Found SOLVE statement: using {} for {}",
658  solve_method,
659  block_name);
661  }
662  }
663  }
664 
665  // get set of all state vars
666  all_state_vars.clear();
667  if (auto symtab = node.get_symbol_table()) {
668  auto statevars = symtab->get_variables_with_properties(NmodlType::state_var);
669  for (const auto& v: statevars) {
670  std::string var_name = v->get_name();
671  if (v->is_array()) {
672  for (int i = 0; i < v->get_length(); ++i) {
673  std::string var_name_i = var_name + "[" + std::to_string(i) + "]";
674  all_state_vars.push_back(var_name_i);
675  }
676  } else {
677  all_state_vars.push_back(var_name);
678  }
679  }
680  }
681 
682  node.visit_children(*this);
683 }
684 
685 } // namespace visitor
686 } // namespace nmodl
nmodl::visitor::SympySolverVisitor::to_nmodl_for_sympy
static std::string to_nmodl_for_sympy(ast::Ast &node)
return NMODL string version of node, excluding any units
Definition: sympy_solver_visitor.hpp:88
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:141
nmodl::ast::Node
Base class for all AST node.
Definition: node.hpp:40
nmodl::visitor::SympySolverVisitor::visit_lin_equation
void visit_lin_equation(ast::LinEquation &node) override
visit node of type ast::LinEquation
Definition: sympy_solver_visitor.cpp:573
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::ast::LinEquation::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6948
nmodl::visitor::SympySolverVisitor::visit_linear_block
void visit_linear_block(ast::LinearBlock &node) override
visit node of type ast::LinearBlock
Definition: sympy_solver_visitor.cpp:587
nmodl::ast::Conserve::get_react
std::shared_ptr< Expression > get_react() const noexcept
Getter for member variable Conserve::react.
Definition: conserve.hpp:159
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::visitor::SympyReplaceSolutionsVisitor::replaced_statements_end
int replaced_statements_end() const
idx (in the new statementVector) of the last statement that was added.
Definition: sympy_replace_solutions_visitor.hpp:236
nmodl::ast::DiffEqExpression
Represents differential equation in DERIVATIVE block.
Definition: diff_eq_expression.hpp:38
nmodl::visitor::SympyReplaceSolutionsVisitor
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
Definition: sympy_replace_solutions_visitor.hpp:212
nmodl::visitor::SympySolverVisitor::replaceAll
static std::string & replaceAll(std::string &context, const std::string &from, const std::string &to)
Function used by SympySolverVisitor::filter_X to replace the name X in a std::string to X_operator.
Definition: sympy_solver_visitor.cpp:134
nmodl::visitor::SympySolverVisitor::visit_conserve
void visit_conserve(ast::Conserve &node) override
visit node of type ast::Conserve
Definition: sympy_solver_visitor.cpp:468
nmodl::codegen::naming::CNEXP_METHOD
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
Definition: codegen_naming.hpp:30
nmodl::ast::NonLinEquation
TODO.
Definition: non_lin_equation.hpp:38
nmodl::visitor::SympySolverVisitor::visit_non_linear_block
void visit_non_linear_block(ast::NonLinearBlock &node) override
visit node of type ast::NonLinearBlock
Definition: sympy_solver_visitor.cpp:615
nmodl::pybind_wrappers::pybind_wrap_api::create_nsls_executor
decltype(&create_nsls_executor_func) create_nsls_executor
Definition: pyembed.hpp:117
nmodl::ast::StatementBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3158
nmodl::ast::Conserve
Represent CONSERVE statement in NMODL.
Definition: conserve.hpp:38
nmodl::visitor::SympyReplaceSolutionsVisitor::replaced_statements_begin
int replaced_statements_begin() const
idx (in the new statementVector) of the first statement that was added.
Definition: sympy_replace_solutions_visitor.hpp:231
nmodl::ast::Conserve::get_expr
std::shared_ptr< Expression > get_expr() const noexcept
Getter for member variable Conserve::expr.
Definition: conserve.hpp:168
nmodl::visitor::SympySolverVisitor::solve_method
std::string solve_method
method specified in solve block
Definition: sympy_solver_visitor.hpp:134
nmodl::ast::StatementVector
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:298
nmodl::ast::NonLinearBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3521
nmodl::visitor::SympyReplaceSolutionsVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: sympy_replace_solutions_visitor.cpp:64
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::ast::StatementBlock::set_statements
void set_statements(StatementVector &&statements)
Setter for member variable StatementBlock::statements (rvalue reference)
Definition: ast.cpp:3218
nmodl::pybind_wrappers::EmbeddedPythonLoader::api
const pybind_wrap_api * api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:87
nmodl::visitor::SympySolverVisitor::filter_string_vector
static std::vector< std::string > filter_string_vector(const std::vector< std::string > &original_vector, const std::string &original_string, const std::string &substitution_string)
Check original_vector for elements that contain a variable named original_string and rename it to sub...
Definition: sympy_solver_visitor.cpp:146
nmodl::visitor::is_local_statement
static bool is_local_statement(const std::shared_ptr< ast::Statement > &statement)
Check if provided statement is local variable declaration statement.
Definition: sympy_solver_visitor.cpp:120
nmodl::visitor::SympySolverVisitor::all_state_vars
std::vector< std::string > all_state_vars
vector of all state variables (in order specified in STATE block in mod file)
Definition: sympy_solver_visitor.hpp:146
nmodl::visitor::SympySolverVisitor::init_block_data
void init_block_data(ast::Node *node)
clear any data from previous block & get set of block local vars + global vars
Definition: sympy_solver_visitor.cpp:27
nmodl::codegen::naming::NTHREAD_DT_VARIABLE
static constexpr char NTHREAD_DT_VARIABLE[]
dt variable in neuron thread structure
Definition: codegen_naming.hpp:99
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
string_utils.hpp
Implement string manipulation functions.
nmodl::visitor::SympySolverVisitor::derivative_block_solve_method
std::unordered_map< std::string, std::string > derivative_block_solve_method
map between derivative block names and associated solver method
Definition: sympy_solver_visitor.hpp:115
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::SympySolverVisitor::solve_non_linear_system
void solve_non_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve non-linear system (for "derivimplicit", "sparse" and "NONLINEAR")
Definition: sympy_solver_visitor.cpp:343
nmodl::codegen::naming::SPARSE_METHOD
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
Definition: codegen_naming.hpp:36
sympy_replace_solutions_visitor.hpp
Replace statements in node with pre_solve_statements, tmp_statements, and solutions.
nmodl::ast::LinearBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3372
nmodl::visitor::SympySolverVisitor::vars
std::set< std::string > vars
local variables in current block + globals
Definition: sympy_solver_visitor.hpp:109
nmodl::pybind_wrappers::pybind_wrap_api::destroy_des_executor
decltype(&destroy_des_executor_func) destroy_des_executor
Definition: pyembed.hpp:122
codegen_naming.hpp
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
nmodl::ast::ExpressionStatement
TODO.
Definition: expression_statement.hpp:38
nmodl::visitor::SympySolverVisitor::global_vars
std::set< std::string > global_vars
global variables
Definition: sympy_solver_visitor.hpp:106
nmodl::ast::DerivativeBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3242
nmodl::visitor::SympySolverVisitor::SMALL_LINEAR_SYSTEM_MAX_STATES
int SMALL_LINEAR_SYSTEM_MAX_STATES
max number of state vars allowed for small system linear solver
Definition: sympy_solver_visitor.hpp:165
nmodl::stringutils::trim
static std::string trim(std::string text)
Definition: string_utils.hpp:63
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::DerivativeBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3249
nmodl::ast::LinEquation::get_left_linxpression
std::shared_ptr< Expression > get_left_linxpression() const noexcept
Getter for member variable LinEquation::left_linxpression.
Definition: lin_equation.hpp:159
nmodl::visitor::SympySolverVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: sympy_solver_visitor.cpp:636
nmodl::visitor::SympySolverVisitor::function_calls
std::set< std::string > function_calls
custom function calls used in ODE block
Definition: sympy_solver_visitor.hpp:112
nmodl::visitor::SympySolverVisitor::block_with_expression_statements
ast::StatementBlock * block_with_expression_statements
block where expression statements appear (to check there is only one)
Definition: sympy_solver_visitor.hpp:131
nmodl::visitor::SympySolverVisitor::replace_diffeq_expression
static void replace_diffeq_expression(ast::DiffEqExpression &expr, const std::string &new_expr)
replace binary expression with new expression provided as string
Definition: sympy_solver_visitor.cpp:64
nmodl::ast::LinearBlock::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:3362
nmodl::symtab::syminfo::to_string
std::string to_string(const T &obj)
Definition: symbol_properties.hpp:279
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12902
nmodl::visitor::SympySolverVisitor::visit_var_name
void visit_var_name(ast::VarName &node) override
visit node of type ast::VarName
Definition: sympy_solver_visitor.cpp:369
nmodl::visitor::SympySolverVisitor::visit_expression_statement
void visit_expression_statement(ast::ExpressionStatement &node) override
visit node of type ast::ExpressionStatement
Definition: sympy_solver_visitor.cpp:629
nmodl::ast::AstNodeType::SOLVE_BLOCK
@ SOLVE_BLOCK
type of ast::SolveBlock
nmodl::ast::DiffEqExpression::get_expression
std::shared_ptr< BinaryExpression > get_expression() const noexcept
Getter for member variable DiffEqExpression::expression.
Definition: diff_eq_expression.hpp:143
nmodl::visitor::SympySolverVisitor::state_vars_in_block
std::set< std::string > state_vars_in_block
set of state variables used in block
Definition: sympy_solver_visitor.hpp:149
nmodl::ast::NonLinEquation::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6830
nmodl::visitor::SympyReplaceSolutionsVisitor::ReplacePolicy::GREEDY
@ GREEDY
Replace statements greedily.
nmodl::visitor::SympySolverVisitor::get_solution_location_iterator
ast::StatementVector::const_iterator get_solution_location_iterator(const ast::StatementVector &statements)
return iterator pointing to where solution should be inserted in statement block
Definition: sympy_solver_visitor.cpp:86
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::visitor::SympySolverVisitor::use_pade_approx
bool use_pade_approx
optionally replace cnexp solution with (1,1) pade approx
Definition: sympy_solver_visitor.hpp:159
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:126
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::visitor::SympySolverVisitor::init_state_vars_vector
void init_state_vars_vector()
construct vector from set of state vars in correct order
Definition: sympy_solver_visitor.cpp:55
nmodl::visitor::SympySolverVisitor::collect_state_vars
bool collect_state_vars
true for (non)linear eqs, to identify all state vars used in equations
Definition: sympy_solver_visitor.hpp:143
nmodl::visitor::SympySolverVisitor::conserve_equation
std::unordered_map< std::string, std::string > conserve_equation
map from state vars to the algebraic equation from CONSERVE statement that should replace their ODE,...
Definition: sympy_solver_visitor.hpp:156
nmodl::ast::DiffEqExpression::set_expression
void set_expression(std::shared_ptr< BinaryExpression > &&expression)
Setter for member variable DiffEqExpression::expression (rvalue reference)
Definition: ast.cpp:6696
nmodl::visitor::SympySolverVisitor::state_vars
std::vector< std::string > state_vars
vector of state vars used in block (in same order as all_state_vars)
Definition: sympy_solver_visitor.hpp:152
nmodl::ast::NonLinearBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3531
nmodl::visitor::SympySolverVisitor::construct_eigen_solver_block
void construct_eigen_solver_block(const std::vector< std::string > &pre_solve_statements, const std::vector< std::string > &solutions, bool linear)
construct solver block
Definition: sympy_solver_visitor.cpp:158
nmodl::stringutils::split_string
static std::vector< std::string > split_string(const std::string &text, char delimiter)
Split a text in a list of words, using a given delimiter character.
Definition: string_utils.hpp:116
nmodl::pybind_wrappers::pybind_wrap_api::create_des_executor
decltype(&create_des_executor_func) create_des_executor
Definition: pyembed.hpp:118
nmodl::ast::ExpressionStatement::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:9030
nmodl::ast::NonLinEquation::get_rhs
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable NonLinEquation::rhs.
Definition: non_lin_equation.hpp:168
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::visitor::SympySolverVisitor::visit_non_lin_equation
void visit_non_lin_equation(ast::NonLinEquation &node) override
visit node of type ast::NonLinEquation
Definition: sympy_solver_visitor.cpp:601
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::visitor::create_statement_block
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:155
nmodl::pybind_wrappers::pybind_wrap_api::create_sls_executor
decltype(&create_sls_executor_func) create_sls_executor
Definition: pyembed.hpp:116
nmodl::visitor::SympySolverVisitor::elimination
bool elimination
optionally do CSE (common subexpression elimination) for sparse solver
Definition: sympy_solver_visitor.hpp:162
nmodl::ast::LinearBlock
Represents LINEAR block in the NMODL.
Definition: linear_block.hpp:53
nmodl::visitor::SympySolverVisitor::last_expression_statement
ast::ExpressionStatement * last_expression_statement
last expression statement visited (to know where to insert solutions in statement block)
Definition: sympy_solver_visitor.hpp:125
nmodl::codegen::naming::DERIVIMPLICIT_METHOD
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Definition: codegen_naming.hpp:24
nmodl::ast::NonLinearBlock
Represents NONLINEAR block in the NMODL.
Definition: non_linear_block.hpp:50
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::SympySolverVisitor::current_statement_block
ast::StatementBlock * current_statement_block
current statement block being visited
Definition: sympy_solver_visitor.hpp:128
nmodl::visitor::SympySolverVisitor::visit_derivative_block
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
Definition: sympy_solver_visitor.cpp:492
nmodl::pybind_wrappers
Definition: pyembed.cpp:20
nmodl::codegen::naming::EULER_METHOD
static constexpr char EULER_METHOD[]
euler method in nmodl
Definition: codegen_naming.hpp:27
nmodl::ast::LinEquation
TODO.
Definition: lin_equation.hpp:38
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:92
nmodl::ast::VarName::get_name
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
Definition: var_name.hpp:164
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl::visitor::SympySolverVisitor::expression_statements
std::unordered_set< ast::Statement * > expression_statements
expression statements appearing in the block (these can be of type DiffEqExpression,...
Definition: sympy_solver_visitor.hpp:119
nmodl::ast::NonLinEquation::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable NonLinEquation::lhs.
Definition: non_lin_equation.hpp:159
nmodl::visitor::SympySolverVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: sympy_solver_visitor.cpp:643
nmodl::visitor::SympySolverVisitor::eq_system
std::vector< std::string > eq_system
vector of {ODE, linear eq, non-linear eq} system to solve
Definition: sympy_solver_visitor.hpp:137
nmodl::ast::Ast::get_statement_block
virtual std::shared_ptr< StatementBlock > get_statement_block() const
Return associated statement block for the AST node.
Definition: ast.cpp:32
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::pybind_wrappers::pybind_wrap_api::destroy_nsls_executor
decltype(&destroy_nsls_executor_func) destroy_nsls_executor
Definition: pyembed.hpp:121
nmodl::visitor::SympySolverVisitor::eq_system_is_valid
bool eq_system_is_valid
only solve eq_system system of equations if this is true:
Definition: sympy_solver_visitor.hpp:140
nmodl::visitor::SympySolverVisitor::solve_linear_system
void solve_linear_system(const std::vector< std::string > &pre_solve_statements={})
solve linear system (for "LINEAR")
Definition: sympy_solver_visitor.cpp:286
nmodl::visitor::suffix_random_string
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
Definition: visitor_utils.cpp:32
nmodl::visitor::SympySolverVisitor::current_expression_statement
ast::ExpressionStatement * current_expression_statement
current expression statement being visited (to track ODEs / (non)lineqs)
Definition: sympy_solver_visitor.hpp:122
nmodl::visitor::SympySolverVisitor::check_expr_statements_in_same_block
void check_expr_statements_in_same_block()
raise error if kinetic/ode/(non)linear statements are spread over multiple blocks
Definition: sympy_solver_visitor.cpp:73
nmodl::pybind_wrappers::pybind_wrap_api::destroy_sls_executor
decltype(&destroy_sls_executor_func) destroy_sls_executor
Definition: pyembed.hpp:120
nmodl::visitor::SympyReplaceSolutionsVisitor::ReplacePolicy::VALUE
@ VALUE
Replace statements matching by lhs varName.
nmodl::ast::LinEquation::get_linxpression
std::shared_ptr< Expression > get_linxpression() const noexcept
Getter for member variable LinEquation::linxpression.
Definition: lin_equation.hpp:168
sympy_solver_visitor.hpp
Visitor for systems of algebraic and differential equations
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:170
all.hpp
Auto generated AST classes declaration.
nmodl::ast::VarName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1120
nmodl::pybind_wrappers::DiffeqSolverExecutor
Definition: pyembed.hpp:67
pyembed.hpp
nmodl::visitor::SympySolverVisitor::visit_diff_eq_expression
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
Definition: sympy_solver_visitor.cpp:389