8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
30 using namespace nmodl;
31 using namespace codegen;
32 using namespace visitor;
34 using namespace test_utils;
36 using Catch::Matchers::ContainsSubstring;
49 const std::string& text,
52 AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
53 bool kinetic =
false) {
54 std::vector<std::string> results;
61 SymtabVisitor().visit_program(*ast);
64 ConstantFolderVisitor().visit_program(*ast);
65 LoopUnrollVisitor().visit_program(*ast);
66 ConstantFolderVisitor().visit_program(*ast);
67 SymtabVisitor().visit_program(*ast);
70 KineticBlockVisitor().visit_program(*ast);
74 SympySolverVisitor(pade, cse).visit_program(*ast);
77 CheckParentVisitor().check_ast(*ast);
89 result.erase(std::remove(result.begin(), result.end(),
','), result.end());
90 std::stringstream ss(result);
93 std::unordered_set<std::string> old_vars;
95 while (getline(ss, token,
' ')) {
96 if (!old_vars.insert(token).second) {
130 const std::string& expected,
131 const bool require_fail =
false) {
132 using namespace pybind11::literals;
135 pybind11::dict(
"result"_a = result,
"expected"_a = expected,
"is_equal"_a =
false);
137 # Comments are in the doxygen for better highlighting
138 def compare_blocks(result, expected):
142 d = {'\[(\d+)\]':'_\\1', 'pow\((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
144 for key, val in d.items():
145 out = re.sub(key, val, out)
148 def compare_systems_of_eq(result_dict, expected_dict):
149 from sympy.parsing.sympy_parser import parse_expr
151 for k, v in result_dict.items():
152 if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
158 expected_dict.clear()
166 # split of sout and a dict with the tmp variables
167 for line in s.split('\n'):
168 line_split = line.lstrip().split('=')
170 if len(line_split) == 2 and line_split[0].startswith('tmp_'):
171 # back-substitution of tmp variables in tmp variables
172 tmp_var = line_split[0].strip()
176 max_tmp = max(max_tmp, int(tmp_var[4:]))
177 for k, v in d.items():
178 line_split[1] = line_split[1].replace(k, f'({v})')
179 d[tmp_var] = line_split[1]
180 elif 'LOCAL' in line:
181 sout += line.split('tmp_0')[0] + '\n'
185 # Back-substitution of the tmps
186 # so that we do not replace tmp_11 with (tmp_1)1
187 for j in range(max_tmp, -1, -1):
189 sout = sout.replace(k, f'({d[k]})')
193 result = reduce(sanitize(result)).split('\n')
194 expected = reduce(sanitize(expected)).split('\n')
196 if len(result) != len(expected):
201 for token1, token2 in zip(result, expected):
203 if not compare_systems_of_eq(result_dict, expected_dict):
207 eq1 = token1.split('=')
208 eq2 = token2.split('=')
209 if len(eq1) == 2 and len(eq2) == 2:
210 result_dict[eq1[0]] = eq1[1]
211 expected_dict[eq2[0]] = eq2[1]
215 return compare_systems_of_eq(result_dict, expected_dict)
217 is_equal = compare_blocks(result, expected))",
222 if (require_fail == locals[
"is_equal"].cast<bool>()) {
224 REQUIRE(result != expected);
226 REQUIRE(result == expected);
236 SymtabVisitor v_symtab;
237 v_symtab.visit_program(node);
240 SympySolverVisitor v_sympy1;
241 v_sympy1.visit_program(node);
242 v_sympy1.visit_program(node);
245 SympySolverVisitor v_sympy2;
246 v_sympy2.visit_program(node);
247 v_sympy1.visit_program(node);
248 v_sympy2.visit_program(node);
253 std::stringstream stream;
254 NmodlPrintVisitor(stream).visit_program(node);
258 SCENARIO(
"Check compare_blocks in sympy unit tests",
"[visitor][sympy]") {
259 GIVEN(
"Empty strings") {
260 THEN(
"Strings are equal") {
264 GIVEN(
"Equivalent equation") {
265 THEN(
"Strings are equal") {
269 GIVEN(
"Equivalent systems of equations") {
270 std::string result = R
"(
273 std::string expected = R"(
276 THEN("Systems of equations are equal") {
280 GIVEN(
"Equivalent systems of equations with brackets") {
281 std::string result = R
"(
286 std::string expected = R"(
291 y = pow(a, 2)*a + 2*b-b
293 THEN("Blocks are equal") {
297 GIVEN(
"Different systems of equations (additional space)") {
298 std::string result = R
"(
303 std::string expected = R"(
308 THEN("Blocks are different") {
312 GIVEN(
"Different systems of equations") {
313 std::string result = R
"(
320 std::string expected = R"(
325 THEN("Blocks are different") {
331 SCENARIO(
"Check local vars name-clash prevention",
"[visitor][sympy]") {
333 std::string nmodl_text = R
"(
338 SOLVE states METHOD sparse
345 THEN("There are no duplicate vars in LOCAL") {
348 REQUIRE(!result.empty());
352 GIVEN(
"LOCAL tmp_0") {
353 std::string nmodl_text = R
"(
358 SOLVE states METHOD sparse
365 THEN("There are no duplicate vars in LOCAL") {
368 REQUIRE(!result.empty());
374 SCENARIO(
"Solve ODEs with cnexp or euler method using SympySolverVisitor",
375 "[visitor][sympy][cnexp][euler]") {
376 GIVEN(
"Derivative block without ODE, solver method cnexp") {
377 std::string nmodl_text = R
"(
379 SOLVE states METHOD cnexp
385 THEN("No ODEs found - do nothing") {
387 REQUIRE(result.empty());
390 GIVEN(
"Derivative block with ODES, solver method is euler") {
391 std::string nmodl_text = R
"(
393 SOLVE states METHOD euler
401 THEN("Construct forwards Euler solutions") {
403 REQUIRE(result.size() == 2);
404 REQUIRE(result[0] ==
"m = (-dt*(m-mInf)+m*mTau)/mTau");
405 REQUIRE(result[1] ==
"h = (-dt*(h-hInf)+h*hTau)/hTau");
408 GIVEN(
"Derivative block with calling external functions passes sympy") {
409 std::string nmodl_text = R
"(
411 SOLVE states METHOD euler
419 THEN("Construct forward Euler interpreting external functions as symbols") {
421 REQUIRE(result.size() == 3);
422 REQUIRE(result[0] ==
"m = dt*sawtooth(m)+m");
423 REQUIRE(result[1] ==
"n = dt*sin(n)+n");
424 REQUIRE(result[2] ==
"p = dt*my_user_func(p)+p");
427 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method euler") {
428 std::string nmodl_text = R
"(
433 SOLVE states METHOD euler
436 m'[0] = (mInf-m[0])/mTau
439 THEN("Construct forwards Euler solutions") {
441 REQUIRE(result.size() == 1);
442 REQUIRE(result[0] ==
"m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
445 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method cnexp") {
446 std::string nmodl_text = R
"(
451 SOLVE states METHOD cnexp
454 m'[0] = (mInf-m[0])/mTau
457 THEN("Construct forwards Euler solutions") {
459 REQUIRE(result.size() == 1);
460 REQUIRE(result[0] ==
"m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
463 GIVEN(
"Derivative block with linear ODES, solver method cnexp") {
464 std::string nmodl_text = R
"(
466 SOLVE states METHOD cnexp
471 h' = hInf/hTau - h/hTau
474 THEN("Integrate equations analytically") {
476 REQUIRE(result.size() == 2);
477 REQUIRE(result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
478 REQUIRE(result[1] ==
"h = hInf-(-h+hInf)*exp(-dt/hTau)");
481 GIVEN(
"Derivative block including non-linear but solvable ODES, solver method cnexp") {
482 std::string nmodl_text = R
"(
484 SOLVE states METHOD cnexp
491 THEN("Integrate equations analytically") {
493 REQUIRE(result.size() == 2);
494 REQUIRE(result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
495 REQUIRE(result[1] ==
"h = -h/(c2*dt*h-1.0)");
498 GIVEN(
"Derivative block including array of 2 state vars, solver method cnexp") {
499 std::string nmodl_text = R
"(
501 SOLVE states METHOD cnexp
507 X'[0] = (mInf-X[0])/mTau
508 X'[1] = c2 * X[1]*X[1]
511 THEN("Integrate equations analytically") {
513 REQUIRE(result.size() == 2);
514 REQUIRE(result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
515 REQUIRE(result[1] ==
"X[1] = -X[1]/(c2*dt*X[1]-1.0)");
518 GIVEN(
"Derivative block including loop over array vars, solver method cnexp") {
519 std::string nmodl_text = R
"(
522 SOLVE states METHOD cnexp
532 X'[i] = (mInf-X[i])/mTau[i]
536 THEN("Integrate equations analytically") {
538 REQUIRE(result.size() == 3);
539 REQUIRE(result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
540 REQUIRE(result[1] ==
"X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
541 REQUIRE(result[2] ==
"X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
544 GIVEN(
"Derivative block including loop over array vars, solver method euler") {
545 std::string nmodl_text = R
"(
548 SOLVE states METHOD euler
558 X'[i] = (mInf-X[i])/mTau[i]
562 THEN("Integrate equations analytically") {
564 REQUIRE(result.size() == 3);
565 REQUIRE(result[0] ==
"X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
566 REQUIRE(result[1] ==
"X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
567 REQUIRE(result[2] ==
"X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
570 GIVEN(
"Derivative block including ODES that can't currently be solved, solver method cnexp") {
571 std::string nmodl_text = R
"(
573 SOLVE states METHOD cnexp
582 THEN("Integrate equations analytically where possible, otherwise leave untouched") {
584 REQUIRE(result.size() == 4);
586 REQUIRE((result[0] ==
"z' = a/z+b/z/z" ||
588 "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
589 REQUIRE(result[1] ==
"h = -h/(c2*dt*h-1.0)");
590 REQUIRE(result[2] ==
"x = a*dt+x");
592 REQUIRE((result[3] ==
"y' = c3*y*y*y" ||
593 result[3] ==
"y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
596 GIVEN(
"Derivative block with cnexp solver method, AST after SympySolver pass") {
597 std::string nmodl_text = R
"(
599 SOLVE states METHOD cnexp
610 SymtabVisitor().visit_program(*ast);
613 SympySolverVisitor().visit_program(*ast);
617 THEN(
"More SympySolver passes do nothing to the AST and don't throw") {
624 SCENARIO(
"Solve ODEs with derivimplicit method using SympySolverVisitor",
625 "[visitor][sympy][derivimplicit]") {
626 GIVEN(
"Derivative block with derivimplicit solver method and conditional block") {
627 std::string nmodl_text = R
"(
632 SOLVE states METHOD derivimplicit
641 std::string expected_result = R"(
643 EIGEN_NEWTON_SOLVE[1]{
653 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
654 nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
660 THEN("SympySolver correctly inserts ode to block") {
668 GIVEN(
"Derivative block, sparse, print in order") {
669 std::string nmodl_text = R
"(
674 SOLVE states METHOD sparse
681 std::string expected_result = R"(
683 EIGEN_NEWTON_SOLVE[2]{
684 LOCAL a, b, old_y, old_x
692 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
694 nmodl_eigen_j[2] = -1/dt
695 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
696 nmodl_eigen_j[1] = -1/dt
705 THEN("Construct & solve linear system for backwards Euler") {
712 GIVEN(
"Derivative block, sparse, print in order, vectors") {
713 std::string nmodl_text = R
"(
718 SOLVE states METHOD sparse
725 std::string expected_result = R"(
727 EIGEN_NEWTON_SOLVE[2]{
728 LOCAL a, b, old_M_1, old_M_0
733 nmodl_eigen_x[0] = M[0]
734 nmodl_eigen_x[1] = M[1]
736 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
738 nmodl_eigen_j[2] = -1/dt
739 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
740 nmodl_eigen_j[1] = -1/dt
743 M[0] = nmodl_eigen_x[0]
744 M[1] = nmodl_eigen_x[1]
749 THEN("Construct & solve linear system for backwards Euler") {
756 GIVEN(
"Derivative block, sparse, derivatives mixed with local variable reassignment") {
757 std::string nmodl_text = R
"(
762 SOLVE states METHOD sparse
770 std::string expected_result = R"(
772 EIGEN_NEWTON_SOLVE[2]{
773 LOCAL a, b, old_x, old_y
781 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
782 nmodl_eigen_j[0] = -1/dt
785 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
787 nmodl_eigen_j[3] = -1/dt
795 THEN("Construct & solve linear system for backwards Euler") {
803 "Throw exception during derivative variable reassignment interleaved in the differential "
805 std::string nmodl_text = R
"(
810 SOLVE states METHOD sparse
820 "Throw an error because state variable assignments are not allowed inside the system "
825 Catch::Matchers::ContainsSubstring(
826 "State variable assignment(s) interleaved in system of "
827 "equations/differential equations") &&
828 Catch::Matchers::StartsWith(
"SympyReplaceSolutionsVisitor"));
831 GIVEN(
"Derivative block in control flow block") {
832 std::string nmodl_text = R
"(
837 SOLVE states METHOD sparse
846 std::string expected_result = R"(
850 EIGEN_NEWTON_SOLVE[2]{
859 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
860 nmodl_eigen_j[0] = -1/dt
862 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
864 nmodl_eigen_j[3] = -1/dt
873 THEN("Construct & solve linear system for backwards Euler") {
881 "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
883 std::string nmodl_text = R
"(
888 SOLVE states METHOD sparse
898 std::string expected_result = R"(
900 EIGEN_NEWTON_SOLVE[2]{
901 LOCAL a, b, old_x, old_y
909 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
910 nmodl_eigen_j[0] = -1/dt
915 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
916 nmodl_eigen_j[1] = 1.0
917 nmodl_eigen_j[3] = a-1/dt
924 std::string expected_result_cse = R"(
926 EIGEN_NEWTON_SOLVE[2]{
927 LOCAL a, b, old_x, old_y
935 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
936 nmodl_eigen_j[0] = -1/dt
941 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
942 nmodl_eigen_j[1] = 1.0
943 nmodl_eigen_j[3] = a-1/dt
951 THEN("Construct & solve linear system for backwards Euler") {
962 GIVEN(
"Derivative block of coupled & linear ODES, solver method sparse") {
963 std::string nmodl_text = R
"(
968 SOLVE states METHOD sparse
977 std::string expected_result = R"(
979 EIGEN_NEWTON_SOLVE[3]{
980 LOCAL a, b, c, d, h, old_x, old_y, old_z
990 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
991 nmodl_eigen_j[0] = -1/dt
994 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
995 nmodl_eigen_j[1] = 2.0
996 nmodl_eigen_j[4] = -1/dt
998 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1000 nmodl_eigen_j[5] = -1.0
1001 nmodl_eigen_j[8] = d-1/dt
1003 x = nmodl_eigen_x[0]
1004 y = nmodl_eigen_x[1]
1005 z = nmodl_eigen_x[2]
1009 std::string expected_cse_result = R"(
1011 EIGEN_NEWTON_SOLVE[3]{
1012 LOCAL a, b, c, d, h, old_x, old_y, old_z
1018 nmodl_eigen_x[0] = x
1019 nmodl_eigen_x[1] = y
1020 nmodl_eigen_x[2] = z
1022 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1023 nmodl_eigen_j[0] = -1/dt
1024 nmodl_eigen_j[3] = 0
1025 nmodl_eigen_j[6] = a
1026 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1027 nmodl_eigen_j[1] = 2.0
1028 nmodl_eigen_j[4] = -1/dt
1029 nmodl_eigen_j[7] = 0
1030 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1031 nmodl_eigen_j[2] = 0
1032 nmodl_eigen_j[5] = -1.0
1033 nmodl_eigen_j[8] = d-1/dt
1035 x = nmodl_eigen_x[0]
1036 y = nmodl_eigen_x[1]
1037 z = nmodl_eigen_x[2]
1042 THEN("Construct & solve linear system for backwards Euler") {
1052 GIVEN(
"Derivative block including ODES with sparse method (from nmodl paper)") {
1053 std::string nmodl_text = R
"(
1058 SOLVE scheme1 METHOD sparse
1060 DERIVATIVE scheme1 {
1065 std::string expected_result = R"(
1066 DERIVATIVE scheme1 {
1067 EIGEN_NEWTON_SOLVE[2]{
1073 nmodl_eigen_x[0] = mc
1074 nmodl_eigen_x[1] = m
1076 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1077 nmodl_eigen_j[0] = -a-1/dt
1078 nmodl_eigen_j[2] = b
1079 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1080 nmodl_eigen_j[1] = a
1081 nmodl_eigen_j[3] = -b-1/dt
1083 mc = nmodl_eigen_x[0]
1084 m = nmodl_eigen_x[1]
1088 THEN("Construct & solve linear system") {
1089 CAPTURE(nmodl_text);
1095 GIVEN(
"Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1096 std::string nmodl_text = R
"(
1101 SOLVE scheme1 METHOD sparse
1103 DERIVATIVE scheme1 {
1109 std::string expected_result = R"(
1110 DERIVATIVE scheme1 {
1111 EIGEN_NEWTON_SOLVE[2]{
1116 nmodl_eigen_x[0] = mc
1117 nmodl_eigen_x[1] = m
1119 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1120 nmodl_eigen_j[0] = -a-1/dt
1121 nmodl_eigen_j[2] = b
1122 nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1123 nmodl_eigen_j[1] = -1.0
1124 nmodl_eigen_j[3] = -1.0
1126 mc = nmodl_eigen_x[0]
1127 m = nmodl_eigen_x[1]
1131 THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1132 CAPTURE(nmodl_text);
1139 "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1141 std::string nmodl_text = R
"(
1146 SOLVE scheme1 METHOD sparse
1148 DERIVATIVE scheme1 {
1154 std::string expected_result = R"(
1155 DERIVATIVE scheme1 {
1156 EIGEN_NEWTON_SOLVE[2]{
1162 nmodl_eigen_x[0] = mc
1163 nmodl_eigen_x[1] = m
1165 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1166 nmodl_eigen_j[0] = -a-1/dt
1167 nmodl_eigen_j[2] = b
1168 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1169 nmodl_eigen_j[1] = a
1170 nmodl_eigen_j[3] = -b-1/dt
1172 mc = nmodl_eigen_x[0]
1173 m = nmodl_eigen_x[1]
1177 THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1178 CAPTURE(nmodl_text);
1184 GIVEN(
"Derivative block with ODES with sparse method, two CONSERVE statements") {
1185 std::string nmodl_text = R
"(
1190 SOLVE ihkin METHOD sparse
1193 LOCAL alpha, beta, k3p, k4, k1ca, k2
1194 evaluate_fct(v, cai)
1196 CONSERVE o2 = 1-c1-o1
1197 c1' = (-1*(alpha*c1-beta*o1))
1198 o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1199 o2' = (1*(k3p*o1-k4*o2))
1200 p0' = (-1*(k1ca*p0-k2*p1))
1201 p1' = (1*(k1ca*p0-k2*p1))
1203 std::string expected_result = R"(
1205 EIGEN_NEWTON_SOLVE[5]{
1206 LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1208 evaluate_fct(v, cai)
1213 nmodl_eigen_x[0] = c1
1214 nmodl_eigen_x[1] = o1
1215 nmodl_eigen_x[2] = o2
1216 nmodl_eigen_x[3] = p0
1217 nmodl_eigen_x[4] = p1
1219 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1220 nmodl_eigen_j[0] = -alpha-1/dt
1221 nmodl_eigen_j[5] = beta
1222 nmodl_eigen_j[10] = 0
1223 nmodl_eigen_j[15] = 0
1224 nmodl_eigen_j[20] = 0
1225 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1226 nmodl_eigen_j[1] = alpha
1227 nmodl_eigen_j[6] = -beta-k3p-1/dt
1228 nmodl_eigen_j[11] = k4
1229 nmodl_eigen_j[16] = 0
1230 nmodl_eigen_j[21] = 0
1231 nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1232 nmodl_eigen_j[2] = -1.0
1233 nmodl_eigen_j[7] = -1.0
1234 nmodl_eigen_j[12] = -1.0
1235 nmodl_eigen_j[17] = 0
1236 nmodl_eigen_j[22] = 0
1237 nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1238 nmodl_eigen_j[3] = 0
1239 nmodl_eigen_j[8] = 0
1240 nmodl_eigen_j[13] = 0
1241 nmodl_eigen_j[18] = -k1ca-1/dt
1242 nmodl_eigen_j[23] = k2
1243 nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1244 nmodl_eigen_j[4] = 0
1245 nmodl_eigen_j[9] = 0
1246 nmodl_eigen_j[14] = 0
1247 nmodl_eigen_j[19] = -1.0
1248 nmodl_eigen_j[24] = -1.0
1250 c1 = nmodl_eigen_x[0]
1251 o1 = nmodl_eigen_x[1]
1252 o2 = nmodl_eigen_x[2]
1253 p0 = nmodl_eigen_x[3]
1254 p1 = nmodl_eigen_x[4]
1259 "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1260 "algebraic relations") {
1261 CAPTURE(nmodl_text);
1267 GIVEN(
"Derivative block including ODES with sparse method - single var in array") {
1268 std::string nmodl_text = R
"(
1277 SOLVE scheme1 METHOD sparse
1279 DERIVATIVE scheme1 {
1280 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1283 std::string expected_result = R"(
1284 DERIVATIVE scheme1 {
1285 EIGEN_NEWTON_SOLVE[1]{
1290 nmodl_eigen_x[0] = W[0]
1292 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1293 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1295 W[0] = nmodl_eigen_x[0]
1299 THEN("Construct & solver linear system") {
1300 CAPTURE(nmodl_text);
1306 GIVEN(
"Derivative block including ODES with sparse method - array vars") {
1307 std::string nmodl_text = R
"(
1316 SOLVE scheme1 METHOD sparse
1318 DERIVATIVE scheme1 {
1319 M'[0] = -A[0]*M[0] + B[0]*M[1]
1320 M'[1] = A[1]*M[0] - B[1]*M[1]
1323 std::string expected_result = R"(
1324 DERIVATIVE scheme1 {
1325 EIGEN_NEWTON_SOLVE[2]{
1326 LOCAL old_M_0, old_M_1
1331 nmodl_eigen_x[0] = M[0]
1332 nmodl_eigen_x[1] = M[1]
1334 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1335 nmodl_eigen_j[0] = -A[0]-1/dt
1336 nmodl_eigen_j[2] = B[0]
1337 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1338 nmodl_eigen_j[1] = A[1]
1339 nmodl_eigen_j[3] = -B[1]-1/dt
1341 M[0] = nmodl_eigen_x[0]
1342 M[1] = nmodl_eigen_x[1]
1346 THEN("Construct & solver linear system") {
1347 CAPTURE(nmodl_text);
1353 GIVEN(
"Derivative block including ODES with derivimplicit method - single var in array") {
1354 std::string nmodl_text = R
"(
1363 SOLVE scheme1 METHOD derivimplicit
1365 DERIVATIVE scheme1 {
1366 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1369 std::string expected_result = R"(
1370 DERIVATIVE scheme1 {
1371 EIGEN_NEWTON_SOLVE[1]{
1376 nmodl_eigen_x[0] = W[0]
1378 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1379 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1381 W[0] = nmodl_eigen_x[0]
1385 THEN("Construct newton solve block") {
1386 CAPTURE(nmodl_text);
1392 GIVEN(
"Derivative block including ODES with derivimplicit method") {
1393 std::string nmodl_text = R
"(
1398 SOLVE states METHOD derivimplicit
1402 m' = (minf-m)/mtau - 3*h
1403 h' = (hinf-h)/htau + m*m
1408 std::string expected_result = R
"(
1410 EIGEN_NEWTON_SOLVE[3]{
1411 LOCAL old_m, old_h, old_n
1418 nmodl_eigen_x[0] = m
1419 nmodl_eigen_x[1] = h
1420 nmodl_eigen_x[2] = n
1422 nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1423 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1424 nmodl_eigen_j[3] = -3.0
1425 nmodl_eigen_j[6] = 0
1426 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1427 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1428 nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1429 nmodl_eigen_j[7] = 0
1430 nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1431 nmodl_eigen_j[2] = 0
1432 nmodl_eigen_j[5] = 0
1433 nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1435 m = nmodl_eigen_x[0]
1436 h = nmodl_eigen_x[1]
1437 n = nmodl_eigen_x[2]
1441 THEN("Construct newton solve block") {
1442 CAPTURE(nmodl_text);
1448 GIVEN(
"Multiple derivative blocks each with derivimplicit method") {
1449 std::string nmodl_text = R
"(
1454 SOLVE states1 METHOD derivimplicit
1455 SOLVE states2 METHOD derivimplicit
1458 DERIVATIVE states1 {
1460 h' = (hinf-h)/htau + m*m
1463 DERIVATIVE states2 {
1464 h' = (hinf-h)/htau + m*m
1465 m' = (minf-m)/mtau + h
1469 std::string expected_result_0 = R
"(
1470 DERIVATIVE states1 {
1471 EIGEN_NEWTON_SOLVE[2]{
1477 nmodl_eigen_x[0] = m
1478 nmodl_eigen_x[1] = h
1480 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1481 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1482 nmodl_eigen_j[2] = 0
1483 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1484 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1485 nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1487 m = nmodl_eigen_x[0]
1488 h = nmodl_eigen_x[1]
1492 std::string expected_result_1 = R"(
1493 DERIVATIVE states2 {
1494 EIGEN_NEWTON_SOLVE[2]{
1500 nmodl_eigen_x[0] = m
1501 nmodl_eigen_x[1] = h
1503 nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1504 nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1505 nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1506 nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1507 nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1508 nmodl_eigen_j[3] = 1.0
1510 m = nmodl_eigen_x[0]
1511 h = nmodl_eigen_x[1]
1515 THEN("Construct newton solve block") {
1518 CAPTURE(nmodl_text);
1530 SCENARIO(
"LINEAR solve block (SympySolver Visitor)",
"[sympy][linear]") {
1531 GIVEN(
"1 state-var symbolic LINEAR solve block") {
1532 std::string nmodl_text = R
"(
1539 std::string expected_text = R"(
1543 THEN("solve analytically") {
1549 GIVEN(
"2 state-var LINEAR solve block") {
1550 std::string nmodl_text = R
"(
1558 std::string expected_text = R"(
1563 THEN("solve analytically") {
1569 GIVEN(
"Linear block, print in order, vectors") {
1570 std::string nmodl_text = R
"(
1578 std::string expected_result = R"(
1584 THEN("Construct & solve linear system") {
1591 GIVEN(
"Linear block, by value replacement, interleaved") {
1592 std::string nmodl_text = R
"(
1604 std::string expected_result = R"(
1614 THEN("Construct & solve linear system") {
1621 GIVEN(
"Linear block in control flow block") {
1622 std::string nmodl_text = R
"(
1633 std::string expected_result = R"(
1642 THEN("Construct & solve linear system") {
1649 GIVEN(
"Linear block, linear equations mixed with control flow blocks and reassignments") {
1650 std::string nmodl_text = R
"(
1663 std::string expected_result = R"(
1674 THEN("Construct & solve linear system") {
1681 GIVEN(
"4 state-var LINEAR solve block") {
1682 std::string nmodl_text = R
"(
1687 ~ w + z/3.2 = -2.0*y
1688 ~ x + 4*c*y = -5.343*a
1689 ~ a + x/b + z - y = 0.842*b*b
1690 ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1692 std::string expected_text = R"(
1694 EIGEN_LINEAR_SOLVE[4]{
1697 nmodl_eigen_x[0] = w
1698 nmodl_eigen_x[1] = x
1699 nmodl_eigen_x[2] = y
1700 nmodl_eigen_x[3] = z
1701 nmodl_eigen_f[0] = 0
1702 nmodl_eigen_f[1] = 5.343*a
1703 nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1704 nmodl_eigen_f[3] = -1.43543/c
1705 nmodl_eigen_j[0] = -1.0
1706 nmodl_eigen_j[4] = 0
1707 nmodl_eigen_j[8] = -2.0
1708 nmodl_eigen_j[12] = -0.3125
1709 nmodl_eigen_j[1] = 0
1710 nmodl_eigen_j[5] = -1.0
1711 nmodl_eigen_j[9] = -4.0*c
1712 nmodl_eigen_j[13] = 0
1713 nmodl_eigen_j[2] = 0
1714 nmodl_eigen_j[6] = -1/b
1715 nmodl_eigen_j[10] = 1.0
1716 nmodl_eigen_j[14] = -1.0
1717 nmodl_eigen_j[3] = 0
1718 nmodl_eigen_j[7] = -1.0
1719 nmodl_eigen_j[11] = -1.3
1720 nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1722 w = nmodl_eigen_x[0]
1723 x = nmodl_eigen_x[1]
1724 y = nmodl_eigen_x[2]
1725 z = nmodl_eigen_x[3]
1729 THEN("return matrix system to solve") {
1736 GIVEN(
"LINEAR solve block with an explicit SOLVEFOR statement") {
1737 std::string nmodl_text = R
"(
1743 LINEAR lin SOLVEFOR x, y {
1747 std::string expected_text = R"(
1748 LINEAR lin SOLVEFOR x,y{
1749 y = (v+15.0)/(3.0*z+1.0)
1750 x = (v*z-5.0)/(3.0*z+1.0)
1752 THEN("solve analytically") {
1764 SCENARIO(
"Solve NONLINEAR block using SympySolver Visitor",
"[visitor][solver][sympy][nonlinear]") {
1765 GIVEN(
"1 state-var numeric NONLINEAR solve block") {
1766 std::string nmodl_text = R
"(
1773 std::string expected_text = R"(
1775 EIGEN_NEWTON_SOLVE[1]{
1778 nmodl_eigen_x[0] = x
1780 nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1781 nmodl_eigen_j[0] = -1.0
1783 x = nmodl_eigen_x[0]
1788 THEN("return F & J for newton solver") {
1794 GIVEN(
"array state-var numeric NONLINEAR solve block") {
1795 std::string nmodl_text = R
"(
1802 ~ s[2] + s[1] = s[0]
1804 std::string expected_text = R"(
1806 EIGEN_NEWTON_SOLVE[3]{
1809 nmodl_eigen_x[0] = s[0]
1810 nmodl_eigen_x[1] = s[1]
1811 nmodl_eigen_x[2] = s[2]
1813 nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1814 nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1815 nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1816 nmodl_eigen_j[0] = -1.0
1817 nmodl_eigen_j[3] = 0
1818 nmodl_eigen_j[6] = 0
1819 nmodl_eigen_j[1] = 0
1820 nmodl_eigen_j[4] = -1.0
1821 nmodl_eigen_j[7] = 0
1822 nmodl_eigen_j[2] = 1.0
1823 nmodl_eigen_j[5] = -1.0
1824 nmodl_eigen_j[8] = -1.0
1826 s[0] = nmodl_eigen_x[0]
1827 s[1] = nmodl_eigen_x[1]
1828 s[2] = nmodl_eigen_x[2]
1832 THEN("return F & J for newton solver") {
1839 SCENARIO(
"Solve KINETIC block using SympySolver Visitor",
"[visitor][solver][sympy][kinetic]") {
1840 GIVEN(
"KINETIC block with not inlined function should work") {
1841 std::string nmodl_text = R
"(
1843 SOLVE kstates METHOD sparse
1849 FUNCTION alfa(v(mV)) {
1853 ~ C1 <-> C2 (alfa(v), alfa(v))
1855 std::string expected_text = R"(
1856 DERIVATIVE kstates {
1857 EIGEN_NEWTON_SOLVE[2]{
1858 LOCAL kf0_, kb0_, old_C1, old_C2
1865 nmodl_eigen_x[0] = C1
1866 nmodl_eigen_x[1] = C2
1868 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1869 nmodl_eigen_j[0] = -kf0_-1/dt
1870 nmodl_eigen_j[2] = kb0_
1871 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1872 nmodl_eigen_j[1] = kf0_
1873 nmodl_eigen_j[3] = -kb0_-1/dt
1875 C1 = nmodl_eigen_x[0]
1876 C2 = nmodl_eigen_x[1]
1880 THEN("Run Kinetic and Sympy Visitor") {
1881 std::vector<std::string> result;
1883 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));
1887 GIVEN(
"Protected names in Sympy are respected") {
1888 std::string nmodl_text = R
"(
1890 SOLVE kstates METHOD sparse
1896 FUNCTION beta(v(mV)) {
1899 FUNCTION lowergamma(v(mV)) {
1903 ~ C1 <-> C2 (beta(v), lowergamma(v))
1905 std::string expected_text = R"(
1906 DERIVATIVE kstates {
1907 EIGEN_NEWTON_SOLVE[2]{
1908 LOCAL kf0_, kb0_, old_C1, old_C2
1911 kb0_ = lowergamma(v)
1915 nmodl_eigen_x[0] = C1
1916 nmodl_eigen_x[1] = C2
1918 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1919 nmodl_eigen_j[0] = -kf0_-1/dt
1920 nmodl_eigen_j[2] = kb0_
1921 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1922 nmodl_eigen_j[1] = kf0_
1923 nmodl_eigen_j[3] = -kb0_-1/dt
1925 C1 = nmodl_eigen_x[0]
1926 C2 = nmodl_eigen_x[1]
1930 THEN("Run Kinetic and Sympy Visitor") {
1931 std::vector<std::string> result;
1933 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));