User Guide
visitor_utils.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include "visitor_utils.hpp"
9 
10 #include <map>
11 #include <memory>
12 #include <string>
13 
14 #include "ast/all.hpp"
16 #include "parser/nmodl_driver.hpp"
17 #include "utils/string_utils.hpp"
21 
22 #include <fmt/format.h>
23 
24 namespace nmodl {
25 namespace visitor {
26 
27 using namespace ast;
29 
31 
32 std::string suffix_random_string(const std::set<std::string>& vars,
33  const std::string& original_string,
34  const UseNumbersInString use_num) {
35  // If the "original_string" is not in the set of the variables to check then
36  // return the "original_string" without suffix
37  if (vars.find(original_string) == vars.end()) {
38  return original_string;
39  }
40  std::string new_string = original_string;
41  auto& singleton_random_string_class = nmodl::utils::SingletonRandomString<4>::instance();
42  // Check if there is a variable defined in the mod file and, if yes, try to use
43  // a different string in the form "original_string"_"random_string"
44  // If there is already a "random_string" assigned to the "originl_string" return it
45  if (singleton_random_string_class.random_string_exists(original_string)) {
46  const auto random_suffix = "_" +
47  singleton_random_string_class.get_random_string(original_string);
48  new_string = original_string + random_suffix;
49  } else {
50  // Check if the "random_string" already exists in the set of variables and if it does try
51  // to find another random string to add as suffix
52  while (vars.find(new_string) != vars.end()) {
53  const auto random_suffix =
54  "_" + singleton_random_string_class.reset_random_string(original_string, use_num);
55  new_string = original_string + random_suffix;
56  }
57  }
58  return new_string;
59 }
60 
61 std::string get_new_name(const std::string& name,
62  const std::string& suffix,
63  std::map<std::string, int>& variables) {
64  auto it = variables.emplace(name, 0);
65  auto counter = it.first->second;
66  ++it.first->second;
67 
68  std::ostringstream oss;
69  oss << name << '_' << suffix << '_' << counter;
70  return oss.str();
71 }
72 
73 std::shared_ptr<ast::LocalListStatement> get_local_list_statement(const StatementBlock& node) {
74  const auto& statements = node.get_statements();
75  for (const auto& statement: statements) {
76  if (statement->is_local_list_statement()) {
77  return std::static_pointer_cast<LocalListStatement>(statement);
78  }
79  }
80  return nullptr;
81 }
82 
84  auto variables = get_local_list_statement(node);
85  const auto& statements = node.get_statements();
86  if (variables == nullptr) {
87  auto statement = std::make_shared<LocalListStatement>(LocalVarVector());
88  node.insert_statement(statements.begin(), statement);
89  }
90 }
91 
93  add_local_statement(node);
94 
95  auto local_list_statement = get_local_list_statement(node);
96  /// each block should already have local statement
97  if (local_list_statement == nullptr) {
98  throw std::logic_error("no local statement");
99  }
100  auto var = std::make_shared<LocalVar>(varname);
101  local_list_statement->emplace_back_local_var(var);
102 
103  return var.get();
104 }
105 
106 LocalVar* add_local_variable(StatementBlock& node, const std::string& varname) {
107  auto name = new Name(new String(varname));
108  return add_local_variable(node, name);
109 }
110 
111 LocalVar* add_local_variable(StatementBlock& node, const std::string& varname, int dim) {
112  auto name = new IndexedName(new Name(new String(varname)), new Integer(dim, nullptr));
113  return add_local_variable(node, name);
114 }
115 
116 /**
117  * Convert given code statement (in string format) to corresponding ast node
118  *
119  * We create dummy nmodl procedure containing given code statement and then
120  * parse it using NMODL parser. As there will be only one block with single
121  * statement, we return first statement.
122  *
123  * \todo Need to revisit this during code generation passes to make sure
124  * if all statements can be part of procedure block.
125  */
126 std::shared_ptr<Statement> create_statement(const std::string& code_statement) {
128  auto nmodl_text = "PROCEDURE dummy() { " + code_statement + " }";
129  auto ast = driver.parse_string(nmodl_text);
130  auto procedure = std::dynamic_pointer_cast<ProcedureBlock>(ast->get_blocks().front());
131  auto statement = std::shared_ptr<Statement>(
132  procedure->get_statement_block()->get_statements()[0]->clone());
133  return statement;
134 }
135 
136 std::vector<std::shared_ptr<Statement>> create_statements(
137  const std::vector<std::string>::const_iterator& code_statements_beg,
138  const std::vector<std::string>::const_iterator& code_statements_end) {
139  std::vector<std::shared_ptr<Statement>> statements;
140  statements.reserve(code_statements_end - code_statements_beg);
141  std::transform(code_statements_beg,
142  code_statements_end,
143  std::back_inserter(statements),
144  [](const std::string& s) { return create_statement(s); });
145  return statements;
146 }
147 
148 /**
149  * Convert given code statement (in string format) to corresponding ast node
150  *
151  * We create dummy nmodl procedure containing given code statement and then
152  * parse it using NMODL parser. As there will be only one block with single
153  * statement, we return first statement.
154  */
155 std::shared_ptr<StatementBlock> create_statement_block(
156  const std::vector<std::string>& code_statements) {
158  std::string nmodl_text = "PROCEDURE dummy() {\n";
159  for (auto& statement: code_statements) {
160  nmodl_text += statement + "\n";
161  }
162  nmodl_text += "}";
163  auto ast = driver.parse_string(nmodl_text);
164  auto procedure = std::dynamic_pointer_cast<ProcedureBlock>(ast->get_blocks().front());
165  auto statement_block = std::shared_ptr<StatementBlock>(
166  procedure->get_statement_block()->clone());
167  return statement_block;
168 }
169 
170 std::set<std::string> get_global_vars(const Program& node) {
171  std::set<std::string> vars;
172  if (auto* symtab = node.get_symbol_table()) {
173  // NB: local_var included here as locals can be declared at global scope
174  NmodlType property = NmodlType::global_var | NmodlType::local_var | NmodlType::range_var |
175  NmodlType::param_assign | NmodlType::extern_var |
176  NmodlType::prime_name | NmodlType::assigned_definition |
177  NmodlType::read_ion_var | NmodlType::write_ion_var |
178  NmodlType::nonspecific_cur_var | NmodlType::electrode_cur_var |
179  NmodlType::constant_var | NmodlType::extern_neuron_variable |
180  NmodlType::state_var | NmodlType::factor_def;
181  for (const auto& globalvar: symtab->get_variables_with_properties(property)) {
182  std::string var_name = globalvar->get_name();
183  if (globalvar->is_array()) {
184  var_name += "[" + std::to_string(globalvar->get_length()) + "]";
185  }
186  vars.insert(var_name);
187  }
188  }
189  return vars;
190 }
191 
192 
193 bool calls_function(const ast::Ast& node, const std::string& name) {
194  const auto& function_calls = collect_nodes(node, {ast::AstNodeType::FUNCTION_CALL});
195  return std::any_of(
196  function_calls.begin(),
197  function_calls.end(),
198  [&name](const std::shared_ptr<const ast::Ast>& f) {
199  return std::dynamic_pointer_cast<const ast::FunctionCall>(f)->get_node_name() == name;
200  });
201 }
202 
203 } // namespace visitor
204 
205 std::vector<std::shared_ptr<const ast::Ast>> collect_nodes(
206  const ast::Ast& node,
207  const std::vector<ast::AstNodeType>& types) {
209  return visitor.lookup(node, types);
210 }
211 
212 std::vector<std::shared_ptr<ast::Ast>> collect_nodes(ast::Ast& node,
213  const std::vector<ast::AstNodeType>& types) {
215  return visitor.lookup(node, types);
216 }
217 
218 bool sparse_solver_exists(const ast::Ast& node) {
219  const auto solve_blocks = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
220  return std::any_of(solve_blocks.begin(), solve_blocks.end(), [](auto const& solve_block) {
221  assert(solve_block);
222  const auto& method = dynamic_cast<ast::SolveBlock const&>(*solve_block).get_method();
223  return method && method->get_node_name() == "sparse";
224  });
225 }
226 
227 std::string to_nmodl(const ast::Ast& node, const std::set<ast::AstNodeType>& exclude_types) {
228  std::stringstream stream;
229  visitor::NmodlPrintVisitor v(stream, exclude_types);
230  node.accept(v);
231  return stream.str();
232 }
233 
234 
235 std::string to_json(const ast::Ast& node, bool compact, bool expand, bool add_nmodl) {
236  std::stringstream stream;
237  visitor::JSONVisitor v(stream);
238  v.compact_json(compact);
239  v.add_nmodl(add_nmodl);
240  v.expand_keys(expand);
241  node.accept(v);
242  v.flush();
243  return stream.str();
244 }
245 
246 std::pair<std::string, std::unordered_set<std::string>> statement_dependencies(
247  const std::shared_ptr<ast::Expression>& lhs,
248  const std::shared_ptr<ast::Expression>& rhs) {
249  std::string key;
250  std::unordered_set<std::string> out;
251 
252  if (!lhs->is_var_name()) {
253  return {key, out};
254  }
255 
256  const auto& lhs_var_name = std::dynamic_pointer_cast<ast::VarName>(lhs);
257  key = get_full_var_name(*lhs_var_name);
258 
259  visitor::AstLookupVisitor lookup_visitor;
260  lookup_visitor.lookup(*rhs, ast::AstNodeType::VAR_NAME);
261  auto rhs_nodes = lookup_visitor.get_nodes();
262  std::for_each(rhs_nodes.begin(),
263  rhs_nodes.end(),
264  [&out](const std::shared_ptr<ast::Ast>& node) { out.emplace(to_nmodl(node)); });
265 
266 
267  return {key, out};
268 }
269 
270 std::string get_indexed_name(const ast::IndexedName& node) {
271  return fmt::format("{}[{}]", node.get_node_name(), to_nmodl(node.get_length()));
272 }
273 
274 std::string get_full_var_name(const ast::VarName& node) {
275  std::string full_var_name;
276  if (node.get_name()->is_indexed_name()) {
277  auto index_name_node = std::dynamic_pointer_cast<ast::IndexedName>(node.get_name());
278  full_var_name = get_indexed_name(*index_name_node);
279  } else {
280  full_var_name = node.get_node_name();
281  }
282  return full_var_name;
283 }
284 
285 bool is_random_construct_function(const std::string& name) {
286  return codegen::naming::RANDOM_FUNCTIONS_MAPPING.count(name) != 0;
287 }
288 
289 } // namespace nmodl
nmodl::ast::IndexedName::get_length
std::shared_ptr< Expression > get_length() const noexcept
Getter for member variable IndexedName::length.
Definition: indexed_name.hpp:176
nmodl::visitor::MetaAstLookupVisitor::get_nodes
const nodes_t & get_nodes() const noexcept
Definition: lookup_visitor.hpp:70
nmodl::parser::NmodlDriver
Class that binds all pieces together for parsing nmodl file.
Definition: nmodl_driver.hpp:63
nmodl::get_indexed_name
std::string get_indexed_name(const ast::IndexedName &node)
Given a Indexed node, return the name with index.
Definition: visitor_utils.cpp:270
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::ast::Identifier
Base class for all identifiers.
Definition: identifier.hpp:41
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
nmodl::visitor::create_statements
std::vector< std::shared_ptr< Statement > > create_statements(const std::vector< std::string >::const_iterator &code_statements_beg, const std::vector< std::string >::const_iterator &code_statements_end)
Same as for create_statement but for vectors of strings.
Definition: visitor_utils.cpp:136
nmodl::visitor::calls_function
bool calls_function(const ast::Ast &node, const std::string &name)
Checks whether block contains a call to a particular function.
Definition: visitor_utils.cpp:193
nmodl::ast::IndexedName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1000
nmodl::visitor::get_local_list_statement
std::shared_ptr< ast::LocalListStatement > get_local_list_statement(const StatementBlock &node)
Return pointer to local statement in the given block, otherwise nullptr.
Definition: visitor_utils.cpp:73
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::to_json
std::string to_json(const ast::Ast &node, bool compact, bool expand, bool add_nmodl)
Given AST node, return the JSON string representation.
Definition: visitor_utils.cpp:235
nmodl::ast::StatementBlock::insert_statement
StatementVector::const_iterator insert_statement(StatementVector::const_iterator position, const std::shared_ptr< Statement > &n)
Insert member to statements.
Definition: ast.cpp:3130
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
nmodl::ast::Integer
Represents an integer variable.
Definition: integer.hpp:49
string_utils.hpp
Implement string manipulation functions.
nmodl::is_random_construct_function
bool is_random_construct_function(const std::string &name)
Is given name a one of the function for RANDOM construct.
Definition: visitor_utils.cpp:285
nmodl::statement_dependencies
std::pair< std::string, std::unordered_set< std::string > > statement_dependencies(const std::shared_ptr< ast::Expression > &lhs, const std::shared_ptr< ast::Expression > &rhs)
If lhs and rhs combined represent an assignment (we assume to have an "=" in between them) we extract...
Definition: visitor_utils.cpp:246
nmodl::ast::LocalVarVector
std::vector< std::shared_ptr< LocalVar > > LocalVarVector
Definition: ast_decl.hpp:352
nmodl::visitor::JSONVisitor::flush
JSONVisitor & flush()
Definition: json_visitor.hpp:60
codegen_naming.hpp
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::IndexedName
Represents specific element of an array variable.
Definition: indexed_name.hpp:48
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
nmodl::utils::UseNumbersInString
UseNumbersInString
Enum to wrap bool variable to select if random string should have numbers or not.
Definition: common_utils.hpp:51
nmodl::visitor::MetaAstLookupVisitor::lookup
const nodes_t & lookup(ast_t &node)
nmodl::visitor::NmodlPrintVisitor
Visitor for printing AST back to NMODL
Definition: nmodl_visitor.hpp:44
nmodl::symtab::syminfo::to_string
std::string to_string(const T &obj)
Definition: symbol_properties.hpp:279
nmodl::ast::AstNodeType::SOLVE_BLOCK
@ SOLVE_BLOCK
type of ast::SolveBlock
nmodl::ast::AstNodeType::FUNCTION_CALL
@ FUNCTION_CALL
type of ast::FunctionCall
nmodl::get_full_var_name
std::string get_full_var_name(const ast::VarName &node)
Given a VarName node, return the full var name including index.
Definition: visitor_utils.cpp:274
nmodl::codegen::naming::RANDOM_FUNCTIONS_MAPPING
static std::unordered_map< std::string, std::string > RANDOM_FUNCTIONS_MAPPING
Definition: codegen_naming.hpp:192
nmodl::parser::UnitDriver::parse_string
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
nmodl::ast::Ast::accept
virtual void accept(visitor::Visitor &v)=0
Accept (or visit) the AST node using current visitor.
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:126
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::ast::LocalVar
TODO.
Definition: local_var.hpp:38
nmodl::visitor::JSONVisitor::compact_json
JSONVisitor & compact_json(bool flag)
Definition: json_visitor.hpp:65
nmodl::visitor::JSONVisitor::expand_keys
JSONVisitor & expand_keys(bool flag)
Definition: json_visitor.hpp:75
nmodl::visitor::add_local_statement
void add_local_statement(StatementBlock &node)
Add empty local statement to given block if already doesn't exist.
Definition: visitor_utils.cpp:83
nmodl::visitor::JSONVisitor
Visitor for printing AST in JSON format
Definition: json_visitor.hpp:37
nmodl::visitor::get_new_name
std::string get_new_name(const std::string &name, const std::string &suffix, std::map< std::string, int > &variables)
Return new name variable by appending _suffix_COUNT where COUNT is number of times the given variable...
Definition: visitor_utils.cpp:61
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::utils::SingletonRandomString::instance
static SingletonRandomString & instance()
Function to instantiate the SingletonRandomString class.
Definition: common_utils.hpp:75
nmodl::visitor::create_statement_block
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:155
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:92
nmodl::ast::VarName::get_name
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
Definition: var_name.hpp:164
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl_driver.hpp
nmodl::ast::Name
Represents a name.
Definition: name.hpp:44
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::sparse_solver_exists
bool sparse_solver_exists(const ast::Ast &node)
Definition: visitor_utils.cpp:218
nmodl::visitor::suffix_random_string
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
Definition: visitor_utils.cpp:32
lookup_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
json_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::ast::String
Represents a string.
Definition: string.hpp:52
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:170
all.hpp
Auto generated AST classes declaration.
nmodl::ast::VarName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1120
nmodl::visitor::JSONVisitor::add_nmodl
JSONVisitor & add_nmodl(bool flag)
Definition: json_visitor.hpp:70
nmodl_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::ast::AstNodeType::VAR_NAME
@ VAR_NAME
type of ast::VarName
nmodl::visitor::MetaAstLookupVisitor
Definition: lookup_visitor.hpp:34