User Guide
sympy_solver.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
10 
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
13 
14 #include "ast/program.hpp"
16 #include "parser/nmodl_driver.hpp"
28 
29 
30 using namespace nmodl;
31 using namespace codegen;
32 using namespace visitor;
33 using namespace test;
34 using namespace test_utils;
35 
36 using Catch::Matchers::ContainsSubstring; // ContainsSubstring in newer Catch2
37 
39 
40 using ast::AstNodeType;
42 
43 
44 //=============================================================================
45 // SympySolver visitor tests
46 //=============================================================================
47 
48 std::vector<std::string> run_sympy_solver_visitor(
49  const std::string& text,
50  bool pade = false,
51  bool cse = false,
52  AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
53  bool kinetic = false) {
54  std::vector<std::string> results;
55 
56  // construct AST from text
58  const auto& ast = driver.parse_string(text);
59 
60  // construct symbol table from AST
61  SymtabVisitor().visit_program(*ast);
62 
63  // unroll loops and fold constants
64  ConstantFolderVisitor().visit_program(*ast);
65  LoopUnrollVisitor().visit_program(*ast);
66  ConstantFolderVisitor().visit_program(*ast);
67  SymtabVisitor().visit_program(*ast);
68 
69  if (kinetic) {
70  KineticBlockVisitor().visit_program(*ast);
71  }
72 
73  // run SympySolver on AST
74  SympySolverVisitor(pade, cse).visit_program(*ast);
75 
76  // check that, after visitor rearrangement, parents are still up-to-date
77  CheckParentVisitor().check_ast(*ast);
78 
79  // run lookup visitor to extract results from AST
80  for (const auto& eq: collect_nodes(*ast, {ret_nodetype})) {
81  results.push_back(to_nmodl(eq));
82  }
83 
84  return results;
85 }
86 
87 // check if in a list of vars (like LOCAL) there are duplicates
88 bool is_unique_vars(std::string result) {
89  result.erase(std::remove(result.begin(), result.end(), ','), result.end());
90  std::stringstream ss(result);
91  std::string token;
92 
93  std::unordered_set<std::string> old_vars;
94 
95  while (getline(ss, token, ' ')) {
96  if (!old_vars.insert(token).second) {
97  return false;
98  }
99  }
100  return true;
101 }
102 
103 
104 /**
105  * \brief Compare nmodl blocks that contain systems of equations (i.e. derivative, linear, etc.)
106  *
107  * This is basically and advanced string == string comparison where we detect the (various) systems
108  * of equations and check if they are equivalent. Implemented mostly in python since we need a call
109  * to sympy to simplify the equations.
110  *
111  * - compare_systems_of_eq The core of the code. \p result_dict and \p expected_dict are
112  * dictionaries that represent the systems of equations in this way:
113  *
114  * a = b*x + c -> result_dict['a'] = 'b*x + c'
115  *
116  * where the variable \p a become a key \p k of the dictionary.
117  *
118  * In there we go over all the equations in \p result_dict and \p expected_dict and check that
119  * result_dict[k] - expected_dict[k] simplifies to 0.
120  *
121  * - sanitize is to transform the equations in something treatable by sympy (i.e. pow(dt, 3) ->
122  * dt**3
123  * - reduce back-substitution of the temporary variables
124  *
125  * \p require_fail requires that the equations are different. Used only for unit-test this function
126  *
127  * \warning do not use this method when there are tmp variables not in the form: tmp_<number>
128  */
129 void compare_blocks(const std::string& result,
130  const std::string& expected,
131  const bool require_fail = false) {
132  using namespace pybind11::literals;
133 
134  auto locals =
135  pybind11::dict("result"_a = result, "expected"_a = expected, "is_equal"_a = false);
136  pybind11::exec(R"(
137  # Comments are in the doxygen for better highlighting
138  def compare_blocks(result, expected):
139 
140  def sanitize(s):
141  import re
142  d = {'\[(\d+)\]':'_\\1', 'pow\((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
143  out = s
144  for key, val in d.items():
145  out = re.sub(key, val, out)
146  return out
147 
148  def compare_systems_of_eq(result_dict, expected_dict):
149  from sympy.parsing.sympy_parser import parse_expr
150  try:
151  for k, v in result_dict.items():
152  if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
153  return False
154  except KeyError:
155  return False
156 
157  result_dict.clear()
158  expected_dict.clear()
159  return True
160 
161  def reduce(s):
162  max_tmp = -1
163  d = {}
164 
165  sout = ""
166  # split of sout and a dict with the tmp variables
167  for line in s.split('\n'):
168  line_split = line.lstrip().split('=')
169 
170  if len(line_split) == 2 and line_split[0].startswith('tmp_'):
171  # back-substitution of tmp variables in tmp variables
172  tmp_var = line_split[0].strip()
173  if tmp_var in d:
174  continue
175 
176  max_tmp = max(max_tmp, int(tmp_var[4:]))
177  for k, v in d.items():
178  line_split[1] = line_split[1].replace(k, f'({v})')
179  d[tmp_var] = line_split[1]
180  elif 'LOCAL' in line:
181  sout += line.split('tmp_0')[0] + '\n'
182  else:
183  sout += line + '\n'
184 
185  # Back-substitution of the tmps
186  # so that we do not replace tmp_11 with (tmp_1)1
187  for j in range(max_tmp, -1, -1):
188  k = f'tmp_{j}'
189  sout = sout.replace(k, f'({d[k]})')
190 
191  return sout
192 
193  result = reduce(sanitize(result)).split('\n')
194  expected = reduce(sanitize(expected)).split('\n')
195 
196  if len(result) != len(expected):
197  return False
198 
199  result_dict = {}
200  expected_dict = {}
201  for token1, token2 in zip(result, expected):
202  if token1 == token2:
203  if not compare_systems_of_eq(result_dict, expected_dict):
204  return False
205  continue
206 
207  eq1 = token1.split('=')
208  eq2 = token2.split('=')
209  if len(eq1) == 2 and len(eq2) == 2:
210  result_dict[eq1[0]] = eq1[1]
211  expected_dict[eq2[0]] = eq2[1]
212  continue
213 
214  return False
215  return compare_systems_of_eq(result_dict, expected_dict)
216 
217  is_equal = compare_blocks(result, expected))",
218  pybind11::globals(),
219  locals);
220 
221  // Error log
222  if (require_fail == locals["is_equal"].cast<bool>()) {
223  if (require_fail) {
224  REQUIRE(result != expected);
225  } else {
226  REQUIRE(result == expected);
227  }
228  } else { // so that we signal to ctest that an assert was performed
229  REQUIRE(true);
230  }
231 }
232 
233 
235  // construct symbol table from AST
236  SymtabVisitor v_symtab;
237  v_symtab.visit_program(node);
238 
239  // run SympySolver on AST several times
240  SympySolverVisitor v_sympy1;
241  v_sympy1.visit_program(node);
242  v_sympy1.visit_program(node);
243 
244  // also use a second instance of SympySolver
245  SympySolverVisitor v_sympy2;
246  v_sympy2.visit_program(node);
247  v_sympy1.visit_program(node);
248  v_sympy2.visit_program(node);
249 }
250 
251 
252 std::string ast_to_string(ast::Program& node) {
253  std::stringstream stream;
254  NmodlPrintVisitor(stream).visit_program(node);
255  return stream.str();
256 }
257 
258 SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]") {
259  GIVEN("Empty strings") {
260  THEN("Strings are equal") {
261  compare_blocks("", "");
262  }
263  }
264  GIVEN("Equivalent equation") {
265  THEN("Strings are equal") {
266  compare_blocks("a = 3*b + c", "a = 2*b + b + c");
267  }
268  }
269  GIVEN("Equivalent systems of equations") {
270  std::string result = R"(
271  x = 3*b + c
272  y = 2*a + b)";
273  std::string expected = R"(
274  x = b+2*b + c
275  y = 2*a + 2*b-b)";
276  THEN("Systems of equations are equal") {
277  compare_blocks(result, expected);
278  }
279  }
280  GIVEN("Equivalent systems of equations with brackets") {
281  std::string result = R"(
282  DERIVATIVE {
283  A[0] = 3*b + c
284  y = pow(a, 3) + b
285  })";
286  std::string expected = R"(
287  DERIVATIVE {
288  tmp_0 = a + c
289  tmp_1 = tmp_0 - a
290  A[0] = b+2*b + tmp_1
291  y = pow(a, 2)*a + 2*b-b
292  })";
293  THEN("Blocks are equal") {
294  compare_blocks(result, expected);
295  }
296  }
297  GIVEN("Different systems of equations (additional space)") {
298  std::string result = R"(
299  DERIVATIVE {
300  x = 3*b + c
301  y = 2*a + b
302  })";
303  std::string expected = R"(
304  DERIVATIVE {
305  x = b+2*b + c
306  y = 2*a + 2*b-b
307  })";
308  THEN("Blocks are different") {
309  compare_blocks(result, expected, true);
310  }
311  }
312  GIVEN("Different systems of equations") {
313  std::string result = R"(
314  DERIVATIVE {
315  tmp_0 = a - c
316  tmp_1 = tmp_0 - a
317  x = 3*b + tmp_1
318  y = 2*a + b
319  })";
320  std::string expected = R"(
321  DERIVATIVE {
322  x = b+2*b + c
323  y = 2*a + 2*b-b
324  })";
325  THEN("Blocks are different") {
326  compare_blocks(result, expected, true);
327  }
328  }
329 }
330 
331 SCENARIO("Check local vars name-clash prevention", "[visitor][sympy]") {
332  GIVEN("LOCAL tmp") {
333  std::string nmodl_text = R"(
334  STATE {
335  x y
336  }
337  BREAKPOINT {
338  SOLVE states METHOD sparse
339  }
340  DERIVATIVE states {
341  LOCAL tmp, b
342  x' = tmp + b
343  y' = tmp + b
344  })";
345  THEN("There are no duplicate vars in LOCAL") {
346  auto result =
347  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
348  REQUIRE(!result.empty());
349  REQUIRE(is_unique_vars(result[0]));
350  }
351  }
352  GIVEN("LOCAL tmp_0") {
353  std::string nmodl_text = R"(
354  STATE {
355  x y
356  }
357  BREAKPOINT {
358  SOLVE states METHOD sparse
359  }
360  DERIVATIVE states {
361  LOCAL tmp_0, b
362  x' = tmp_0 + b
363  y' = tmp_0 + b
364  })";
365  THEN("There are no duplicate vars in LOCAL") {
366  auto result =
367  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
368  REQUIRE(!result.empty());
369  REQUIRE(is_unique_vars(result[0]));
370  }
371  }
372 }
373 
374 SCENARIO("Solve ODEs with cnexp or euler method using SympySolverVisitor",
375  "[visitor][sympy][cnexp][euler]") {
376  GIVEN("Derivative block without ODE, solver method cnexp") {
377  std::string nmodl_text = R"(
378  BREAKPOINT {
379  SOLVE states METHOD cnexp
380  }
381  DERIVATIVE states {
382  m = m + h
383  }
384  )";
385  THEN("No ODEs found - do nothing") {
386  auto result = run_sympy_solver_visitor(nmodl_text);
387  REQUIRE(result.empty());
388  }
389  }
390  GIVEN("Derivative block with ODES, solver method is euler") {
391  std::string nmodl_text = R"(
392  BREAKPOINT {
393  SOLVE states METHOD euler
394  }
395  DERIVATIVE states {
396  m' = (mInf-m)/mTau
397  h' = (hInf-h)/hTau
398  z = a*b + c
399  }
400  )";
401  THEN("Construct forwards Euler solutions") {
402  auto result = run_sympy_solver_visitor(nmodl_text);
403  REQUIRE(result.size() == 2);
404  REQUIRE(result[0] == "m = (-dt*(m-mInf)+m*mTau)/mTau");
405  REQUIRE(result[1] == "h = (-dt*(h-hInf)+h*hTau)/hTau");
406  }
407  }
408  GIVEN("Derivative block with calling external functions passes sympy") {
409  std::string nmodl_text = R"(
410  BREAKPOINT {
411  SOLVE states METHOD euler
412  }
413  DERIVATIVE states {
414  m' = sawtooth(m)
415  n' = sin(n)
416  p' = my_user_func(p)
417  }
418  )";
419  THEN("Construct forward Euler interpreting external functions as symbols") {
420  auto result = run_sympy_solver_visitor(nmodl_text);
421  REQUIRE(result.size() == 3);
422  REQUIRE(result[0] == "m = dt*sawtooth(m)+m");
423  REQUIRE(result[1] == "n = dt*sin(n)+n");
424  REQUIRE(result[2] == "p = dt*my_user_func(p)+p");
425  }
426  }
427  GIVEN("Derivative block with ODE, 1 state var in array, solver method euler") {
428  std::string nmodl_text = R"(
429  STATE {
430  m[1]
431  }
432  BREAKPOINT {
433  SOLVE states METHOD euler
434  }
435  DERIVATIVE states {
436  m'[0] = (mInf-m[0])/mTau
437  }
438  )";
439  THEN("Construct forwards Euler solutions") {
440  auto result = run_sympy_solver_visitor(nmodl_text);
441  REQUIRE(result.size() == 1);
442  REQUIRE(result[0] == "m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
443  }
444  }
445  GIVEN("Derivative block with ODE, 1 state var in array, solver method cnexp") {
446  std::string nmodl_text = R"(
447  STATE {
448  m[1]
449  }
450  BREAKPOINT {
451  SOLVE states METHOD cnexp
452  }
453  DERIVATIVE states {
454  m'[0] = (mInf-m[0])/mTau
455  }
456  )";
457  THEN("Construct forwards Euler solutions") {
458  auto result = run_sympy_solver_visitor(nmodl_text);
459  REQUIRE(result.size() == 1);
460  REQUIRE(result[0] == "m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
461  }
462  }
463  GIVEN("Derivative block with linear ODES, solver method cnexp") {
464  std::string nmodl_text = R"(
465  BREAKPOINT {
466  SOLVE states METHOD cnexp
467  }
468  DERIVATIVE states {
469  m' = (mInf-m)/mTau
470  z = a*b + c
471  h' = hInf/hTau - h/hTau
472  }
473  )";
474  THEN("Integrate equations analytically") {
475  auto result = run_sympy_solver_visitor(nmodl_text);
476  REQUIRE(result.size() == 2);
477  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
478  REQUIRE(result[1] == "h = hInf-(-h+hInf)*exp(-dt/hTau)");
479  }
480  }
481  GIVEN("Derivative block including non-linear but solvable ODES, solver method cnexp") {
482  std::string nmodl_text = R"(
483  BREAKPOINT {
484  SOLVE states METHOD cnexp
485  }
486  DERIVATIVE states {
487  m' = (mInf-m)/mTau
488  h' = c2 * h*h
489  }
490  )";
491  THEN("Integrate equations analytically") {
492  auto result = run_sympy_solver_visitor(nmodl_text);
493  REQUIRE(result.size() == 2);
494  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
495  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
496  }
497  }
498  GIVEN("Derivative block including array of 2 state vars, solver method cnexp") {
499  std::string nmodl_text = R"(
500  BREAKPOINT {
501  SOLVE states METHOD cnexp
502  }
503  STATE {
504  X[2]
505  }
506  DERIVATIVE states {
507  X'[0] = (mInf-X[0])/mTau
508  X'[1] = c2 * X[1]*X[1]
509  }
510  )";
511  THEN("Integrate equations analytically") {
512  auto result = run_sympy_solver_visitor(nmodl_text);
513  REQUIRE(result.size() == 2);
514  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
515  REQUIRE(result[1] == "X[1] = -X[1]/(c2*dt*X[1]-1.0)");
516  }
517  }
518  GIVEN("Derivative block including loop over array vars, solver method cnexp") {
519  std::string nmodl_text = R"(
520  DEFINE N 3
521  BREAKPOINT {
522  SOLVE states METHOD cnexp
523  }
524  ASSIGNED {
525  mTau[N]
526  }
527  STATE {
528  X[N]
529  }
530  DERIVATIVE states {
531  FROM i=0 TO N-1 {
532  X'[i] = (mInf-X[i])/mTau[i]
533  }
534  }
535  )";
536  THEN("Integrate equations analytically") {
537  auto result = run_sympy_solver_visitor(nmodl_text);
538  REQUIRE(result.size() == 3);
539  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
540  REQUIRE(result[1] == "X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
541  REQUIRE(result[2] == "X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
542  }
543  }
544  GIVEN("Derivative block including loop over array vars, solver method euler") {
545  std::string nmodl_text = R"(
546  DEFINE N 3
547  BREAKPOINT {
548  SOLVE states METHOD euler
549  }
550  ASSIGNED {
551  mTau[N]
552  }
553  STATE {
554  X[N]
555  }
556  DERIVATIVE states {
557  FROM i=0 TO N-1 {
558  X'[i] = (mInf-X[i])/mTau[i]
559  }
560  }
561  )";
562  THEN("Integrate equations analytically") {
563  auto result = run_sympy_solver_visitor(nmodl_text);
564  REQUIRE(result.size() == 3);
565  REQUIRE(result[0] == "X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
566  REQUIRE(result[1] == "X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
567  REQUIRE(result[2] == "X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
568  }
569  }
570  GIVEN("Derivative block including ODES that can't currently be solved, solver method cnexp") {
571  std::string nmodl_text = R"(
572  BREAKPOINT {
573  SOLVE states METHOD cnexp
574  }
575  DERIVATIVE states {
576  z' = a/z + b/z/z
577  h' = c2 * h*h
578  x' = a
579  y' = c3 * y*y*y
580  }
581  )";
582  THEN("Integrate equations analytically where possible, otherwise leave untouched") {
583  auto result = run_sympy_solver_visitor(nmodl_text);
584  REQUIRE(result.size() == 4);
585  /// sympy 1.9 able to solve ode but not older versions
586  REQUIRE((result[0] == "z' = a/z+b/z/z" ||
587  result[0] ==
588  "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
589  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
590  REQUIRE(result[2] == "x = a*dt+x");
591  /// sympy 1.4 able to solve ode but not older versions
592  REQUIRE((result[3] == "y' = c3*y*y*y" ||
593  result[3] == "y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
594  }
595  }
596  GIVEN("Derivative block with cnexp solver method, AST after SympySolver pass") {
597  std::string nmodl_text = R"(
598  BREAKPOINT {
599  SOLVE states METHOD cnexp
600  }
601  DERIVATIVE states {
602  m' = (mInf-m)/mTau
603  }
604  )";
605  // construct AST from text
607  auto ast = driver.parse_string(nmodl_text);
608 
609  // construct symbol table from AST
610  SymtabVisitor().visit_program(*ast);
611 
612  // run SympySolver on AST
613  SympySolverVisitor().visit_program(*ast);
614 
615  std::string AST_string = ast_to_string(*ast);
616 
617  THEN("More SympySolver passes do nothing to the AST and don't throw") {
618  REQUIRE_NOTHROW(run_sympy_visitor_passes(*ast));
619  REQUIRE(AST_string == ast_to_string(*ast));
620  }
621  }
622 }
623 
624 SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
625  "[visitor][sympy][derivimplicit]") {
626  GIVEN("Derivative block with derivimplicit solver method and conditional block") {
627  std::string nmodl_text = R"(
628  STATE {
629  m
630  }
631  BREAKPOINT {
632  SOLVE states METHOD derivimplicit
633  }
634  DERIVATIVE states {
635  IF (mInf == 1) {
636  mInf = mInf+1
637  }
638  m' = (mInf-m)/mTau
639  }
640  )";
641  std::string expected_result = R"(
642  DERIVATIVE states {
643  EIGEN_NEWTON_SOLVE[1]{
644  LOCAL old_m
645  }{
646  IF (mInf == 1) {
647  mInf = mInf+1
648  }
649  old_m = m
650  }{
651  nmodl_eigen_x[0] = m
652  }{
653  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
654  nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
655  }{
656  m = nmodl_eigen_x[0]
657  }{
658  }
659  })";
660  THEN("SympySolver correctly inserts ode to block") {
661  CAPTURE(nmodl_text);
662  auto result =
663  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
664  compare_blocks(result[0], reindent_text(expected_result));
665  }
666  }
667 
668  GIVEN("Derivative block, sparse, print in order") {
669  std::string nmodl_text = R"(
670  STATE {
671  x y
672  }
673  BREAKPOINT {
674  SOLVE states METHOD sparse
675  }
676  DERIVATIVE states {
677  LOCAL a, b
678  y' = a
679  x' = b
680  })";
681  std::string expected_result = R"(
682  DERIVATIVE states {
683  EIGEN_NEWTON_SOLVE[2]{
684  LOCAL a, b, old_y, old_x
685  }{
686  old_y = y
687  old_x = x
688  }{
689  nmodl_eigen_x[0] = x
690  nmodl_eigen_x[1] = y
691  }{
692  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
693  nmodl_eigen_j[0] = 0
694  nmodl_eigen_j[2] = -1/dt
695  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
696  nmodl_eigen_j[1] = -1/dt
697  nmodl_eigen_j[3] = 0
698  }{
699  x = nmodl_eigen_x[0]
700  y = nmodl_eigen_x[1]
701  }{
702  }
703  })";
704 
705  THEN("Construct & solve linear system for backwards Euler") {
706  auto result =
707  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
708 
709  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
710  }
711  }
712  GIVEN("Derivative block, sparse, print in order, vectors") {
713  std::string nmodl_text = R"(
714  STATE {
715  M[2]
716  }
717  BREAKPOINT {
718  SOLVE states METHOD sparse
719  }
720  DERIVATIVE states {
721  LOCAL a, b
722  M'[1] = a
723  M'[0] = b
724  })";
725  std::string expected_result = R"(
726  DERIVATIVE states {
727  EIGEN_NEWTON_SOLVE[2]{
728  LOCAL a, b, old_M_1, old_M_0
729  }{
730  old_M_1 = M[1]
731  old_M_0 = M[0]
732  }{
733  nmodl_eigen_x[0] = M[0]
734  nmodl_eigen_x[1] = M[1]
735  }{
736  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
737  nmodl_eigen_j[0] = 0
738  nmodl_eigen_j[2] = -1/dt
739  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
740  nmodl_eigen_j[1] = -1/dt
741  nmodl_eigen_j[3] = 0
742  }{
743  M[0] = nmodl_eigen_x[0]
744  M[1] = nmodl_eigen_x[1]
745  }{
746  }
747  })";
748 
749  THEN("Construct & solve linear system for backwards Euler") {
750  auto result =
751  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
752 
753  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
754  }
755  }
756  GIVEN("Derivative block, sparse, derivatives mixed with local variable reassignment") {
757  std::string nmodl_text = R"(
758  STATE {
759  x y
760  }
761  BREAKPOINT {
762  SOLVE states METHOD sparse
763  }
764  DERIVATIVE states {
765  LOCAL a, b
766  x' = a
767  b = b + 1
768  y' = b
769  })";
770  std::string expected_result = R"(
771  DERIVATIVE states {
772  EIGEN_NEWTON_SOLVE[2]{
773  LOCAL a, b, old_x, old_y
774  }{
775  old_x = x
776  old_y = y
777  }{
778  nmodl_eigen_x[0] = x
779  nmodl_eigen_x[1] = y
780  }{
781  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
782  nmodl_eigen_j[0] = -1/dt
783  nmodl_eigen_j[2] = 0
784  b = b+1
785  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
786  nmodl_eigen_j[1] = 0
787  nmodl_eigen_j[3] = -1/dt
788  }{
789  x = nmodl_eigen_x[0]
790  y = nmodl_eigen_x[1]
791  }{
792  }
793  })";
794 
795  THEN("Construct & solve linear system for backwards Euler") {
796  auto result =
797  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
798 
799  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
800  }
801  }
802  GIVEN(
803  "Throw exception during derivative variable reassignment interleaved in the differential "
804  "equation set") {
805  std::string nmodl_text = R"(
806  STATE {
807  x y
808  }
809  BREAKPOINT {
810  SOLVE states METHOD sparse
811  }
812  DERIVATIVE states {
813  LOCAL a, b
814  x' = a
815  x = x + 1
816  y' = b + x
817  })";
818 
819  THEN(
820  "Throw an error because state variable assignments are not allowed inside the system "
821  "of differential "
822  "equations") {
823  REQUIRE_THROWS_WITH(
824  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK),
825  Catch::Matchers::ContainsSubstring(
826  "State variable assignment(s) interleaved in system of "
827  "equations/differential equations") &&
828  Catch::Matchers::StartsWith("SympyReplaceSolutionsVisitor"));
829  }
830  }
831  GIVEN("Derivative block in control flow block") {
832  std::string nmodl_text = R"(
833  STATE {
834  x y
835  }
836  BREAKPOINT {
837  SOLVE states METHOD sparse
838  }
839  DERIVATIVE states {
840  LOCAL a, b
841  if (a == 1) {
842  x' = a
843  y' = b
844  }
845  })";
846  std::string expected_result = R"(
847  DERIVATIVE states {
848  LOCAL a, b
849  IF (a == 1) {
850  EIGEN_NEWTON_SOLVE[2]{
851  LOCAL old_x, old_y
852  }{
853  old_x = x
854  old_y = y
855  }{
856  nmodl_eigen_x[0] = x
857  nmodl_eigen_x[1] = y
858  }{
859  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
860  nmodl_eigen_j[0] = -1/dt
861  nmodl_eigen_j[2] = 0
862  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
863  nmodl_eigen_j[1] = 0
864  nmodl_eigen_j[3] = -1/dt
865  }{
866  x = nmodl_eigen_x[0]
867  y = nmodl_eigen_x[1]
868  }{
869  }
870  }
871  })";
872 
873  THEN("Construct & solve linear system for backwards Euler") {
874  auto result =
875  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
876 
877  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
878  }
879  }
880  GIVEN(
881  "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
882  "block") {
883  std::string nmodl_text = R"(
884  STATE {
885  x y
886  }
887  BREAKPOINT {
888  SOLVE states METHOD sparse
889  }
890  DERIVATIVE states {
891  LOCAL a, b
892  x' = a * y+b
893  if (b == 1) {
894  a = a + 1
895  }
896  y' = x + a*y
897  })";
898  std::string expected_result = R"(
899  DERIVATIVE states {
900  EIGEN_NEWTON_SOLVE[2]{
901  LOCAL a, b, old_x, old_y
902  }{
903  old_x = x
904  old_y = y
905  }{
906  nmodl_eigen_x[0] = x
907  nmodl_eigen_x[1] = y
908  }{
909  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
910  nmodl_eigen_j[0] = -1/dt
911  nmodl_eigen_j[2] = a
912  IF (b == 1) {
913  a = a+1
914  }
915  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
916  nmodl_eigen_j[1] = 1.0
917  nmodl_eigen_j[3] = a-1/dt
918  }{
919  x = nmodl_eigen_x[0]
920  y = nmodl_eigen_x[1]
921  }{
922  }
923  })";
924  std::string expected_result_cse = R"(
925  DERIVATIVE states {
926  EIGEN_NEWTON_SOLVE[2]{
927  LOCAL a, b, old_x, old_y
928  }{
929  old_x = x
930  old_y = y
931  }{
932  nmodl_eigen_x[0] = x
933  nmodl_eigen_x[1] = y
934  }{
935  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
936  nmodl_eigen_j[0] = -1/dt
937  nmodl_eigen_j[2] = a
938  IF (b == 1) {
939  a = a+1
940  }
941  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
942  nmodl_eigen_j[1] = 1.0
943  nmodl_eigen_j[3] = a-1/dt
944  }{
945  x = nmodl_eigen_x[0]
946  y = nmodl_eigen_x[1]
947  }{
948  }
949  })";
950 
951  THEN("Construct & solve linear system for backwards Euler") {
952  auto result =
953  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
954  auto result_cse =
955  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
956 
957  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
958  compare_blocks(reindent_text(result_cse[0]), reindent_text(expected_result_cse));
959  }
960  }
961 
962  GIVEN("Derivative block of coupled & linear ODES, solver method sparse") {
963  std::string nmodl_text = R"(
964  STATE {
965  x y z
966  }
967  BREAKPOINT {
968  SOLVE states METHOD sparse
969  }
970  DERIVATIVE states {
971  LOCAL a, b, c, d, h
972  x' = a*z + b*h
973  y' = c + 2*x
974  z' = d*z - y
975  }
976  )";
977  std::string expected_result = R"(
978  DERIVATIVE states {
979  EIGEN_NEWTON_SOLVE[3]{
980  LOCAL a, b, c, d, h, old_x, old_y, old_z
981  }{
982  old_x = x
983  old_y = y
984  old_z = z
985  }{
986  nmodl_eigen_x[0] = x
987  nmodl_eigen_x[1] = y
988  nmodl_eigen_x[2] = z
989  }{
990  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
991  nmodl_eigen_j[0] = -1/dt
992  nmodl_eigen_j[3] = 0
993  nmodl_eigen_j[6] = a
994  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
995  nmodl_eigen_j[1] = 2.0
996  nmodl_eigen_j[4] = -1/dt
997  nmodl_eigen_j[7] = 0
998  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
999  nmodl_eigen_j[2] = 0
1000  nmodl_eigen_j[5] = -1.0
1001  nmodl_eigen_j[8] = d-1/dt
1002  }{
1003  x = nmodl_eigen_x[0]
1004  y = nmodl_eigen_x[1]
1005  z = nmodl_eigen_x[2]
1006  }{
1007  }
1008  })";
1009  std::string expected_cse_result = R"(
1010  DERIVATIVE states {
1011  EIGEN_NEWTON_SOLVE[3]{
1012  LOCAL a, b, c, d, h, old_x, old_y, old_z
1013  }{
1014  old_x = x
1015  old_y = y
1016  old_z = z
1017  }{
1018  nmodl_eigen_x[0] = x
1019  nmodl_eigen_x[1] = y
1020  nmodl_eigen_x[2] = z
1021  }{
1022  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1023  nmodl_eigen_j[0] = -1/dt
1024  nmodl_eigen_j[3] = 0
1025  nmodl_eigen_j[6] = a
1026  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1027  nmodl_eigen_j[1] = 2.0
1028  nmodl_eigen_j[4] = -1/dt
1029  nmodl_eigen_j[7] = 0
1030  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1031  nmodl_eigen_j[2] = 0
1032  nmodl_eigen_j[5] = -1.0
1033  nmodl_eigen_j[8] = d-1/dt
1034  }{
1035  x = nmodl_eigen_x[0]
1036  y = nmodl_eigen_x[1]
1037  z = nmodl_eigen_x[2]
1038  }{
1039  }
1040  })";
1041 
1042  THEN("Construct & solve linear system for backwards Euler") {
1043  auto result =
1044  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1045  auto result_cse =
1046  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
1047 
1048  compare_blocks(result[0], reindent_text(expected_result));
1049  compare_blocks(result_cse[0], reindent_text(expected_cse_result));
1050  }
1051  }
1052  GIVEN("Derivative block including ODES with sparse method (from nmodl paper)") {
1053  std::string nmodl_text = R"(
1054  STATE {
1055  mc m
1056  }
1057  BREAKPOINT {
1058  SOLVE scheme1 METHOD sparse
1059  }
1060  DERIVATIVE scheme1 {
1061  mc' = -a*mc + b*m
1062  m' = a*mc - b*m
1063  }
1064  )";
1065  std::string expected_result = R"(
1066  DERIVATIVE scheme1 {
1067  EIGEN_NEWTON_SOLVE[2]{
1068  LOCAL old_mc, old_m
1069  }{
1070  old_mc = mc
1071  old_m = m
1072  }{
1073  nmodl_eigen_x[0] = mc
1074  nmodl_eigen_x[1] = m
1075  }{
1076  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1077  nmodl_eigen_j[0] = -a-1/dt
1078  nmodl_eigen_j[2] = b
1079  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1080  nmodl_eigen_j[1] = a
1081  nmodl_eigen_j[3] = -b-1/dt
1082  }{
1083  mc = nmodl_eigen_x[0]
1084  m = nmodl_eigen_x[1]
1085  }{
1086  }
1087  })";
1088  THEN("Construct & solve linear system") {
1089  CAPTURE(nmodl_text);
1090  auto result =
1091  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1092  compare_blocks(result[0], reindent_text(expected_result));
1093  }
1094  }
1095  GIVEN("Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1096  std::string nmodl_text = R"(
1097  STATE {
1098  mc m
1099  }
1100  BREAKPOINT {
1101  SOLVE scheme1 METHOD sparse
1102  }
1103  DERIVATIVE scheme1 {
1104  mc' = -a*mc + b*m
1105  m' = a*mc - b*m
1106  CONSERVE m = 1 - mc
1107  }
1108  )";
1109  std::string expected_result = R"(
1110  DERIVATIVE scheme1 {
1111  EIGEN_NEWTON_SOLVE[2]{
1112  LOCAL old_mc
1113  }{
1114  old_mc = mc
1115  }{
1116  nmodl_eigen_x[0] = mc
1117  nmodl_eigen_x[1] = m
1118  }{
1119  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1120  nmodl_eigen_j[0] = -a-1/dt
1121  nmodl_eigen_j[2] = b
1122  nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1123  nmodl_eigen_j[1] = -1.0
1124  nmodl_eigen_j[3] = -1.0
1125  }{
1126  mc = nmodl_eigen_x[0]
1127  m = nmodl_eigen_x[1]
1128  }{
1129  }
1130  })";
1131  THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1132  CAPTURE(nmodl_text);
1133  auto result =
1134  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1135  compare_blocks(result[0], reindent_text(expected_result));
1136  }
1137  }
1138  GIVEN(
1139  "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1140  "= ...") {
1141  std::string nmodl_text = R"(
1142  STATE {
1143  mc m
1144  }
1145  BREAKPOINT {
1146  SOLVE scheme1 METHOD sparse
1147  }
1148  DERIVATIVE scheme1 {
1149  mc' = -a*mc + b*m
1150  m' = a*mc - b*m
1151  CONSERVE m + mc = 1
1152  }
1153  )";
1154  std::string expected_result = R"(
1155  DERIVATIVE scheme1 {
1156  EIGEN_NEWTON_SOLVE[2]{
1157  LOCAL old_mc, old_m
1158  }{
1159  old_mc = mc
1160  old_m = m
1161  }{
1162  nmodl_eigen_x[0] = mc
1163  nmodl_eigen_x[1] = m
1164  }{
1165  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1166  nmodl_eigen_j[0] = -a-1/dt
1167  nmodl_eigen_j[2] = b
1168  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1169  nmodl_eigen_j[1] = a
1170  nmodl_eigen_j[3] = -b-1/dt
1171  }{
1172  mc = nmodl_eigen_x[0]
1173  m = nmodl_eigen_x[1]
1174  }{
1175  }
1176  })";
1177  THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1178  CAPTURE(nmodl_text);
1179  auto result =
1180  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1181  compare_blocks(result[0], reindent_text(expected_result));
1182  }
1183  }
1184  GIVEN("Derivative block with ODES with sparse method, two CONSERVE statements") {
1185  std::string nmodl_text = R"(
1186  STATE {
1187  c1 o1 o2 p0 p1
1188  }
1189  BREAKPOINT {
1190  SOLVE ihkin METHOD sparse
1191  }
1192  DERIVATIVE ihkin {
1193  LOCAL alpha, beta, k3p, k4, k1ca, k2
1194  evaluate_fct(v, cai)
1195  CONSERVE p1 = 1-p0
1196  CONSERVE o2 = 1-c1-o1
1197  c1' = (-1*(alpha*c1-beta*o1))
1198  o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1199  o2' = (1*(k3p*o1-k4*o2))
1200  p0' = (-1*(k1ca*p0-k2*p1))
1201  p1' = (1*(k1ca*p0-k2*p1))
1202  })";
1203  std::string expected_result = R"(
1204  DERIVATIVE ihkin {
1205  EIGEN_NEWTON_SOLVE[5]{
1206  LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1207  }{
1208  evaluate_fct(v, cai)
1209  old_c1 = c1
1210  old_o1 = o1
1211  old_p0 = p0
1212  }{
1213  nmodl_eigen_x[0] = c1
1214  nmodl_eigen_x[1] = o1
1215  nmodl_eigen_x[2] = o2
1216  nmodl_eigen_x[3] = p0
1217  nmodl_eigen_x[4] = p1
1218  }{
1219  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1220  nmodl_eigen_j[0] = -alpha-1/dt
1221  nmodl_eigen_j[5] = beta
1222  nmodl_eigen_j[10] = 0
1223  nmodl_eigen_j[15] = 0
1224  nmodl_eigen_j[20] = 0
1225  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1226  nmodl_eigen_j[1] = alpha
1227  nmodl_eigen_j[6] = -beta-k3p-1/dt
1228  nmodl_eigen_j[11] = k4
1229  nmodl_eigen_j[16] = 0
1230  nmodl_eigen_j[21] = 0
1231  nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1232  nmodl_eigen_j[2] = -1.0
1233  nmodl_eigen_j[7] = -1.0
1234  nmodl_eigen_j[12] = -1.0
1235  nmodl_eigen_j[17] = 0
1236  nmodl_eigen_j[22] = 0
1237  nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1238  nmodl_eigen_j[3] = 0
1239  nmodl_eigen_j[8] = 0
1240  nmodl_eigen_j[13] = 0
1241  nmodl_eigen_j[18] = -k1ca-1/dt
1242  nmodl_eigen_j[23] = k2
1243  nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1244  nmodl_eigen_j[4] = 0
1245  nmodl_eigen_j[9] = 0
1246  nmodl_eigen_j[14] = 0
1247  nmodl_eigen_j[19] = -1.0
1248  nmodl_eigen_j[24] = -1.0
1249  }{
1250  c1 = nmodl_eigen_x[0]
1251  o1 = nmodl_eigen_x[1]
1252  o2 = nmodl_eigen_x[2]
1253  p0 = nmodl_eigen_x[3]
1254  p1 = nmodl_eigen_x[4]
1255  }{
1256  }
1257  })";
1258  THEN(
1259  "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1260  "algebraic relations") {
1261  CAPTURE(nmodl_text);
1262  auto result =
1263  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1264  compare_blocks(result[0], reindent_text(expected_result));
1265  }
1266  }
1267  GIVEN("Derivative block including ODES with sparse method - single var in array") {
1268  std::string nmodl_text = R"(
1269  STATE {
1270  W[1]
1271  }
1272  ASSIGNED {
1273  A[2]
1274  B[1]
1275  }
1276  BREAKPOINT {
1277  SOLVE scheme1 METHOD sparse
1278  }
1279  DERIVATIVE scheme1 {
1280  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1281  }
1282  )";
1283  std::string expected_result = R"(
1284  DERIVATIVE scheme1 {
1285  EIGEN_NEWTON_SOLVE[1]{
1286  LOCAL old_W_0
1287  }{
1288  old_W_0 = W[0]
1289  }{
1290  nmodl_eigen_x[0] = W[0]
1291  }{
1292  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1293  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1294  }{
1295  W[0] = nmodl_eigen_x[0]
1296  }{
1297  }
1298  })";
1299  THEN("Construct & solver linear system") {
1300  CAPTURE(nmodl_text);
1301  auto result =
1302  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1303  compare_blocks(result[0], reindent_text(expected_result));
1304  }
1305  }
1306  GIVEN("Derivative block including ODES with sparse method - array vars") {
1307  std::string nmodl_text = R"(
1308  STATE {
1309  M[2]
1310  }
1311  ASSIGNED {
1312  A[2]
1313  B[2]
1314  }
1315  BREAKPOINT {
1316  SOLVE scheme1 METHOD sparse
1317  }
1318  DERIVATIVE scheme1 {
1319  M'[0] = -A[0]*M[0] + B[0]*M[1]
1320  M'[1] = A[1]*M[0] - B[1]*M[1]
1321  }
1322  )";
1323  std::string expected_result = R"(
1324  DERIVATIVE scheme1 {
1325  EIGEN_NEWTON_SOLVE[2]{
1326  LOCAL old_M_0, old_M_1
1327  }{
1328  old_M_0 = M[0]
1329  old_M_1 = M[1]
1330  }{
1331  nmodl_eigen_x[0] = M[0]
1332  nmodl_eigen_x[1] = M[1]
1333  }{
1334  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1335  nmodl_eigen_j[0] = -A[0]-1/dt
1336  nmodl_eigen_j[2] = B[0]
1337  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1338  nmodl_eigen_j[1] = A[1]
1339  nmodl_eigen_j[3] = -B[1]-1/dt
1340  }{
1341  M[0] = nmodl_eigen_x[0]
1342  M[1] = nmodl_eigen_x[1]
1343  }{
1344  }
1345  })";
1346  THEN("Construct & solver linear system") {
1347  CAPTURE(nmodl_text);
1348  auto result =
1349  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1350  compare_blocks(result[0], reindent_text(expected_result));
1351  }
1352  }
1353  GIVEN("Derivative block including ODES with derivimplicit method - single var in array") {
1354  std::string nmodl_text = R"(
1355  STATE {
1356  W[1]
1357  }
1358  ASSIGNED {
1359  A[2]
1360  B[1]
1361  }
1362  BREAKPOINT {
1363  SOLVE scheme1 METHOD derivimplicit
1364  }
1365  DERIVATIVE scheme1 {
1366  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1367  }
1368  )";
1369  std::string expected_result = R"(
1370  DERIVATIVE scheme1 {
1371  EIGEN_NEWTON_SOLVE[1]{
1372  LOCAL old_W_0
1373  }{
1374  old_W_0 = W[0]
1375  }{
1376  nmodl_eigen_x[0] = W[0]
1377  }{
1378  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1379  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1380  }{
1381  W[0] = nmodl_eigen_x[0]
1382  }{
1383  }
1384  })";
1385  THEN("Construct newton solve block") {
1386  CAPTURE(nmodl_text);
1387  auto result =
1388  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1389  compare_blocks(result[0], reindent_text(expected_result));
1390  }
1391  }
1392  GIVEN("Derivative block including ODES with derivimplicit method") {
1393  std::string nmodl_text = R"(
1394  STATE {
1395  m h n
1396  }
1397  BREAKPOINT {
1398  SOLVE states METHOD derivimplicit
1399  }
1400  DERIVATIVE states {
1401  rates(v)
1402  m' = (minf-m)/mtau - 3*h
1403  h' = (hinf-h)/htau + m*m
1404  n' = (ninf-n)/ntau
1405  }
1406  )";
1407  /// new derivative block with EigenNewtonSolverBlock node
1408  std::string expected_result = R"(
1409  DERIVATIVE states {
1410  EIGEN_NEWTON_SOLVE[3]{
1411  LOCAL old_m, old_h, old_n
1412  }{
1413  rates(v)
1414  old_m = m
1415  old_h = h
1416  old_n = n
1417  }{
1418  nmodl_eigen_x[0] = m
1419  nmodl_eigen_x[1] = h
1420  nmodl_eigen_x[2] = n
1421  }{
1422  nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1423  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1424  nmodl_eigen_j[3] = -3.0
1425  nmodl_eigen_j[6] = 0
1426  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1427  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1428  nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1429  nmodl_eigen_j[7] = 0
1430  nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1431  nmodl_eigen_j[2] = 0
1432  nmodl_eigen_j[5] = 0
1433  nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1434  }{
1435  m = nmodl_eigen_x[0]
1436  h = nmodl_eigen_x[1]
1437  n = nmodl_eigen_x[2]
1438  }{
1439  }
1440  })";
1441  THEN("Construct newton solve block") {
1442  CAPTURE(nmodl_text);
1443  auto result =
1444  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1445  compare_blocks(result[0], reindent_text(expected_result));
1446  }
1447  }
1448  GIVEN("Multiple derivative blocks each with derivimplicit method") {
1449  std::string nmodl_text = R"(
1450  STATE {
1451  m h
1452  }
1453  BREAKPOINT {
1454  SOLVE states1 METHOD derivimplicit
1455  SOLVE states2 METHOD derivimplicit
1456  }
1457 
1458  DERIVATIVE states1 {
1459  m' = (minf-m)/mtau
1460  h' = (hinf-h)/htau + m*m
1461  }
1462 
1463  DERIVATIVE states2 {
1464  h' = (hinf-h)/htau + m*m
1465  m' = (minf-m)/mtau + h
1466  }
1467  )";
1468  /// EigenNewtonSolverBlock in each derivative block
1469  std::string expected_result_0 = R"(
1470  DERIVATIVE states1 {
1471  EIGEN_NEWTON_SOLVE[2]{
1472  LOCAL old_m, old_h
1473  }{
1474  old_m = m
1475  old_h = h
1476  }{
1477  nmodl_eigen_x[0] = m
1478  nmodl_eigen_x[1] = h
1479  }{
1480  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1481  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1482  nmodl_eigen_j[2] = 0
1483  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1484  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1485  nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1486  }{
1487  m = nmodl_eigen_x[0]
1488  h = nmodl_eigen_x[1]
1489  }{
1490  }
1491  })";
1492  std::string expected_result_1 = R"(
1493  DERIVATIVE states2 {
1494  EIGEN_NEWTON_SOLVE[2]{
1495  LOCAL old_h, old_m
1496  }{
1497  old_h = h
1498  old_m = m
1499  }{
1500  nmodl_eigen_x[0] = m
1501  nmodl_eigen_x[1] = h
1502  }{
1503  nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1504  nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1505  nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1506  nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1507  nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1508  nmodl_eigen_j[3] = 1.0
1509  }{
1510  m = nmodl_eigen_x[0]
1511  h = nmodl_eigen_x[1]
1512  }{
1513  }
1514  })";
1515  THEN("Construct newton solve block") {
1516  auto result =
1517  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1518  CAPTURE(nmodl_text);
1519  compare_blocks(result[0], reindent_text(expected_result_0));
1520  compare_blocks(result[1], reindent_text(expected_result_1));
1521  }
1522  }
1523 }
1524 
1525 
1526 //=============================================================================
1527 // LINEAR solve block tests
1528 //=============================================================================
1529 
1530 SCENARIO("LINEAR solve block (SympySolver Visitor)", "[sympy][linear]") {
1531  GIVEN("1 state-var symbolic LINEAR solve block") {
1532  std::string nmodl_text = R"(
1533  STATE {
1534  x
1535  }
1536  LINEAR lin {
1537  ~ 2*a*x = 1
1538  })";
1539  std::string expected_text = R"(
1540  LINEAR lin {
1541  x = 0.5/a
1542  })";
1543  THEN("solve analytically") {
1544  auto result =
1545  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1546  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1547  }
1548  }
1549  GIVEN("2 state-var LINEAR solve block") {
1550  std::string nmodl_text = R"(
1551  STATE {
1552  x y
1553  }
1554  LINEAR lin {
1555  ~ x + 4*y = 5*a
1556  ~ x - y = 0
1557  })";
1558  std::string expected_text = R"(
1559  LINEAR lin {
1560  x = a
1561  y = a
1562  })";
1563  THEN("solve analytically") {
1564  auto result =
1565  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1566  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1567  }
1568  }
1569  GIVEN("Linear block, print in order, vectors") {
1570  std::string nmodl_text = R"(
1571  STATE {
1572  M[2]
1573  }
1574  LINEAR lin {
1575  ~ M[1] = M[0] + 1
1576  ~ M[0] = 2
1577  })";
1578  std::string expected_result = R"(
1579  LINEAR lin {
1580  M[1] = 3.0
1581  M[0] = 2.0
1582  })";
1583 
1584  THEN("Construct & solve linear system") {
1585  auto result =
1586  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1587 
1588  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1589  }
1590  }
1591  GIVEN("Linear block, by value replacement, interleaved") {
1592  std::string nmodl_text = R"(
1593  STATE {
1594  x y
1595  }
1596  LINEAR lin {
1597  LOCAL a
1598  a = 0
1599  ~ x = y + a
1600  a = 1
1601  ~ y = a
1602  a = 2
1603  })";
1604  std::string expected_result = R"(
1605  LINEAR lin {
1606  LOCAL a
1607  a = 0
1608  x = 2.0*a
1609  a = 1
1610  y = a
1611  a = 2
1612  })";
1613 
1614  THEN("Construct & solve linear system") {
1615  auto result =
1616  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1617 
1618  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1619  }
1620  }
1621  GIVEN("Linear block in control flow block") {
1622  std::string nmodl_text = R"(
1623  STATE {
1624  x y
1625  }
1626  LINEAR lin {
1627  LOCAL a
1628  if (a == 1) {
1629  ~ x = y + a
1630  ~ y = a
1631  }
1632  })";
1633  std::string expected_result = R"(
1634  LINEAR lin {
1635  LOCAL a
1636  IF (a == 1) {
1637  x = 2.0*a
1638  y = a
1639  }
1640  })";
1641 
1642  THEN("Construct & solve linear system") {
1643  auto result =
1644  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1645 
1646  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1647  }
1648  }
1649  GIVEN("Linear block, linear equations mixed with control flow blocks and reassignments") {
1650  std::string nmodl_text = R"(
1651  STATE {
1652  x y
1653  }
1654  LINEAR lin {
1655  LOCAL a
1656  ~ x = y + a
1657  if (a == 1) {
1658  a = a + 1
1659  x = a + 1
1660  }
1661  ~ y = a
1662  })";
1663  std::string expected_result = R"(
1664  LINEAR lin {
1665  LOCAL a
1666  x = 2.0*a
1667  IF (a == 1) {
1668  a = a+1
1669  x = a+1
1670  }
1671  y = a
1672  })";
1673 
1674  THEN("Construct & solve linear system") {
1675  auto result =
1676  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1677 
1678  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1679  }
1680  }
1681  GIVEN("4 state-var LINEAR solve block") {
1682  std::string nmodl_text = R"(
1683  STATE {
1684  w x y z
1685  }
1686  LINEAR lin {
1687  ~ w + z/3.2 = -2.0*y
1688  ~ x + 4*c*y = -5.343*a
1689  ~ a + x/b + z - y = 0.842*b*b
1690  ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1691  })";
1692  std::string expected_text = R"(
1693  LINEAR lin {
1694  EIGEN_LINEAR_SOLVE[4]{
1695  }{
1696  }{
1697  nmodl_eigen_x[0] = w
1698  nmodl_eigen_x[1] = x
1699  nmodl_eigen_x[2] = y
1700  nmodl_eigen_x[3] = z
1701  nmodl_eigen_f[0] = 0
1702  nmodl_eigen_f[1] = 5.343*a
1703  nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1704  nmodl_eigen_f[3] = -1.43543/c
1705  nmodl_eigen_j[0] = -1.0
1706  nmodl_eigen_j[4] = 0
1707  nmodl_eigen_j[8] = -2.0
1708  nmodl_eigen_j[12] = -0.3125
1709  nmodl_eigen_j[1] = 0
1710  nmodl_eigen_j[5] = -1.0
1711  nmodl_eigen_j[9] = -4.0*c
1712  nmodl_eigen_j[13] = 0
1713  nmodl_eigen_j[2] = 0
1714  nmodl_eigen_j[6] = -1/b
1715  nmodl_eigen_j[10] = 1.0
1716  nmodl_eigen_j[14] = -1.0
1717  nmodl_eigen_j[3] = 0
1718  nmodl_eigen_j[7] = -1.0
1719  nmodl_eigen_j[11] = -1.3
1720  nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1721  }{
1722  w = nmodl_eigen_x[0]
1723  x = nmodl_eigen_x[1]
1724  y = nmodl_eigen_x[2]
1725  z = nmodl_eigen_x[3]
1726  }{
1727  }
1728  })";
1729  THEN("return matrix system to solve") {
1730  auto result =
1731  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1732  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1733  }
1734  }
1735 
1736  GIVEN("LINEAR solve block with an explicit SOLVEFOR statement") {
1737  std::string nmodl_text = R"(
1738  STATE {
1739  x
1740  y
1741  z
1742  }
1743  LINEAR lin SOLVEFOR x, y {
1744  ~ 3 * x = v - y
1745  ~ x = z * y - 5
1746  })";
1747  std::string expected_text = R"(
1748  LINEAR lin SOLVEFOR x,y{
1749  y = (v+15.0)/(3.0*z+1.0)
1750  x = (v*z-5.0)/(3.0*z+1.0)
1751  })";
1752  THEN("solve analytically") {
1753  auto result =
1754  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1755  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1756  }
1757  }
1758 }
1759 
1760 //=============================================================================
1761 // NONLINEAR solve block tests
1762 //=============================================================================
1763 
1764 SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][sympy][nonlinear]") {
1765  GIVEN("1 state-var numeric NONLINEAR solve block") {
1766  std::string nmodl_text = R"(
1767  STATE {
1768  x
1769  }
1770  NONLINEAR nonlin {
1771  ~ x = 5
1772  })";
1773  std::string expected_text = R"(
1774  NONLINEAR nonlin {
1775  EIGEN_NEWTON_SOLVE[1]{
1776  }{
1777  }{
1778  nmodl_eigen_x[0] = x
1779  }{
1780  nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1781  nmodl_eigen_j[0] = -1.0
1782  }{
1783  x = nmodl_eigen_x[0]
1784  }{
1785  }
1786  })";
1787 
1788  THEN("return F & J for newton solver") {
1789  auto result =
1790  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1791  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1792  }
1793  }
1794  GIVEN("array state-var numeric NONLINEAR solve block") {
1795  std::string nmodl_text = R"(
1796  STATE {
1797  s[3]
1798  }
1799  NONLINEAR nonlin {
1800  ~ s[0] = 1
1801  ~ s[1] = 3
1802  ~ s[2] + s[1] = s[0]
1803  })";
1804  std::string expected_text = R"(
1805  NONLINEAR nonlin {
1806  EIGEN_NEWTON_SOLVE[3]{
1807  }{
1808  }{
1809  nmodl_eigen_x[0] = s[0]
1810  nmodl_eigen_x[1] = s[1]
1811  nmodl_eigen_x[2] = s[2]
1812  }{
1813  nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1814  nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1815  nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1816  nmodl_eigen_j[0] = -1.0
1817  nmodl_eigen_j[3] = 0
1818  nmodl_eigen_j[6] = 0
1819  nmodl_eigen_j[1] = 0
1820  nmodl_eigen_j[4] = -1.0
1821  nmodl_eigen_j[7] = 0
1822  nmodl_eigen_j[2] = 1.0
1823  nmodl_eigen_j[5] = -1.0
1824  nmodl_eigen_j[8] = -1.0
1825  }{
1826  s[0] = nmodl_eigen_x[0]
1827  s[1] = nmodl_eigen_x[1]
1828  s[2] = nmodl_eigen_x[2]
1829  }{
1830  }
1831  })";
1832  THEN("return F & J for newton solver") {
1833  auto result =
1834  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1835  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1836  }
1837  }
1838 }
1839 SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sympy][kinetic]") {
1840  GIVEN("KINETIC block with not inlined function should work") {
1841  std::string nmodl_text = R"(
1842  BREAKPOINT {
1843  SOLVE kstates METHOD sparse
1844  }
1845  STATE {
1846  C1
1847  C2
1848  }
1849  FUNCTION alfa(v(mV)) {
1850  alfa = v
1851  }
1852  KINETIC kstates {
1853  ~ C1 <-> C2 (alfa(v), alfa(v))
1854  })";
1855  std::string expected_text = R"(
1856  DERIVATIVE kstates {
1857  EIGEN_NEWTON_SOLVE[2]{
1858  LOCAL kf0_, kb0_, old_C1, old_C2
1859  }{
1860  kb0_ = alfa(v)
1861  kf0_ = alfa(v)
1862  old_C1 = C1
1863  old_C2 = C2
1864  }{
1865  nmodl_eigen_x[0] = C1
1866  nmodl_eigen_x[1] = C2
1867  }{
1868  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1869  nmodl_eigen_j[0] = -kf0_-1/dt
1870  nmodl_eigen_j[2] = kb0_
1871  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1872  nmodl_eigen_j[1] = kf0_
1873  nmodl_eigen_j[3] = -kb0_-1/dt
1874  }{
1875  C1 = nmodl_eigen_x[0]
1876  C2 = nmodl_eigen_x[1]
1877  }{
1878  }
1879  })";
1880  THEN("Run Kinetic and Sympy Visitor") {
1881  std::vector<std::string> result;
1882  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1883  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1884  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1885  }
1886  }
1887  GIVEN("Protected names in Sympy are respected") {
1888  std::string nmodl_text = R"(
1889  BREAKPOINT {
1890  SOLVE kstates METHOD sparse
1891  }
1892  STATE {
1893  C1
1894  C2
1895  }
1896  FUNCTION beta(v(mV)) {
1897  beta = v
1898  }
1899  FUNCTION lowergamma(v(mV)) {
1900  lowergamma = v
1901  }
1902  KINETIC kstates {
1903  ~ C1 <-> C2 (beta(v), lowergamma(v))
1904  })";
1905  std::string expected_text = R"(
1906  DERIVATIVE kstates {
1907  EIGEN_NEWTON_SOLVE[2]{
1908  LOCAL kf0_, kb0_, old_C1, old_C2
1909  }{
1910  kf0_ = beta(v)
1911  kb0_ = lowergamma(v)
1912  old_C1 = C1
1913  old_C2 = C2
1914  }{
1915  nmodl_eigen_x[0] = C1
1916  nmodl_eigen_x[1] = C2
1917  }{
1918  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1919  nmodl_eigen_j[0] = -kf0_-1/dt
1920  nmodl_eigen_j[2] = kb0_
1921  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1922  nmodl_eigen_j[1] = kf0_
1923  nmodl_eigen_j[3] = -kb0_-1/dt
1924  }{
1925  C1 = nmodl_eigen_x[0]
1926  C2 = nmodl_eigen_x[1]
1927  }{
1928  }
1929  })";
1930  THEN("Run Kinetic and Sympy Visitor") {
1931  std::vector<std::string> result;
1932  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1933  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1934  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1935  }
1936  }
1937 }
test_utils.hpp
nmodl::parser::NmodlDriver
Class that binds all pieces together for parsing nmodl file.
Definition: nmodl_driver.hpp:67
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
solve_block_visitor.hpp
Replace solve block statements with actual solution node in the AST.
ast_to_string
std::string ast_to_string(ast::Program &node)
Definition: sympy_solver.cpp:252
nmodl::test_utils::reindent_text
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:53
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
run_sympy_visitor_passes
void run_sympy_visitor_passes(ast::Program &node)
Definition: sympy_solver.cpp:234
loop_unroll_visitor.hpp
Unroll for loop in the AST.
constant_folder_visitor.hpp
Perform constant folding of integer/float/double expressions.
compare_blocks
void compare_blocks(const std::string &result, const std::string &expected, const bool require_fail=false)
Compare nmodl blocks that contain systems of equations (i.e.
Definition: sympy_solver.cpp:129
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
kinetic_block_visitor.hpp
Visitor for kinetic block statements
program.hpp
Auto generated AST classes declaration.
run_sympy_solver_visitor
std::vector< std::string > run_sympy_solver_visitor(const std::string &text, bool pade=false, bool cse=false, AstNodeType ret_nodetype=AstNodeType::DIFF_EQ_EXPRESSION, bool kinetic=false)
Definition: sympy_solver.cpp:48
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
codegen_coreneuron_cpp_visitor.hpp
Visitor for printing C++ code compatible with legacy api of CoreNEURON
nmodl::parser::UnitDriver::parse_string
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
neuron_solve_visitor.hpp
Visitor that solves ODEs using old solvers of NEURON
checkparent_visitor.hpp
Visitor for checking parents of ast nodes
inline_visitor.hpp
Visitor to inline local procedure and function calls
nmodl_driver.hpp
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
SCENARIO
SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]")
Definition: sympy_solver.cpp:258
is_unique_vars
bool is_unique_vars(std::string result)
Definition: sympy_solver.cpp:88
symtab_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
sympy_solver_visitor.hpp
Visitor for systems of algebraic and differential equations
nmodl_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.