User Guide
visitor_utils.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include "visitor_utils.hpp"
9 
10 #include <map>
11 #include <memory>
12 #include <nlohmann/json.hpp>
13 #include <string>
14 
15 #include "ast/all.hpp"
17 #include "parser/nmodl_driver.hpp"
18 #include "utils/string_utils.hpp"
22 
23 #include "utils/fmt.h"
24 
25 namespace nmodl {
26 namespace visitor {
27 
28 using namespace ast;
30 
32 
33 std::string suffix_random_string(const std::set<std::string>& vars,
34  const std::string& original_string,
35  const UseNumbersInString use_num) {
36  // If the "original_string" is not in the set of the variables to check then
37  // return the "original_string" without suffix
38  if (vars.find(original_string) == vars.end()) {
39  return original_string;
40  }
41  std::string new_string = original_string;
42  auto& singleton_random_string_class = nmodl::utils::SingletonRandomString<4>::instance();
43  // Check if there is a variable defined in the mod file and, if yes, try to use
44  // a different string in the form "original_string"_"random_string"
45  // If there is already a "random_string" assigned to the "originl_string" return it
46  if (singleton_random_string_class.random_string_exists(original_string)) {
47  const auto random_suffix = "_" +
48  singleton_random_string_class.get_random_string(original_string);
49  new_string = original_string + random_suffix;
50  } else {
51  // Check if the "random_string" already exists in the set of variables and if it does try
52  // to find another random string to add as suffix
53  while (vars.find(new_string) != vars.end()) {
54  const auto random_suffix =
55  "_" + singleton_random_string_class.reset_random_string(original_string, use_num);
56  new_string = original_string + random_suffix;
57  }
58  }
59  return new_string;
60 }
61 
62 std::string get_new_name(const std::string& name,
63  const std::string& suffix,
64  std::map<std::string, int>& variables) {
65  auto it = variables.emplace(name, 0);
66  auto counter = it.first->second;
67  ++it.first->second;
68 
69  std::ostringstream oss;
70  oss << name << '_' << suffix << '_' << counter;
71  return oss.str();
72 }
73 
74 std::shared_ptr<ast::LocalListStatement> get_local_list_statement(const StatementBlock& node) {
75  const auto& statements = node.get_statements();
76  for (const auto& statement: statements) {
77  if (statement->is_local_list_statement()) {
78  return std::static_pointer_cast<LocalListStatement>(statement);
79  }
80  }
81  return nullptr;
82 }
83 
85  auto variables = get_local_list_statement(node);
86  const auto& statements = node.get_statements();
87  if (variables == nullptr) {
88  auto statement = std::make_shared<LocalListStatement>(LocalVarVector());
89  node.insert_statement(statements.begin(), statement);
90  }
91 }
92 
94  add_local_statement(node);
95 
96  auto local_list_statement = get_local_list_statement(node);
97  /// each block should already have local statement
98  if (local_list_statement == nullptr) {
99  throw std::logic_error("no local statement");
100  }
101  auto var = std::make_shared<LocalVar>(varname);
102  local_list_statement->emplace_back_local_var(var);
103 
104  return var.get();
105 }
106 
107 LocalVar* add_local_variable(StatementBlock& node, const std::string& varname) {
108  auto name = new Name(new String(varname));
109  return add_local_variable(node, name);
110 }
111 
112 LocalVar* add_local_variable(StatementBlock& node, const std::string& varname, int dim) {
113  auto name = new IndexedName(new Name(new String(varname)), new Integer(dim, nullptr));
114  return add_local_variable(node, name);
115 }
116 
117 /**
118  * Convert given code statement (in string format) to corresponding ast node
119  *
120  * We create dummy nmodl procedure containing given code statement and then
121  * parse it using NMODL parser. As there will be only one block with single
122  * statement, we return first statement.
123  *
124  * \todo Need to revisit this during code generation passes to make sure
125  * if all statements can be part of procedure block.
126  */
127 std::shared_ptr<Statement> create_statement(const std::string& code_statement) {
129  auto nmodl_text = "PROCEDURE dummy() { " + code_statement + " }";
130  auto ast = driver.parse_string(nmodl_text);
131  auto procedure = std::dynamic_pointer_cast<ProcedureBlock>(ast->get_blocks().front());
132  auto statement = std::shared_ptr<Statement>(
133  procedure->get_statement_block()->get_statements()[0]->clone());
134  return statement;
135 }
136 
137 std::vector<std::shared_ptr<Statement>> create_statements(
138  const std::vector<std::string>::const_iterator& code_statements_beg,
139  const std::vector<std::string>::const_iterator& code_statements_end) {
140  std::vector<std::shared_ptr<Statement>> statements;
141  statements.reserve(code_statements_end - code_statements_beg);
142  std::transform(code_statements_beg,
143  code_statements_end,
144  std::back_inserter(statements),
145  [](const std::string& s) { return create_statement(s); });
146  return statements;
147 }
148 
149 /**
150  * Convert given code statement (in string format) to corresponding ast node
151  *
152  * We create dummy nmodl procedure containing given code statement and then
153  * parse it using NMODL parser. As there will be only one block with single
154  * statement, we return first statement.
155  */
156 std::shared_ptr<StatementBlock> create_statement_block(
157  const std::vector<std::string>& code_statements) {
159  std::string nmodl_text = "PROCEDURE dummy() {\n";
160  for (auto& statement: code_statements) {
161  nmodl_text += statement + "\n";
162  }
163  nmodl_text += "}";
164  auto ast = driver.parse_string(nmodl_text);
165  auto procedure = std::dynamic_pointer_cast<ProcedureBlock>(ast->get_blocks().front());
166  auto statement_block = std::shared_ptr<StatementBlock>(
167  procedure->get_statement_block()->clone());
168  return statement_block;
169 }
170 
171 std::set<std::string> get_global_vars(const Program& node) {
172  std::set<std::string> vars;
173  if (auto* symtab = node.get_symbol_table()) {
174  // NB: local_var included here as locals can be declared at global scope
175  NmodlType property = NmodlType::global_var | NmodlType::local_var | NmodlType::range_var |
176  NmodlType::param_assign | NmodlType::extern_var |
177  NmodlType::prime_name | NmodlType::assigned_definition |
178  NmodlType::read_ion_var | NmodlType::write_ion_var |
179  NmodlType::nonspecific_cur_var | NmodlType::electrode_cur_var |
180  NmodlType::constant_var | NmodlType::extern_neuron_variable |
181  NmodlType::state_var | NmodlType::factor_def;
182  for (const auto& globalvar: symtab->get_variables_with_properties(property)) {
183  std::string var_name = globalvar->get_name();
184  if (globalvar->is_array()) {
185  var_name += "[" + std::to_string(globalvar->get_length()) + "]";
186  }
187  vars.insert(var_name);
188  }
189  }
190  return vars;
191 }
192 
193 
194 bool calls_function(const ast::Ast& node, const std::string& name) {
195  const auto& function_calls = collect_nodes(node, {ast::AstNodeType::FUNCTION_CALL});
196  return std::any_of(
197  function_calls.begin(),
198  function_calls.end(),
199  [&name](const std::shared_ptr<const ast::Ast>& f) {
200  return std::dynamic_pointer_cast<const ast::FunctionCall>(f)->get_node_name() == name;
201  });
202 }
203 
204 } // namespace visitor
205 
206 std::vector<std::shared_ptr<const ast::Ast>> collect_nodes(
207  const ast::Ast& node,
208  const std::vector<ast::AstNodeType>& types) {
210  return visitor.lookup(node, types);
211 }
212 
213 std::vector<std::shared_ptr<ast::Ast>> collect_nodes(ast::Ast& node,
214  const std::vector<ast::AstNodeType>& types) {
216  return visitor.lookup(node, types);
217 }
218 
219 bool node_exists(const ast::Ast& node, ast::AstNodeType ast_type) {
220  const auto blocks = collect_nodes(node, {ast_type});
221  return !blocks.empty();
222 }
223 
224 bool solver_exists(const ast::Ast& node, const std::string& name) {
225  const auto solve_blocks = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
226  return std::any_of(solve_blocks.begin(), solve_blocks.end(), [&name](auto const& solve_block) {
227  assert(solve_block);
228  const auto& method = dynamic_cast<ast::SolveBlock const&>(*solve_block).get_method();
229  return method && method->get_node_name() == name;
230  });
231 }
232 
233 
234 std::string to_nmodl(const ast::Ast& node, const std::set<ast::AstNodeType>& exclude_types) {
235  std::stringstream stream;
236  visitor::NmodlPrintVisitor v(stream, exclude_types);
237  node.accept(v);
238  return stream.str();
239 }
240 
241 
242 std::string to_json(const ast::Ast& node, bool compact, bool expand, bool add_nmodl) {
243  std::stringstream stream;
244  visitor::JSONVisitor v(stream);
245  v.compact_json(compact);
246  v.add_nmodl(add_nmodl);
247  v.expand_keys(expand);
248  node.accept(v);
249  v.flush();
250  return stream.str();
251 }
252 
253 std::string statement_dependencies_key(const std::shared_ptr<ast::Expression>& lhs) {
254  if (!lhs->is_var_name()) {
255  return "";
256  }
257 
258  const auto& lhs_var_name = std::dynamic_pointer_cast<ast::VarName>(lhs);
259  return get_full_var_name(*lhs_var_name);
260 }
261 
262 std::pair<std::string, std::unordered_set<std::string>> statement_dependencies(
263  const std::shared_ptr<ast::Expression>& lhs,
264  const std::shared_ptr<ast::Expression>& rhs) {
265  std::string key = statement_dependencies_key(lhs);
266  std::unordered_set<std::string> out;
267  if (!lhs->is_var_name()) {
268  return {key, out};
269  }
270 
271  visitor::AstLookupVisitor lookup_visitor;
272  lookup_visitor.lookup(*rhs, ast::AstNodeType::VAR_NAME);
273  auto rhs_nodes = lookup_visitor.get_nodes();
274  std::for_each(rhs_nodes.begin(),
275  rhs_nodes.end(),
276  [&out](const std::shared_ptr<ast::Ast>& node) { out.emplace(to_nmodl(node)); });
277 
278 
279  return {key, out};
280 }
281 
282 std::string get_indexed_name(const ast::IndexedName& node) {
283  return fmt::format("{}[{}]", node.get_node_name(), to_nmodl(node.get_length()));
284 }
285 
286 std::string get_full_var_name(const ast::VarName& node) {
287  std::string full_var_name;
288  if (node.get_name()->is_indexed_name()) {
289  auto index_name_node = std::dynamic_pointer_cast<ast::IndexedName>(node.get_name());
290  full_var_name = get_indexed_name(*index_name_node);
291  } else {
292  full_var_name = node.get_node_name();
293  }
294  return full_var_name;
295 }
296 
297 bool is_random_construct_function(const std::string& name) {
298  return codegen::naming::RANDOM_FUNCTIONS_MAPPING.count(name) != 0;
299 }
300 
301 bool is_nrn_pointing(const std::string& name) {
303 }
304 
305 
306 } // namespace nmodl
nmodl::ast::IndexedName::get_length
std::shared_ptr< Expression > get_length() const noexcept
Getter for member variable IndexedName::length.
Definition: indexed_name.hpp:176
nmodl::visitor::MetaAstLookupVisitor::get_nodes
const nodes_t & get_nodes() const noexcept
Definition: lookup_visitor.hpp:70
nmodl::parser::NmodlDriver
Class that binds all pieces together for parsing nmodl file.
Definition: nmodl_driver.hpp:67
nmodl::get_indexed_name
std::string get_indexed_name(const ast::IndexedName &node)
Given a Indexed node, return the name with index.
Definition: visitor_utils.cpp:282
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
nmodl::ast::Identifier
Base class for all identifiers.
Definition: identifier.hpp:41
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
nmodl::visitor::create_statements
std::vector< std::shared_ptr< Statement > > create_statements(const std::vector< std::string >::const_iterator &code_statements_beg, const std::vector< std::string >::const_iterator &code_statements_end)
Same as for create_statement but for vectors of strings.
Definition: visitor_utils.cpp:137
nmodl::visitor::calls_function
bool calls_function(const ast::Ast &node, const std::string &name)
Checks whether block contains a call to a particular function.
Definition: visitor_utils.cpp:194
nmodl::ast::IndexedName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1004
nmodl::visitor::get_local_list_statement
std::shared_ptr< ast::LocalListStatement > get_local_list_statement(const StatementBlock &node)
Return pointer to local statement in the given block, otherwise nullptr.
Definition: visitor_utils.cpp:74
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::to_json
std::string to_json(const ast::Ast &node, bool compact, bool expand, bool add_nmodl)
Given AST node, return the JSON string representation.
Definition: visitor_utils.cpp:242
nmodl::ast::StatementBlock::insert_statement
StatementVector::const_iterator insert_statement(StatementVector::const_iterator position, const std::shared_ptr< Statement > &n)
Insert member to statements.
Definition: ast.cpp:3134
nmodl::ast::VarName
Represents a variable.
Definition: var_name.hpp:43
nmodl::ast::Integer
Represents an integer variable.
Definition: integer.hpp:49
string_utils.hpp
Implement string manipulation functions.
nmodl::is_random_construct_function
bool is_random_construct_function(const std::string &name)
Is given name a one of the function for RANDOM construct.
Definition: visitor_utils.cpp:297
nmodl::statement_dependencies
std::pair< std::string, std::unordered_set< std::string > > statement_dependencies(const std::shared_ptr< ast::Expression > &lhs, const std::shared_ptr< ast::Expression > &rhs)
If lhs and rhs combined represent an assignment (we assume to have an "=" in between them) we extract...
Definition: visitor_utils.cpp:262
nmodl::ast::LocalVarVector
std::vector< std::shared_ptr< LocalVar > > LocalVarVector
Definition: ast_decl.hpp:356
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
nmodl::visitor::JSONVisitor::flush
JSONVisitor & flush()
Definition: json_visitor.hpp:60
codegen_naming.hpp
nmodl::ast::StatementBlock::get_statements
const StatementVector & get_statements() const noexcept
Getter for member variable StatementBlock::statements.
Definition: statement_block.hpp:221
nmodl::solver_exists
bool solver_exists(const ast::Ast &node, const std::string &name)
Whether or not a solver of type name exists in the AST.
Definition: visitor_utils.cpp:224
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::IndexedName
Represents specific element of an array variable.
Definition: indexed_name.hpp:48
driver
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28
nmodl::utils::UseNumbersInString
UseNumbersInString
Enum to wrap bool variable to select if random string should have numbers or not.
Definition: common_utils.hpp:51
nmodl::visitor::MetaAstLookupVisitor::lookup
const nodes_t & lookup(ast_t &node)
nmodl::statement_dependencies_key
std::string statement_dependencies_key(const std::shared_ptr< ast::Expression > &lhs)
The result.first of statement_dependencies.
Definition: visitor_utils.cpp:253
nmodl::visitor::NmodlPrintVisitor
Visitor for printing AST back to NMODL
Definition: nmodl_visitor.hpp:44
nmodl::is_nrn_pointing
bool is_nrn_pointing(const std::string &name)
Is given name nrn_pointing.
Definition: visitor_utils.cpp:301
nmodl::symtab::syminfo::to_string
std::string to_string(const T &obj)
Definition: symbol_properties.hpp:282
nmodl::ast::AstNodeType::SOLVE_BLOCK
@ SOLVE_BLOCK
type of ast::SolveBlock
nmodl::node_exists
bool node_exists(const ast::Ast &node, ast::AstNodeType ast_type)
Whether a node of type ast_type exists as a subnode of node.
Definition: visitor_utils.cpp:219
nmodl::ast::AstNodeType::FUNCTION_CALL
@ FUNCTION_CALL
type of ast::FunctionCall
nmodl::get_full_var_name
std::string get_full_var_name(const ast::VarName &node)
Given a VarName node, return the full var name including index.
Definition: visitor_utils.cpp:286
nmodl::codegen::naming::RANDOM_FUNCTIONS_MAPPING
static std::unordered_map< std::string, std::string > RANDOM_FUNCTIONS_MAPPING
Definition: codegen_naming.hpp:241
nmodl::parser::UnitDriver::parse_string
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
nmodl::ast::Ast::accept
virtual void accept(visitor::Visitor &v)=0
Accept (or visit) the AST node using current visitor.
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:127
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
nmodl::ast::LocalVar
TODO.
Definition: local_var.hpp:38
nmodl::visitor::JSONVisitor::compact_json
JSONVisitor & compact_json(bool flag)
Definition: json_visitor.hpp:65
nmodl::visitor::JSONVisitor::expand_keys
JSONVisitor & expand_keys(bool flag)
Definition: json_visitor.hpp:75
nmodl::visitor::add_local_statement
void add_local_statement(StatementBlock &node)
Add empty local statement to given block if already doesn't exist.
Definition: visitor_utils.cpp:84
nmodl::visitor::JSONVisitor
Visitor for printing AST in JSON format
Definition: json_visitor.hpp:37
nmodl::visitor::get_new_name
std::string get_new_name(const std::string &name, const std::string &suffix, std::map< std::string, int > &variables)
Return new name variable by appending _suffix_COUNT where COUNT is number of times the given variable...
Definition: visitor_utils.cpp:62
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::utils::SingletonRandomString::instance
static SingletonRandomString & instance()
Function to instantiate the SingletonRandomString class.
Definition: common_utils.hpp:75
nmodl::visitor::create_statement_block
std::shared_ptr< StatementBlock > create_statement_block(const std::vector< std::string > &code_statements)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:156
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:93
nmodl::ast::VarName::get_name
std::shared_ptr< Identifier > get_name() const noexcept
Getter for member variable VarName::name.
Definition: var_name.hpp:164
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl::codegen::naming::NRN_POINTING_METHOD
static constexpr char NRN_POINTING_METHOD[]
nrn_pointing function in nmodl
Definition: codegen_naming.hpp:54
nmodl_driver.hpp
nmodl::ast::Name
Represents a name.
Definition: name.hpp:44
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::visitor::suffix_random_string
std::string suffix_random_string(const std::set< std::string > &vars, const std::string &original_string, const UseNumbersInString use_num)
Return the "original_string" with a random suffix if "original_string" exists in "vars".
Definition: visitor_utils.cpp:33
lookup_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
fmt.h
json_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::ast::String
Represents a string.
Definition: string.hpp:52
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:171
all.hpp
Auto generated AST classes declaration.
nmodl::ast::VarName::get_node_name
std::string get_node_name() const override
Return name of the node.
Definition: ast.cpp:1124
nmodl::visitor::JSONVisitor::add_nmodl
JSONVisitor & add_nmodl(bool flag)
Definition: json_visitor.hpp:70
nmodl_visitor.hpp
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::ast::AstNodeType::VAR_NAME
@ VAR_NAME
type of ast::VarName
nmodl::visitor::MetaAstLookupVisitor
Definition: lookup_visitor.hpp:34