User Guide
constant_folder_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "utils/logger.hpp"
13 
14 
15 namespace nmodl {
16 namespace visitor {
17 
18 /// check if given expression is a number
19 /// note that the DEFINE node is already expanded to integer
20 static inline bool is_number(const std::shared_ptr<ast::Expression>& node) {
21  return node->is_integer() || node->is_double() || node->is_float();
22 }
23 
24 /// get value of a number node
25 /// TODO : eval method can be added to virtual base class
26 static double get_value(const std::shared_ptr<ast::Expression>& node) {
27  if (node->is_integer()) {
28  return std::dynamic_pointer_cast<ast::Integer>(node)->eval();
29  } else if (node->is_float()) {
30  return std::dynamic_pointer_cast<ast::Float>(node)->to_double();
31  } else if (node->is_double()) {
32  return std::dynamic_pointer_cast<ast::Double>(node)->to_double();
33  }
34  throw std::runtime_error("Invalid type passed to is_number()");
35 }
36 
37 /// operators that currently implemented
38 static inline bool supported_operator(ast::BinaryOp op) {
39  return op == ast::BOP_ADDITION || op == ast::BOP_SUBTRACTION || op == ast::BOP_MULTIPLICATION ||
40  op == ast::BOP_DIVISION;
41 }
42 
43 /// Evaluate binary operation
44 /// TODO : add support for other binary operators like ^ (pow)
45 static double compute(double lhs, ast::BinaryOp op, double rhs) {
46  switch (op) {
47  case ast::BOP_ADDITION:
48  return lhs + rhs;
49 
51  return lhs - rhs;
52 
54  return lhs * rhs;
55 
56  case ast::BOP_DIVISION:
57  return lhs / rhs;
58 
59  default:
60  throw std::logic_error("Invalid binary operator in constant folding");
61  }
62 }
63 
64 /**
65  * Visit parenthesis expression and simplify it
66  * @param node AST node representing an expression with parenthesis
67  *
68  * AST could have expression like (1+2). In this case, it has following
69  * form in the AST :
70  *
71  * parenthesis_exp => wrapped_expr => binary_expression => ...
72  *
73  * To make constant folding simple, we can remove intermediate wrapped_expr
74  * and directly replace binary_expression inside parenthesis_exp :
75  *
76  * parenthesis_exp => binary_expression => ...
77  */
79  node.visit_children(*this);
80  auto expr = node.get_expression();
81  if (expr->is_wrapped_expression()) {
82  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(expr);
83  node.set_expression(e->get_expression());
84  }
85 }
86 
87 /**
88  * Visit wrapped node type and perform constant folding
89  * @param node AST node that wrap other node types
90  *
91  * MOD file has expressions like
92  *
93  * a = 1 + 2
94  * DEFINE NN 10
95  * FROM i=0 TO NN-2 {
96  *
97  * }
98  *
99  * which need to be turned into
100  *
101  * a = 1 + 2
102  * DEFINE NN 10
103  * FROM i=0 TO 8 {
104  *
105  * }
106  */
108  node.visit_children(*this);
109  node.visit_children(*this);
110 
111  /// first expression which is wrapped
112  auto expr = node.get_expression();
113 
114  /// if wrapped expression is parentheses
115  bool is_parentheses = false;
116 
117  /// opposite to visit_paren_expression, we might have
118  /// a = (2+1)
119  /// in this case we can pick inner expression.
120  if (expr->is_paren_expression()) {
121  auto e = std::dynamic_pointer_cast<ast::ParenExpression>(expr);
122  expr = e->get_expression();
123  is_parentheses = true;
124  }
125 
126  /// we want to simplify binary expressions only
127  if (!expr->is_binary_expression()) {
128  /// wrapped expression might be parenthesis expression like (2)
129  /// which we can simplify to 2 to help next evaluations
130  if (is_parentheses) {
131  node.set_expression(std::move(expr));
132  }
133  return;
134  }
135 
136  auto binary_expr = std::dynamic_pointer_cast<ast::BinaryExpression>(expr);
137  auto lhs = binary_expr->get_lhs();
138  auto rhs = binary_expr->get_rhs();
139  auto op = binary_expr->get_op().get_value();
140 
141  /// in case of expression like
142  /// a = 2 + ((1) + (3))
143  /// we are in the innermost expression i.e. ((1) + (3))
144  /// where (1) and (3) are wrapped expression themself. we can
145  /// remove these extra wrapped expressions
146 
147  if (lhs->is_wrapped_expression()) {
148  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(lhs);
149  lhs = e->get_expression();
150  }
151 
152  if (rhs->is_wrapped_expression()) {
153  auto e = std::dynamic_pointer_cast<ast::WrappedExpression>(rhs);
154  rhs = e->get_expression();
155  }
156 
157  /// once we simplify, lhs and rhs must be numbers for constant folding
158  if (!is_number(lhs) || !is_number(rhs) || !supported_operator(op)) {
159  return;
160  }
161 
162  const std::string& nmodl_before = to_nmodl(binary_expr);
163 
164  /// compute the value of expression
165  auto value = compute(get_value(lhs), op, get_value(rhs));
166 
167  /// if both operands are not integers or floats, result is double
168  if (lhs->is_integer() && rhs->is_integer()) {
169  node.set_expression(std::make_shared<ast::Integer>(static_cast<int>(value), nullptr));
170  } else if (lhs->is_double() || rhs->is_double()) {
171  node.set_expression(std::make_shared<ast::Double>(stringutils::to_string(value)));
172  } else {
173  node.set_expression(std::make_shared<ast::Float>(stringutils::to_string(value)));
174  }
175 
176  const std::string& nmodl_after = to_nmodl(node.get_expression());
177  logger->debug("ConstantFolderVisitor : expression {} folded to {}", nmodl_before, nmodl_after);
178 }
179 
180 } // namespace visitor
181 } // namespace nmodl
nmodl::ast::BinaryOp
BinaryOp
enum Type for binary operators in NMODL
Definition: ast_common.hpp:47
nmodl::visitor::ConstantFolderVisitor::visit_paren_expression
void visit_paren_expression(ast::ParenExpression &node) override
Visit parenthesis expression and simplify it.
Definition: constant_folder_visitor.cpp:78
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::ast::WrappedExpression::get_expression
std::shared_ptr< Expression > get_expression() const noexcept
Getter for member variable WrappedExpression::expression.
Definition: wrapped_expression.hpp:143
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
constant_folder_visitor.hpp
Perform constant folding of integer/float/double expressions.
nmodl::ast::BOP_SUBTRACTION
@ BOP_SUBTRACTION
Definition: ast_common.hpp:49
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::ast::ParenExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6415
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::BOP_ADDITION
@ BOP_ADDITION
+
Definition: ast_common.hpp:48
nmodl::ast::BOP_DIVISION
@ BOP_DIVISION
\/
Definition: ast_common.hpp:51
nmodl::ast::BOP_MULTIPLICATION
@ BOP_MULTIPLICATION
*
Definition: ast_common.hpp:50
nmodl::visitor::is_number
static bool is_number(const std::shared_ptr< ast::Expression > &node)
check if given expression is a number note that the DEFINE node is already expanded to integer
Definition: constant_folder_visitor.cpp:20
nmodl::ast::WrappedExpression::set_expression
void set_expression(std::shared_ptr< Expression > &&expression)
Setter for member variable WrappedExpression::expression (rvalue reference)
Definition: ast.cpp:13674
nmodl::ast::ParenExpression::get_expression
std::shared_ptr< Expression > get_expression() const noexcept
Getter for member variable ParenExpression::expression.
Definition: paren_expression.hpp:143
nmodl::ast::ParenExpression
TODO.
Definition: paren_expression.hpp:38
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::get_value
static double get_value(const std::shared_ptr< ast::Expression > &node)
get value of a number node TODO : eval method can be added to virtual base class
Definition: constant_folder_visitor.cpp:26
nmodl::visitor::supported_operator
static bool supported_operator(ast::BinaryOp op)
operators that currently implemented
Definition: constant_folder_visitor.cpp:38
nmodl::visitor::compute
static double compute(double lhs, ast::BinaryOp op, double rhs)
Evaluate binary operation TODO : add support for other binary operators like ^ (pow)
Definition: constant_folder_visitor.cpp:45
nmodl::visitor::ConstantFolderVisitor::visit_wrapped_expression
void visit_wrapped_expression(ast::WrappedExpression &node) override
Visit wrapped node type and perform constant folding.
Definition: constant_folder_visitor.cpp:107
nmodl::ast::WrappedExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:13616
nmodl::stringutils::to_string
std::string to_string(double value, const std::string &format_spec)
Convert double value to string without trailing zeros.
Definition: string_utils.cpp:18
all.hpp
Auto generated AST classes declaration.
nmodl::ast::WrappedExpression
Wrap any other expression type.
Definition: wrapped_expression.hpp:38
nmodl::ast::ParenExpression::set_expression
void set_expression(std::shared_ptr< Expression > &&expression)
Setter for member variable ParenExpression::expression (rvalue reference)
Definition: ast.cpp:6473