User Guide
loop_unroll_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 
11 #include "ast/all.hpp"
12 #include "parser/c11_driver.hpp"
13 #include "utils/logger.hpp"
16 
17 namespace nmodl {
18 namespace visitor {
19 
20 
21 /// return underlying expression wrapped by WrappedExpression
22 static std::shared_ptr<ast::Expression> unwrap(const std::shared_ptr<ast::Expression>& expr) {
23  if (expr && expr->is_wrapped_expression()) {
24  const auto& e = std::dynamic_pointer_cast<ast::WrappedExpression>(expr);
25  return e->get_expression();
26  }
27  return expr;
28 }
29 
30 
31 /**
32  * Unroll given for loop
33  *
34  * \param node From loop node in the AST
35  * \return expression statement representing unrolled loop if successful otherwise \a nullptr
36  */
37 static std::shared_ptr<ast::ExpressionStatement> unroll_for_loop(
38  const std::shared_ptr<ast::FromStatement>& node) {
39  /// loop can be in the form of `FROM i=(0) TO (10)`
40  /// so first unwrap all elements of the loop
41  const auto& from = unwrap(node->get_from());
42  const auto& to = unwrap(node->get_to());
43  const auto& increment = unwrap(node->get_increment());
44 
45  /// we can unroll if iteration space of the loop is known
46  /// after constant folding start, end and increment must be known
47  if (!from->is_integer() || !to->is_integer() ||
48  (increment != nullptr && !increment->is_integer())) {
49  return nullptr;
50  }
51 
52  int start = std::dynamic_pointer_cast<ast::Integer>(from)->eval();
53  int end = std::dynamic_pointer_cast<ast::Integer>(to)->eval();
54  int step = 1;
55  if (increment != nullptr) {
56  step = std::dynamic_pointer_cast<ast::Integer>(increment)->eval();
57  }
58 
59  ast::StatementVector statements;
60  std::string index_var = node->get_node_name();
61  for (int i = start; i <= end; i += step) {
62  /// duplicate loop body and copy all statements to new vector
63  const auto new_block = std::unique_ptr<ast::StatementBlock>(
64  node->get_statement_block()->clone());
65  IndexRemover(index_var, i).visit_statement_block(*new_block);
66  statements.insert(statements.end(),
67  new_block->get_statements().begin(),
68  new_block->get_statements().end());
69  }
70 
71  /// create new statement representing unrolled loop
72  auto block = new ast::StatementBlock(std::move(statements));
73  return std::make_shared<ast::ExpressionStatement>(block);
74 }
75 
76 
77 /**
78  * Parse verbatim blocks and rename variable if it is used.
79  */
81  node.visit_children(*this);
82 
83  const auto& statements = node.get_statements();
84 
85  for (auto iter = statements.begin(); iter != statements.end(); ++iter) {
86  if ((*iter)->is_from_statement()) {
87  const auto& statement = std::dynamic_pointer_cast<ast::FromStatement>((*iter));
88 
89  /// check if any verbatim block exists
90  const auto& verbatim_blocks = collect_nodes(*statement, {ast::AstNodeType::VERBATIM});
91  if (!verbatim_blocks.empty()) {
92  logger->debug("LoopUnrollVisitor : can not unroll because of verbatim block");
93  continue;
94  }
95 
96  /// unroll loop, replace current statement on successfull unroll
97  const auto& new_statement = unroll_for_loop(statement);
98  if (new_statement != nullptr) {
99  node.reset_statement(iter, new_statement);
100 
101  const auto& before = to_nmodl(statement);
102  const auto& after = to_nmodl(new_statement);
103  logger->debug("LoopUnrollVisitor : \n {} \n unrolled to \n {}", before, after);
104  }
105  }
106  }
107 }
108 
109 } // namespace visitor
110 } // namespace nmodl
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
nmodl::ast::StatementBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3162
nmodl::ast::StatementVector
std::vector< std::shared_ptr< Statement > > StatementVector
Definition: ast_decl.hpp:302
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
loop_unroll_visitor.hpp
Unroll for loop in the AST.
nmodl::visitor::IndexRemover
Helper visitor to replace index of array variable with integer.
Definition: index_remover.hpp:28
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::LoopUnrollVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
Parse verbatim blocks and rename variable if it is used.
Definition: loop_unroll_visitor.cpp:80
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
visitor_utils.hpp
Utility functions for visitors implementation.
c11_driver.hpp
nmodl::visitor::AstVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: ast_visitor.cpp:124
nmodl::visitor::unroll_for_loop
static std::shared_ptr< ast::ExpressionStatement > unroll_for_loop(const std::shared_ptr< ast::FromStatement > &node)
Unroll given for loop.
Definition: loop_unroll_visitor.cpp:37
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
index_remover.hpp
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::ast::StatementBlock::reset_statement
void reset_statement(StatementVector::const_iterator position, Statement *n)
Reset member to statements.
Definition: ast.cpp:3142
logger.hpp
Implement logger based on spdlog library.
nmodl::ast::AstNodeType::VERBATIM
@ VERBATIM
type of ast::Verbatim
nmodl::visitor::unwrap
static std::shared_ptr< ast::Expression > unwrap(const std::shared_ptr< ast::Expression > &expr)
return underlying expression wrapped by WrappedExpression
Definition: loop_unroll_visitor.cpp:22
all.hpp
Auto generated AST classes declaration.