User Guide
loop_unroll.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 
10 #include "ast/program.hpp"
11 #include "parser/nmodl_driver.hpp"
19 
20 using namespace nmodl;
21 using namespace visitor;
22 using namespace test;
23 using namespace test_utils;
24 
25 using ast::AstNodeType;
27 
28 //=============================================================================
29 // Loop unroll tests
30 //=============================================================================
31 
32 std::string run_loop_unroll_visitor(const std::string& text) {
34  const auto& ast = driver.parse_string(text);
35 
36  SymtabVisitor().visit_program(*ast);
37  ConstantFolderVisitor().visit_program(*ast);
38  LoopUnrollVisitor().visit_program(*ast);
39  ConstantFolderVisitor().visit_program(*ast);
40 
41  // check that, after visitor rearrangement, parents are still up-to-date
42  CheckParentVisitor().check_ast(*ast);
43 
44  return to_nmodl(ast, {AstNodeType::DEFINE});
45 }
46 
47 SCENARIO("Perform loop unrolling of FROM construct", "[visitor][unroll]") {
48  GIVEN("A loop with known iteration space") {
49  std::string input_nmodl = R"(
50  DEFINE N 2
51  PROCEDURE rates() {
52  LOCAL x[N]
53  FROM i=0 TO N {
54  x[i] = x[i] + 11
55  }
56  FROM i=(0+(0+1)) TO (N+2-1) {
57  x[(i+0)] = x[i+1] + 11
58  }
59  }
60  KINETIC state {
61  FROM i=1 TO N+1 {
62  ~ ca[i] <-> ca[i+1] (DFree*frat[i+1]*1(um), DFree*frat[i+1]*1(um))
63  }
64  }
65  )";
66  std::string output_nmodl = R"(
67  PROCEDURE rates() {
68  LOCAL x[N]
69  {
70  x[0] = x[0]+11
71  x[1] = x[1]+11
72  x[2] = x[2]+11
73  }
74  {
75  x[1] = x[2]+11
76  x[2] = x[3]+11
77  x[3] = x[4]+11
78  }
79  }
80 
81  KINETIC state {
82  {
83  ~ ca[1] <-> ca[2] (DFree*frat[2]*1(um), DFree*frat[2]*1(um))
84  ~ ca[2] <-> ca[3] (DFree*frat[3]*1(um), DFree*frat[3]*1(um))
85  ~ ca[3] <-> ca[4] (DFree*frat[4]*1(um), DFree*frat[4]*1(um))
86  }
87  }
88  )";
89  THEN("Loop body gets correctly unrolled") {
90  auto result = run_loop_unroll_visitor(input_nmodl);
91  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
92  }
93  }
94 
95  GIVEN("A nested loop") {
96  std::string input_nmodl = R"(
97  DEFINE N 1
98  PROCEDURE rates() {
99  LOCAL x[N]
100  FROM i=0 TO N {
101  FROM j=1 TO N+1 {
102  x[i] = x[i+j] + 1
103  }
104  }
105  }
106  )";
107  std::string output_nmodl = R"(
108  PROCEDURE rates() {
109  LOCAL x[N]
110  {
111  {
112  x[0] = x[1]+1
113  x[0] = x[2]+1
114  }
115  {
116  x[1] = x[2]+1
117  x[1] = x[3]+1
118  }
119  }
120  }
121  )";
122  THEN("Loop get unrolled recursively") {
123  auto result = run_loop_unroll_visitor(input_nmodl);
124  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
125  }
126  }
127 
128 
129  GIVEN("Loop with verbatim and unknown iteration space") {
130  std::string input_nmodl = R"(
131  DEFINE N 1
132  PROCEDURE rates() {
133  LOCAL x[N]
134  FROM i=((0+0)) TO (((N+0))) {
135  FROM j=1 TO k {
136  x[i] = x[i+k] + 1
137  }
138  }
139  FROM i=0 TO N {
140  VERBATIM ENDVERBATIM
141  }
142  }
143  )";
144  std::string output_nmodl = R"(
145  PROCEDURE rates() {
146  LOCAL x[N]
147  {
148  FROM j = 1 TO k {
149  x[0] = x[0+k]+1
150  }
151  FROM j = 1 TO k {
152  x[1] = x[1+k]+1
153  }
154  }
155  FROM i = 0 TO N {
156  VERBATIM ENDVERBATIM
157  }
158  }
159  )";
160  THEN("Only some loops get unrolled") {
161  auto result = run_loop_unroll_visitor(input_nmodl);
162  REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
163  }
164  }
165 }
test_utils.hpp
nmodl::parser::NmodlDriver
Class that binds all pieces together for parsing nmodl file.
Definition: nmodl_driver.hpp:67
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
nmodl::test_utils::reindent_text
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:53
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
run_loop_unroll_visitor
std::string run_loop_unroll_visitor(const std::string &text)
Definition: loop_unroll.cpp:32
loop_unroll_visitor.hpp
Unroll for loop in the AST.
constant_folder_visitor.hpp
Perform constant folding of integer/float/double expressions.
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
visitor_utils.hpp
Utility functions for visitors implementation.
program.hpp
Auto generated AST classes declaration.
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
nmodl::parser::UnitDriver::parse_string
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
checkparent_visitor.hpp
Visitor for checking parents of ast nodes
nmodl_driver.hpp
symtab_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
SCENARIO
SCENARIO("Perform loop unrolling of FROM construct", "[visitor][unroll]")
Definition: loop_unroll.cpp:47