User Guide
steadystate.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 
10 #include "ast/program.hpp"
11 #include "parser/nmodl_driver.hpp"
20 
21 using namespace nmodl;
22 using namespace visitor;
23 using namespace test;
24 using namespace test_utils;
25 
26 using ast::AstNodeType;
28 
29 
30 //=============================================================================
31 // STEADYSTATE visitor tests
32 //=============================================================================
33 
34 std::vector<std::string> run_steadystate_visitor(
35  const std::string& text,
36  const std::vector<AstNodeType>& ret_nodetypes = {AstNodeType::SOLVE_BLOCK,
37  AstNodeType::DERIVATIVE_BLOCK}) {
38  std::vector<std::string> results;
39  // construct AST from text
41  const auto& ast = driver.parse_string(text);
42 
43  // construct symbol table from AST
44  SymtabVisitor().visit_program(*ast);
45 
46  // unroll loops and fold constants
47  ConstantFolderVisitor().visit_program(*ast);
48  LoopUnrollVisitor().visit_program(*ast);
49  ConstantFolderVisitor().visit_program(*ast);
50  SymtabVisitor().visit_program(*ast);
51 
52  // Run kinetic block visitor first, so any kinetic blocks
53  // are converted into derivative blocks
54  KineticBlockVisitor().visit_program(*ast);
55  SymtabVisitor().visit_program(*ast);
56 
57  // run SteadystateVisitor on AST
58  SteadystateVisitor().visit_program(*ast);
59 
60  // run lookup visitor to extract results from AST
61  const auto& res = collect_nodes(*ast, ret_nodetypes);
62  results.reserve(res.size());
63  for (const auto& r: res) {
64  results.push_back(to_nmodl(r));
65  }
66 
67  // check that, after visitor rearrangement, parents are still up-to-date
68  CheckParentVisitor().check_ast(*ast);
69 
70  return results;
71 }
72 
73 SCENARIO("Solving ODEs with STEADYSTATE solve method", "[visitor][steadystate]") {
74  GIVEN("STEADYSTATE sparse solve") {
75  std::string nmodl_text = R"(
76  BREAKPOINT {
77  SOLVE states STEADYSTATE sparse
78  }
79  DERIVATIVE states {
80  m' = m + h
81  }
82  )";
83  std::string expected_text1 = R"(
84  DERIVATIVE states {
85  m' = m+h
86  })";
87  std::string expected_text2 = R"(
88  DERIVATIVE states_steadystate {
89  dt = 1000000000
90  m' = m+h
91  })";
92  THEN("Construct DERIVATIVE block with steadystate solution & update SOLVE statement") {
93  auto result = run_steadystate_visitor(nmodl_text);
94  REQUIRE(result.size() == 3);
95  REQUIRE(result[0] == "SOLVE states_steadystate METHOD sparse");
96  REQUIRE(reindent_text(result[1]) == reindent_text(expected_text1));
97  REQUIRE(reindent_text(result[2]) == reindent_text(expected_text2));
98  }
99  }
100  GIVEN("STEADYSTATE derivimplicit solve") {
101  std::string nmodl_text = R"(
102  BREAKPOINT {
103  SOLVE states STEADYSTATE derivimplicit
104  }
105  DERIVATIVE states {
106  m' = m + h
107  }
108  )";
109  std::string expected_text1 = R"(
110  DERIVATIVE states {
111  m' = m+h
112  })";
113  std::string expected_text2 = R"(
114  DERIVATIVE states_steadystate {
115  dt = 1e-09
116  m' = m+h
117  })";
118  THEN("Construct DERIVATIVE block with steadystate solution & update SOLVE statement") {
119  auto result = run_steadystate_visitor(nmodl_text);
120  REQUIRE(result.size() == 3);
121  REQUIRE(result[0] == "SOLVE states_steadystate METHOD derivimplicit");
122  REQUIRE(reindent_text(result[1]) == reindent_text(expected_text1));
123  REQUIRE(reindent_text(result[2]) == reindent_text(expected_text2));
124  }
125  }
126  GIVEN("two STEADYSTATE solves") {
127  std::string nmodl_text = R"(
128  STATE {
129  Z[3]
130  x
131  }
132  BREAKPOINT {
133  SOLVE states0 STEADYSTATE derivimplicit
134  SOLVE states1 STEADYSTATE sparse
135  }
136  DERIVATIVE states0 {
137  Z'[0] = Z[1] - Z[2]
138  Z'[1] = Z[0] + 2*Z[2]
139  Z'[2] = Z[0]*Z[0] - 3.10
140  }
141  DERIVATIVE states1 {
142  x' = x + c
143  }
144  )";
145  std::string expected_text1 = R"(
146  DERIVATIVE states0 {
147  Z'[0] = Z[1]-Z[2]
148  Z'[1] = Z[0]+2*Z[2]
149  Z'[2] = Z[0]*Z[0]-3.10
150  })";
151  std::string expected_text2 = R"(
152  DERIVATIVE states1 {
153  x' = x+c
154  })";
155  std::string expected_text3 = R"(
156  DERIVATIVE states0_steadystate {
157  dt = 1e-09
158  Z'[0] = Z[1]-Z[2]
159  Z'[1] = Z[0]+2*Z[2]
160  Z'[2] = Z[0]*Z[0]-3.10
161  })";
162  std::string expected_text4 = R"(
163  DERIVATIVE states1_steadystate {
164  dt = 1000000000
165  x' = x+c
166  })";
167  THEN("Construct DERIVATIVE blocks with steadystate solution & update SOLVE statements") {
168  auto result = run_steadystate_visitor(nmodl_text);
169  REQUIRE(result.size() == 6);
170  REQUIRE(result[0] == "SOLVE states0_steadystate METHOD derivimplicit");
171  REQUIRE(result[1] == "SOLVE states1_steadystate METHOD sparse");
172  REQUIRE(reindent_text(result[2]) == reindent_text(expected_text1));
173  REQUIRE(reindent_text(result[3]) == reindent_text(expected_text2));
174  REQUIRE(reindent_text(result[4]) == reindent_text(expected_text3));
175  REQUIRE(reindent_text(result[5]) == reindent_text(expected_text4));
176  }
177  }
178 }
test_utils.hpp
nmodl::parser::NmodlDriver
Class that binds all pieces together for parsing nmodl file.
Definition: nmodl_driver.hpp:67
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
run_steadystate_visitor
std::vector< std::string > run_steadystate_visitor(const std::string &text, const std::vector< AstNodeType > &ret_nodetypes={AstNodeType::SOLVE_BLOCK, AstNodeType::DERIVATIVE_BLOCK})
Definition: steadystate.cpp:34
nmodl::test_utils::reindent_text
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:53
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
loop_unroll_visitor.hpp
Unroll for loop in the AST.
constant_folder_visitor.hpp
Perform constant folding of integer/float/double expressions.
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
kinetic_block_visitor.hpp
Visitor for kinetic block statements
steadystate_visitor.hpp
Visitor for STEADYSTATE solve statements
visitor_utils.hpp
Utility functions for visitors implementation.
program.hpp
Auto generated AST classes declaration.
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
nmodl::parser::UnitDriver::parse_string
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
checkparent_visitor.hpp
Visitor for checking parents of ast nodes
nmodl_driver.hpp
symtab_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
SCENARIO
SCENARIO("Solving ODEs with STEADYSTATE solve method", "[visitor][steadystate]")
Definition: steadystate.cpp:73