User Guide
inline_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "ast/ast_decl.hpp"
12 #include "parser/c11_driver.hpp"
13 #include "utils/logger.hpp"
17 
18 
19 namespace nmodl {
20 namespace visitor {
21 
22 using namespace ast;
23 
25  bool to_inline = true;
26  const auto& statements = block.get_statements();
27  for (const auto& statement: statements) {
28  /// inlining is disabled if function/procedure contains table or lag statement
29  if (statement->is_table_statement() || statement->is_lag_statement()) {
30  to_inline = false;
31  break;
32  }
33  // verbatim blocks with return statement are not safe to inline
34  // especially for net_receive block
35  if (statement->is_verbatim()) {
36  const auto node = dynamic_cast<const Verbatim*>(statement.get());
37  assert(node);
38  auto text = node->get_statement()->eval();
40  driver.scan_string(text);
41  if (driver.has_token("return")) {
42  to_inline = false;
43  break;
44  }
45  }
46  }
47  return to_inline;
48 }
49 
50 void InlineVisitor::add_return_variable(StatementBlock& block, std::string& varname) {
51  auto lhs = new Name(new String(varname));
52  auto rhs = new Integer(0, nullptr);
53  auto expression = new BinaryExpression(lhs, BinaryOperator(BOP_ASSIGN), rhs);
54  auto statement = std::make_shared<ExpressionStatement>(expression);
55  block.emplace_back_statement(statement);
56 }
57 
58 /** We can replace statement if the entire statement itself is a function call.
59  * In this case we check if:
60  * - given statement is expression statement
61  * - if expression is wrapped expression
62  * - if wrapped expression is a function call
63  *
64  * \todo Add method to ast itself to simplify this implementation
65  */
66 bool InlineVisitor::can_replace_statement(const std::shared_ptr<Statement>& statement) {
67  if (!statement->is_expression_statement()) {
68  return false;
69  }
70 
71  bool to_replace = false;
72  auto es = dynamic_cast<ExpressionStatement*>(statement.get());
73  assert(es);
74  auto e = es->get_expression();
75  if (e->is_wrapped_expression()) {
76  auto wrapped_expression = dynamic_cast<WrappedExpression*>(e.get());
77  assert(wrapped_expression);
78  if (wrapped_expression->get_expression()->is_function_call()) {
79  // if caller is external function (i.e. neuron function) don't replace it
80  const auto& function_call = std::static_pointer_cast<FunctionCall>(
81  wrapped_expression->get_expression());
82  const auto& function_name = function_call->get_node_name();
83  const auto& symbol = program_symtab->lookup_in_scope(function_name);
84  to_replace = !symbol->is_external_variable();
85  }
86  }
87  return to_replace;
88 }
89 
91  const ArgumentVector& callee_parameters,
92  const ExpressionVector& caller_expressions) {
93  /// nothing to inline if no arguments for function call
94  if (caller_expressions.empty()) {
95  return;
96  }
97 
98  size_t counter = 0;
99  const auto& statements = inlined_block.get_statements();
100 
101  for (const auto& argument: callee_parameters) {
102  auto name = argument->get_name()->clone();
103  auto old_name = name->get_node_name();
104  auto new_name = get_new_name(old_name, "in", inlined_variables);
105  name->set_name(new_name);
106 
107  /// for argument add new variable to local statement
108  add_local_variable(inlined_block, name);
109 
110  /// variables in cloned block needs to be renamed
111  RenameVisitor visitor(old_name, new_name);
112  inlined_block.visit_children(visitor);
113 
114  auto lhs = new VarName(name->clone(), nullptr, nullptr);
115  auto rhs = caller_expressions.at(counter)->clone();
116 
117  /// create assignment statement and insert after the local variables
118  auto expression = new BinaryExpression(lhs, BinaryOperator(ast::BOP_ASSIGN), rhs);
119  auto statement = std::make_shared<ExpressionStatement>(expression);
120  inlined_block.insert_statement(statements.begin() +
121  static_cast<std::ptrdiff_t>(counter + 1ul),
122  statement);
123  counter++;
124  }
125 }
126 
127 
129  ast::FunctionCall& node,
130  ast::StatementBlock& caller) {
131  const auto& function_name = callee.get_node_name();
132 
133  /// do nothing if we can't inline given procedure/function
134  if (!can_inline_block(*callee.get_statement_block())) {
135  logger->warn("Can not inline function call to {}", function_name);
136  return false;
137  }
138 
139  /// make sure to rename conflicting local variable in caller block
140  /// because in case of procedure inlining they can conflict with
141  /// global variables used in procedure being inlined
143  v.visit_statement_block(caller);
144 
145  const auto& caller_arguments = node.get_arguments();
146  std::string new_varname = get_new_name(function_name, "in", inlined_variables);
147 
148  /// check if caller statement could be replaced
149  bool to_replace = can_replace_statement(caller_statement);
150 
151  /// need to add local variable for function calls or for procedure call if it is part of
152  /// expression (standalone procedure calls don't need return statement)
153  if (callee.is_function_block() || !to_replace) {
154  /// create new variable which will be used for returning value from inlined block
155  auto name = new ast::Name(new ast::String(new_varname));
156  ModToken tok;
157  name->set_token(tok);
158 
159  auto local_list_statement = get_local_list_statement(caller);
160  /// each block should already have local statement
161  if (local_list_statement == nullptr) {
162  throw std::logic_error("got local statement as nullptr");
163  }
164  local_list_statement->emplace_back_local_var(std::make_shared<ast::LocalVar>(name));
165  }
166 
167  /// get a copy of function/procedure body
168  auto inlined_block = std::unique_ptr<ast::StatementBlock>(
169  callee.get_statement_block()->clone());
170 
171  /// function definition has function name as return value. we have to rename
172  /// it with new variable name
173  RenameVisitor visitor(function_name, new_varname);
174  inlined_block->visit_children(visitor);
175 
176  inlined_block->set_symbol_table(nullptr);
177 
178  /// each argument is added as new assignment statement
179  inline_arguments(*inlined_block, callee.get_parameters(), caller_arguments);
180 
181  /// to return value from procedure we have to add new variable
182  if (callee.is_procedure_block() && !to_replace) {
183  add_return_variable(*inlined_block, new_varname);
184  }
185 
186  auto statement = new ast::ExpressionStatement(std::move(inlined_block));
187 
188  if (to_replace) {
189  replaced_statements[caller_statement] = statement;
190  } else {
191  inlined_statements[caller_statement].push_back(
192  std::shared_ptr<ast::ExpressionStatement>(statement));
193  }
194 
195  /// variable name which will replace the function call that we just inlined
196  replaced_fun_calls[&node] = new_varname;
197  return true;
198 }
199 
200 
202  /// argument can be function call itself
203  node.visit_children(*this);
204 
205  const auto& function_name = node.get_name()->get_node_name();
206  auto symbol = program_symtab->lookup_in_scope(function_name);
207 
208  /// nothing to do if called function is not defined or it's external
209  if (symbol == nullptr || symbol->is_external_variable()) {
210  return;
211  }
212 
213  auto nodes = symbol->get_nodes_by_type(
214  {AstNodeType::FUNCTION_BLOCK, AstNodeType::PROCEDURE_BLOCK});
215  if (nodes.empty()) {
216  throw std::runtime_error("symbol table doesn't have ast node for " + function_name);
217  }
218  auto f_block = nodes.front();
219 
220  /// first inline called function
221  f_block->visit_children(*this);
222 
223  bool inlined = false;
224 
225  auto block = dynamic_cast<ast::Block*>(f_block);
226  assert(block);
227  inlined = inline_function_call(*block, node, *caller_block);
228 
229  if (inlined) {
230  symbol->mark_inlined();
231  }
232 }
233 
235  /** While inlining we have to add new statements before call site.
236  * In order to return result we also have to add new local variable
237  * to the caller block. Hence we have to keep track of caller block,
238  * caller block's symbol table and caller statement.
239  */
240  caller_block = &node;
241  statementblock_stack.push(&node);
242 
243  /** Add empty local statement at the begining of block if already doesn't exist.
244  * Why? When we iterate over statements and inline function call, we have to add
245  * local variable to return the result. As we can't modify vector while iterating,
246  * we pre-add local statement without any local variables. If inlining pass doesn't
247  * add any variable then this statement will be removed.
248  */
249  add_local_statement(node);
250 
251  const auto& statements = node.get_statements();
252 
253  for (const auto& statement: statements) {
254  caller_statement = statement;
255  statement_stack.push(statement);
256  caller_statement->visit_children(*this);
257  statement_stack.pop();
258  }
259 
260  /// each block should already have local statement
261  auto local_list_statement = get_local_list_statement(node);
262  if (local_list_statement->get_variables().empty()) {
263  node.erase_statement(statements.begin());
264  }
265 
266  /// check if any statement is candidate for replacement due to inlining
267  /// this is typicall case of procedure calls
268  for (auto it = statements.begin(); it < statements.end(); ++it) {
269  const auto& statement = *it;
270  if (replaced_statements.find(statement) != replaced_statements.end()) {
271  node.reset_statement(it, replaced_statements[statement]);
272  }
273  }
274 
275  /// all statements from inlining needs to be added before the caller statement
276  for (auto& element: inlined_statements) {
277  auto it = std::find(statements.begin(), statements.end(), element.first);
278  if (it != statements.end()) {
279  node.insert_statement(it, element.second, element.second.begin(), element.second.end());
280  element.second.clear();
281  }
282  }
283 
284  /** Restore the caller context : consider call chain A() -> B() -> none.
285  * When we finishes processing B's statements, even we pop the stack,
286  * caller_* variables still point to B's context. Hence first we have
287  * to pop elements and then use top() to get context of A(). We have to
288  * check for non-empty() stack because if there is only A() -> none then
289  * stack is already empty.
290  */
291  statementblock_stack.pop();
292 
293  if (!statement_stack.empty()) {
294  caller_statement = statement_stack.top();
295  }
296  if (!statementblock_stack.empty()) {
297  caller_block = statementblock_stack.top();
298  }
299 }
300 
301 /** Visit all wrapped expressions which can contain function calls.
302  * If a function call is replaced then the wrapped expression is
303  * also replaced with new variable node from the inlining result.
304  */
306  node.visit_children(*this);
307  const auto& e = node.get_expression();
308  if (e->is_function_call()) {
309  auto expression = dynamic_cast<FunctionCall*>(e.get());
310  // if node is inlined, replace it with corresponding variable name
311  // and remove entry from the bookkeeping map
312  if (replaced_fun_calls.find(expression) != replaced_fun_calls.end()) {
313  auto var = replaced_fun_calls[expression];
314  node.set_expression(std::make_shared<Name>(new String(var)));
315  replaced_fun_calls.erase(expression);
316  }
317  }
318 }
319 
321  program_symtab = node.get_symbol_table();
322  if (program_symtab == nullptr) {
323  throw std::runtime_error("Program node doesn't have symbol table");
324  }
325  node.visit_children(*this);
326 }
327 
328 } // namespace visitor
329 } // namespace nmodl
nmodl::ast::Verbatim
Represents a C code block.
Definition: verbatim.hpp:38
nmodl::ast::WrappedExpression::get_expression
std::shared_ptr< Expression > get_expression() const noexcept
Getter for member variable WrappedExpression::expression.
Definition: wrapped_expression.hpp:143
nmodl::visitor::InlineVisitor::inline_arguments
void inline_arguments(ast::StatementBlock &inlined_block, const ast::ArgumentVector &callee_parameters, const ast::ExpressionVector &caller_expressions)
add assignment statements into given statement block to inline arguments
Definition: inline_visitor.cpp:90
nmodl::ast::StatementBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3158
nmodl::ast::FunctionCall::get_name
std::shared_ptr< Name > get_name() const noexcept
Getter for member variable FunctionCall::name.
Definition: function_call.hpp:157
ast_decl.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::visitor::get_local_list_statement
std::shared_ptr< ast::LocalListStatement > get_local_list_statement(const StatementBlock &node)
Return pointer to local statement in the given block, otherwise nullptr.
Definition: visitor_utils.cpp:73
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::visitor::InlineVisitor::can_replace_statement
bool can_replace_statement(const std::shared_ptr< ast::Statement > &statement)
true if statement can be replaced with inlined body this is possible for standalone function/procedur...
Definition: inline_visitor.cpp:66
nmodl::ast::StatementBlock::insert_statement
StatementVector::const_iterator insert_statement(StatementVector::const_iterator position, const std::shared_ptr< Statement > &n)
Insert member to statements.
Definition: ast.cpp:3130
nmodl::ast::Ast::is_procedure_block
virtual bool is_procedure_block() const noexcept
Check if the ast node is an instance of ast::ProcedureBlock.
Definition: ast.cpp:144
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
nmodl::ast::Integer
Represents an integer variable.
Definition: integer.hpp:49
nmodl::ast::BOP_ASSIGN
@ BOP_ASSIGN
=
Definition: ast_common.hpp:59
nmodl::ast::Ast::is_function_block
virtual bool is_function_block() const noexcept
Check if the ast node is an instance of ast::FunctionBlock.
Definition: ast.cpp:142
nmodl::visitor::InlineVisitor::inline_function_call
bool inline_function_call(ast::Block &callee, ast::FunctionCall &node, ast::StatementBlock &caller)
inline function/procedure into caller block
Definition: inline_visitor.cpp:128
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::parser::CDriver
Class that binds all pieces together for parsing C verbatim blocks.
Definition: c11_driver.hpp:37
nmodl::visitor::InlineVisitor::can_inline_block
static bool can_inline_block(const ast::StatementBlock &block)
true if given statement block can be inlined
Definition: inline_visitor.cpp:24
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
nmodl::ast::ExpressionStatement
TODO.
Definition: expression_statement.hpp:38
nmodl::ast::Block
Base class for all block scoped nodes.
Definition: block.hpp:41
nmodl::visitor::InlineVisitor::visit_function_call
void visit_function_call(ast::FunctionCall &node) override
visit node of type ast::FunctionCall
Definition: inline_visitor.cpp:201
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::ArgumentVector
std::vector< std::shared_ptr< Argument > > ArgumentVector
Definition: ast_decl.hpp:312
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
nmodl::ast::StatementBlock::erase_statement
StatementVector::const_iterator erase_statement(StatementVector::const_iterator first)
Erase member to statements.
Definition: ast.cpp:3092
nmodl::ast::FunctionCall
TODO.
Definition: function_call.hpp:38
c11_driver.hpp
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12902
nmodl::visitor::InlineVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: inline_visitor.cpp:320
nmodl::ast::FunctionCall::get_arguments
const ExpressionVector & get_arguments() const noexcept
Getter for member variable FunctionCall::arguments.
Definition: function_call.hpp:166
local_var_rename_visitor.hpp
Visitor to rename local variables conflicting with global scope
nmodl::ast::WrappedExpression::set_expression
void set_expression(std::shared_ptr< Expression > &&expression)
Setter for member variable WrappedExpression::expression (rvalue reference)
Definition: ast.cpp:13674
nmodl::ast::StatementBlock::emplace_back_statement
void emplace_back_statement(Statement *n)
Add member to statements by raw pointer.
Definition: ast.cpp:3073
nmodl::visitor::RenameVisitor
Blindly rename given variable to new name
Definition: rename_visitor.hpp:43
nmodl::visitor::add_local_statement
void add_local_statement(StatementBlock &node)
Add empty local statement to given block if already doesn't exist.
Definition: visitor_utils.cpp:83
nmodl::visitor::get_new_name
std::string get_new_name(const std::string &name, const std::string &suffix, std::map< std::string, int > &variables)
Return new name variable by appending _suffix_COUNT where COUNT is number of times the given variable...
Definition: visitor_utils.cpp:61
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::ast::ExpressionVector
std::vector< std::shared_ptr< Expression > > ExpressionVector
Definition: ast_decl.hpp:299
nmodl::visitor::InlineVisitor::add_return_variable
static void add_return_variable(ast::StatementBlock &block, std::string &varname)
add assignment statement at end of block (to use as a return statement in case of procedure blocks)
Definition: inline_visitor.cpp:50
nmodl::ast::StatementBlock::reset_statement
void reset_statement(StatementVector::const_iterator position, Statement *n)
Reset member to statements.
Definition: ast.cpp:3138
nmodl::visitor::InlineVisitor::visit_wrapped_expression
void visit_wrapped_expression(ast::WrappedExpression &node) override
Visit all wrapped expressions which can contain function calls.
Definition: inline_visitor.cpp:305
nmodl::ast::BinaryOperator
Operator used in ast::BinaryExpression.
Definition: binary_operator.hpp:38
nmodl::ast::FunctionCall::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:7068
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:92
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
inline_visitor.hpp
Visitor to inline local procedure and function calls
nmodl::ast::Name
Represents a name.
Definition: name.hpp:44
nmodl::ast::Ast::get_statement_block
virtual std::shared_ptr< StatementBlock > get_statement_block() const
Return associated statement block for the AST node.
Definition: ast.cpp:32
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
rename_visitor.hpp
Blindly rename given variable to new name
nmodl::ast::WrappedExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:13616
nmodl::visitor::LocalVarRenameVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
rename name conflicting variables in the statement block and it's all children
Definition: local_var_rename_visitor.cpp:23
nmodl::ast::Block::get_parameters
virtual const ArgumentVector & get_parameters() const
Definition: block.hpp:50
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
nmodl::visitor::LocalVarRenameVisitor
Visitor to rename local variables conflicting with global scope
Definition: local_var_rename_visitor.hpp:62
nmodl::ModToken
Represent token returned by scanner.
Definition: modtoken.hpp:50
nmodl::ast::Ast::get_node_name
virtual std::string get_node_name() const
Return name of of the node.
Definition: ast.cpp:28
nmodl::ast::String
Represents a string.
Definition: string.hpp:52
all.hpp
Auto generated AST classes declaration.
nmodl::ast::WrappedExpression
Wrap any other expression type.
Definition: wrapped_expression.hpp:38
nmodl::visitor::InlineVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: inline_visitor.cpp:234