User Guide
loop_unroll_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "parser/c11_driver.hpp"
12 #include "utils/logger.hpp"
14 
15 
16 namespace nmodl {
17 namespace visitor {
18 
19 /**
20  * \class IndexRemover
21  * \brief Helper visitor to replace index of array variable with integer
22  *
23  * When loop is unrolled, the index variable like `i` :
24  *
25  * ca[i] <-> ca[i+1]
26  *
27  * has type `Name` in the AST. This needs to be replaced with `Integer`
28  * for optimizations like constant folding. This pass look at name and
29  * binary expressions under index variables.
30  */
31 class IndexRemover: public AstVisitor {
32  private:
33  /// index variable name
34  std::string index;
35 
36  /// integer value of index variable
37  int value;
38 
39  /// true if we are visiting index variable
40  bool under_indexed_name = false;
41 
42  public:
43  IndexRemover(std::string index, int value)
44  : index(std::move(index))
45  , value(value) {}
46 
47  /// if expression we are visiting is `Name` then return new `Integer` node
48  std::shared_ptr<ast::Expression> replace_for_name(
49  const std::shared_ptr<ast::Expression>& node) const {
50  if (node->is_name()) {
51  auto name = std::dynamic_pointer_cast<ast::Name>(node);
52  if (name->get_node_name() == index) {
53  return std::make_shared<ast::Integer>(value, nullptr);
54  }
55  }
56  return node;
57  }
58 
60  node.visit_children(*this);
61  if (under_indexed_name) {
62  /// first recursively replaces children
63  /// replace lhs & rhs if they have matching index variable
64  auto lhs = replace_for_name(node.get_lhs());
65  auto rhs = replace_for_name(node.get_rhs());
66  node.set_lhs(std::move(lhs));
67  node.set_rhs(std::move(rhs));
68  }
69  }
70 
71  void visit_indexed_name(ast::IndexedName& node) override {
72  under_indexed_name = true;
73  node.visit_children(*this);
74  /// once all children are replaced, do the same for index
75  auto length = replace_for_name(node.get_length());
76  node.set_length(std::move(length));
77  under_indexed_name = false;
78  }
79 };
80 
81 
82 /// return underlying expression wrapped by WrappedExpression
83 static std::shared_ptr<ast::Expression> unwrap(const std::shared_ptr<ast::Expression>& expr) {
84  if (expr && expr->is_wrapped_expression()) {
85  const auto& e = std::dynamic_pointer_cast<ast::WrappedExpression>(expr);
86  return e->get_expression();
87  }
88  return expr;
89 }
90 
91 
92 /**
93  * Unroll given for loop
94  *
95  * \param node From loop node in the AST
96  * \return expression statement representing unrolled loop if successful otherwise \a nullptr
97  */
98 static std::shared_ptr<ast::ExpressionStatement> unroll_for_loop(
99  const std::shared_ptr<ast::FromStatement>& node) {
100  /// loop can be in the form of `FROM i=(0) TO (10)`
101  /// so first unwrap all elements of the loop
102  const auto& from = unwrap(node->get_from());
103  const auto& to = unwrap(node->get_to());
104  const auto& increment = unwrap(node->get_increment());
105 
106  /// we can unroll if iteration space of the loop is known
107  /// after constant folding start, end and increment must be known
108  if (!from->is_integer() || !to->is_integer() ||
109  (increment != nullptr && !increment->is_integer())) {
110  return nullptr;
111  }
112 
113  int start = std::dynamic_pointer_cast<ast::Integer>(from)->eval();
114  int end = std::dynamic_pointer_cast<ast::Integer>(to)->eval();
115  int step = 1;
116  if (increment != nullptr) {
117  step = std::dynamic_pointer_cast<ast::Integer>(increment)->eval();
118  }
119 
120  ast::StatementVector statements;
121  std::string index_var = node->get_node_name();
122  for (int i = start; i <= end; i += step) {
123  /// duplicate loop body and copy all statements to new vector
124  const auto new_block = std::unique_ptr<ast::StatementBlock>(
125  node->get_statement_block()->clone());
126  IndexRemover(index_var, i).visit_statement_block(*new_block);
127  statements.insert(statements.end(),
128  new_block->get_statements().begin(),
129  new_block->get_statements().end());
130  }
131 
132  /// create new statement representing unrolled loop
133  auto block = new ast::StatementBlock(std::move(statements));
134  return std::make_shared<ast::ExpressionStatement>(block);
135 }
136 
137 
138 /**
139  * Parse verbatim blocks and rename variable if it is used.
140  */
142  node.visit_children(*this);
143 
144  const auto& statements = node.get_statements();
145 
146  for (auto iter = statements.begin(); iter != statements.end(); ++iter) {
147  if ((*iter)->is_from_statement()) {
148  const auto& statement = std::dynamic_pointer_cast<ast::FromStatement>((*iter));
149 
150  /// check if any verbatim block exists
151  const auto& verbatim_blocks = collect_nodes(*statement, {ast::AstNodeType::VERBATIM});
152  if (!verbatim_blocks.empty()) {
153  logger->debug("LoopUnrollVisitor : can not unroll because of verbatim block");
154  continue;
155  }
156 
157  /// unroll loop, replace current statement on successfull unroll
158  const auto& new_statement = unroll_for_loop(statement);
159  if (new_statement != nullptr) {
160  node.reset_statement(iter, new_statement);
161 
162  const auto& before = to_nmodl(statement);
163  const auto& after = to_nmodl(new_statement);
164  logger->debug("LoopUnrollVisitor : \n {} \n unrolled to \n {}", before, after);
165  }
166  }
167  }
168 }
169 
170 } // namespace visitor
171 } // namespace nmodl
nmodl::ast::IndexedName::get_length
std::shared_ptr< Expression > get_length() const noexcept
Getter for member variable IndexedName::length.
Definition: indexed_name.hpp:176
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::visitor::IndexRemover::index
std::string index
index variable name
Definition: loop_unroll_visitor.cpp:34
nmodl::ast::StatementBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3158
nmodl::visitor::IndexRemover::under_indexed_name
bool under_indexed_name
true if we are visiting index variable
Definition: loop_unroll_visitor.cpp:40
nmodl::ast::StatementVector
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:298
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
loop_unroll_visitor.hpp
Unroll for loop in the AST.
nmodl::visitor::IndexRemover
Helper visitor to replace index of array variable with integer.
Definition: loop_unroll_visitor.cpp:31
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::ast::BinaryExpression::get_rhs
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable BinaryExpression::rhs.
Definition: binary_expression.hpp:179
nmodl::visitor::LoopUnrollVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
Parse verbatim blocks and rename variable if it is used.
Definition: loop_unroll_visitor.cpp:141
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
nmodl::visitor::IndexRemover::replace_for_name
std::shared_ptr< ast::Expression > replace_for_name(const std::shared_ptr< ast::Expression > &node) const
if expression we are visiting is Name then return new Integer node
Definition: loop_unroll_visitor.cpp:48
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::BinaryExpression::set_rhs
void set_rhs(std::shared_ptr< Expression > &&rhs)
Setter for member variable BinaryExpression::rhs (rvalue reference)
Definition: ast.cpp:6612
nmodl::ast::IndexedName
Represents specific element of an array variable.
Definition: indexed_name.hpp:48
nmodl::ast::BinaryExpression::set_lhs
void set_lhs(std::shared_ptr< Expression > &&lhs)
Setter for member variable BinaryExpression::lhs (rvalue reference)
Definition: ast.cpp:6586
c11_driver.hpp
nmodl::visitor::AstVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: ast_visitor.cpp:124
nmodl::visitor::AstVisitor
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
nmodl::visitor::IndexRemover::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
Definition: loop_unroll_visitor.cpp:59
nmodl::visitor::unroll_for_loop
static std::shared_ptr< ast::ExpressionStatement > unroll_for_loop(const std::shared_ptr< ast::FromStatement > &node)
Unroll given for loop.
Definition: loop_unroll_visitor.cpp:98
nmodl::ast::IndexedName::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:1007
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::visitor::IndexRemover::value
int value
integer value of index variable
Definition: loop_unroll_visitor.cpp:37
nmodl::ast::BinaryExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6505
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::ast::StatementBlock::reset_statement
void reset_statement(StatementVector::const_iterator position, Statement *n)
Reset member to statements.
Definition: ast.cpp:3138
logger.hpp
Implement logger based on spdlog library.
nmodl::ast::AstNodeType::VERBATIM
@ VERBATIM
type of ast::Verbatim
nmodl::visitor::IndexRemover::visit_indexed_name
void visit_indexed_name(ast::IndexedName &node) override
visit node of type ast::IndexedName
Definition: loop_unroll_visitor.cpp:71
nmodl::visitor::unwrap
static std::shared_ptr< ast::Expression > unwrap(const std::shared_ptr< ast::Expression > &expr)
return underlying expression wrapped by WrappedExpression
Definition: loop_unroll_visitor.cpp:83
nmodl::ast::BinaryExpression::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable BinaryExpression::lhs.
Definition: binary_expression.hpp:161
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
all.hpp
Auto generated AST classes declaration.
nmodl::ast::IndexedName::set_length
void set_length(std::shared_ptr< Expression > &&length)
Setter for member variable IndexedName::length (rvalue reference)
Definition: ast.cpp:1096
nmodl::visitor::IndexRemover::IndexRemover
IndexRemover(std::string index, int value)
Definition: loop_unroll_visitor.cpp:43