User Guide
cvode_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
11 #include "lexer/token_mapping.hpp"
12 #include "pybind/pyembed.hpp"
13 #include "utils/logger.hpp"
15 #include <optional>
16 #include <regex>
17 #include <utility>
18 
19 namespace pywrap = nmodl::pybind_wrappers;
20 
21 namespace nmodl {
22 namespace visitor {
23 
24 static int get_index(const ast::IndexedName& node) {
25  return std::stoi(to_nmodl(node.get_length()));
26 }
27 
29  auto conserve_equations = collect_nodes(node, {ast::AstNodeType::CONSERVE});
30  if (!conserve_equations.empty()) {
31  std::unordered_set<ast::Statement*> eqs;
32  for (const auto& item: conserve_equations) {
33  eqs.insert(std::dynamic_pointer_cast<ast::Statement>(item).get());
34  }
35  node.erase_statement(eqs);
36  }
37 }
38 
39 // remove units from CVODE block so sympy can parse it properly
40 static void remove_units(ast::BinaryExpression& node) {
41  // matches either an int or a float, followed by any (including zero)
42  // number of spaces, followed by an expression in parentheses, that only
43  // has letters of the alphabet
44  std::regex unit_pattern(R"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))");
45  auto rhs_string = to_nmodl(node.get_rhs());
46  auto rhs_string_no_units = fmt::format("{} = {}",
47  to_nmodl(node.get_lhs()),
48  std::regex_replace(rhs_string, unit_pattern, "$1"));
49  logger->debug("CvodeVisitor :: removing units from statement {}", to_nmodl(node));
50  logger->debug("CvodeVisitor :: result: {}", rhs_string_no_units);
51  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
52  create_statement(rhs_string_no_units));
53  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
54  expr_statement->get_expression());
55  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
56 }
57 
58 static std::pair<std::string, std::optional<int>> parse_independent_var(
59  std::shared_ptr<ast::Identifier> node) {
60  auto variable = std::make_pair(node->get_node_name(), std::optional<int>());
61  if (node->is_indexed_name()) {
62  variable.second = std::optional<int>(
63  get_index(*std::dynamic_pointer_cast<const ast::IndexedName>(node)));
64  }
65  return variable;
66 }
67 
68 /// set of all indexed variables not equal to ``ignored_name``
69 static std::unordered_set<std::string> get_indexed_variables(const ast::Expression& node,
70  const std::string& ignored_name) {
71  std::unordered_set<std::string> indexed_variables;
72  // all of the "reserved" vars
73  auto reserved_symbols = get_external_functions();
74  // all indexed vars
75  auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME});
76  for (const auto& var: indexed_vars) {
77  const auto& varname = var->get_node_name();
78  // skip if it's a reserved var
79  auto varname_not_reserved =
80  std::none_of(reserved_symbols.begin(),
81  reserved_symbols.end(),
82  [&varname](const auto item) { return varname == item; });
83  if (indexed_variables.count(varname) == 0 && varname != ignored_name &&
84  varname_not_reserved) {
85  indexed_variables.insert(varname);
86  }
87  }
88  return indexed_variables;
89 }
90 
91 static std::string cvode_set_lhs(ast::BinaryExpression& node) {
92  const auto& lhs = node.get_lhs();
93 
94  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
95 
96  std::string varname;
97  if (name->is_prime_name()) {
98  varname = "D" + name->get_node_name();
99  node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
100  } else if (name->is_indexed_name()) {
101  auto nodes = collect_nodes(*name, {ast::AstNodeType::PRIME_NAME});
102  // make sure the LHS isn't just a plain indexed var
103  if (!nodes.empty()) {
104  varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\'');
105  auto statement = fmt::format("{} = {}", varname, varname);
106  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
107  create_statement(statement));
108  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
109  expr_statement->get_expression());
110  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
111  }
112  }
113  return varname;
114 }
115 
116 
118  protected:
121  public:
124  node.visit_children(*this);
125  in_differential_equation = false;
126  }
127 };
128 
130  public:
132  program_symtab = symtab;
133  }
134 
136  const auto& lhs = node.get_lhs();
137 
138  if (!in_differential_equation || !lhs->is_var_name()) {
139  return;
140  }
141 
142  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
143  auto varname = cvode_set_lhs(node);
144 
145  if (program_symtab->lookup(varname) == nullptr) {
146  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
147  symbol->set_original_name(name->get_node_name());
148  program_symtab->insert(symbol);
149  }
150  }
151 };
152 
154  public:
155  explicit StiffVisitor(symtab::SymbolTable* symtab) {
156  program_symtab = symtab;
157  }
158 
160  const auto& lhs = node.get_lhs();
161 
162  if (!in_differential_equation || !lhs->is_var_name()) {
163  return;
164  }
165 
166  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
167  auto varname = cvode_set_lhs(node);
168 
169  if (program_symtab->lookup(varname) == nullptr) {
170  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
171  symbol->set_original_name(name->get_node_name());
172  program_symtab->insert(symbol);
173  }
174 
175  remove_units(node);
176 
177  auto rhs = node.get_rhs();
178 
179  // all indexed variables (need special treatment in SymPy)
180  auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name());
182  auto [jacobian, exception_message] =
183  diff2c(to_nmodl(*rhs), parse_independent_var(name), indexed_variables);
184  if (!exception_message.empty()) {
185  logger->warn("CvodeVisitor :: python exception: {}", exception_message);
186  }
187  // NOTE: LHS can be anything here, the equality is to keep `create_statement` from
188  // complaining, we discard the LHS later
189  auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian);
190  logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement);
191  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
192  create_statement(statement));
193  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
194  expr_statement->get_expression());
195  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
196  }
197 };
198 
199 static std::shared_ptr<ast::DerivativeBlock> get_derivative_block(ast::Program& node) {
200  auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK});
201  if (derivative_blocks.empty()) {
202  return nullptr;
203  }
204 
205  // steady state adds a DERIVATIVE block with a `_steadystate` suffix
206  auto not_steadystate = [](const auto& item) {
207  auto name = std::dynamic_pointer_cast<const ast::DerivativeBlock>(item)->get_node_name();
208  return !stringutils::ends_with(name, "_steadystate");
209  };
210  decltype(derivative_blocks) derivative_blocks_copy;
211  std::copy_if(derivative_blocks.begin(),
212  derivative_blocks.end(),
213  std::back_inserter(derivative_blocks_copy),
214  not_steadystate);
215  if (derivative_blocks_copy.size() > 1) {
216  auto message = "CvodeVisitor :: cannot have multiple DERIVATIVE blocks";
217  logger->error(message);
218  throw std::runtime_error(message);
219  }
220 
221  return std::dynamic_pointer_cast<ast::DerivativeBlock>(derivative_blocks_copy[0]);
222 }
223 
224 
226  auto derivative_block = get_derivative_block(node);
227  if (derivative_block == nullptr) {
228  return;
229  }
230 
231  auto non_stiff_block = derivative_block->get_statement_block()->clone();
232  remove_conserve_statements(*non_stiff_block);
233 
234  auto stiff_block = derivative_block->get_statement_block()->clone();
235  remove_conserve_statements(*stiff_block);
236 
237  NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block);
239  auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME});
241  derivative_block->get_name(),
242  std::shared_ptr<ast::Integer>(new ast::Integer(prime_vars.size(), nullptr)),
243  std::shared_ptr<ast::StatementBlock>(non_stiff_block),
244  std::shared_ptr<ast::StatementBlock>(stiff_block)));
245 }
246 
247 } // namespace visitor
248 } // namespace nmodl
nmodl::ast::IndexedName::get_length
std::shared_ptr< Expression > get_length() const noexcept
Getter for member variable IndexedName::length.
Definition: indexed_name.hpp:176
nmodl::visitor::CvodeHelperVisitor::visit_diff_eq_expression
void visit_diff_eq_expression(ast::DiffEqExpression &node)
visit node of type ast::DiffEqExpression
Definition: cvode_visitor.cpp:122
nmodl::ast::AstNodeType::INDEXED_NAME
@ INDEXED_NAME
type of ast::IndexedName
nmodl::ast::AstNodeType::DERIVATIVE_BLOCK
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:29
nmodl::get_external_functions
std::vector< std::string > get_external_functions()
Return functions that can be used in the NMODL.
Definition: token_mapping.cpp:308
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
nmodl::visitor::get_index
static int get_index(const ast::IndexedName &node)
Definition: cvode_visitor.cpp:24
nmodl::ast::DiffEqExpression
Represents differential equation in DERIVATIVE block.
Definition: diff_eq_expression.hpp:38
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::ast::CvodeBlock
Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks.
Definition: cvode_block.hpp:38
token_mapping.hpp
Map different tokens from lexer to token types.
nmodl::ast::AstNodeType::PRIME_NAME
@ PRIME_NAME
type of ast::PrimeName
nmodl::ast::Integer
Represents an integer variable.
Definition: integer.hpp:49
nmodl::ast::AstNodeType::CONSERVE
@ CONSERVE
type of ast::Conserve
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::get_indexed_variables
static std::unordered_set< std::string > get_indexed_variables(const ast::Expression &node, const std::string &ignored_name)
set of all indexed variables not equal to ignored_name
Definition: cvode_visitor.cpp:69
nmodl::visitor::NonStiffVisitor::NonStiffVisitor
NonStiffVisitor(symtab::SymbolTable *symtab)
Definition: cvode_visitor.cpp:131
nmodl::visitor::CvodeHelperVisitor
Definition: cvode_visitor.cpp:117
nmodl::ast::BinaryExpression::get_rhs
std::shared_ptr< Expression > get_rhs() const noexcept
Getter for member variable BinaryExpression::rhs.
Definition: binary_expression.hpp:179
nmodl::visitor::NonStiffVisitor::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
Definition: cvode_visitor.cpp:135
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::BinaryExpression::set_rhs
void set_rhs(std::shared_ptr< Expression > &&rhs)
Setter for member variable BinaryExpression::rhs (rvalue reference)
Definition: ast.cpp:6616
nmodl::ast::IndexedName
Represents specific element of an array variable.
Definition: indexed_name.hpp:48
nmodl::ast::StatementBlock::erase_statement
StatementVector::const_iterator erase_statement(StatementVector::const_iterator first)
Erase member to statements.
Definition: ast.cpp:3096
nmodl::ast::BinaryExpression::set_lhs
void set_lhs(std::shared_ptr< Expression > &&lhs)
Setter for member variable BinaryExpression::lhs (rvalue reference)
Definition: ast.cpp:6590
nmodl::ast::DiffEqExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6642
nmodl::visitor::StiffVisitor::StiffVisitor
StiffVisitor(symtab::SymbolTable *symtab)
Definition: cvode_visitor.cpp:155
nmodl::pybind_wrappers::EmbeddedPythonLoader::api
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:135
nmodl::visitor::remove_conserve_statements
static void remove_conserve_statements(ast::StatementBlock &node)
Definition: cvode_visitor.cpp:28
nmodl::visitor::CvodeHelperVisitor::program_symtab
symtab::SymbolTable * program_symtab
Definition: cvode_visitor.cpp:119
nmodl::visitor::AstVisitor::visit_statement_block
void visit_statement_block(ast::StatementBlock &node) override
visit node of type ast::StatementBlock
Definition: ast_visitor.cpp:124
nmodl::visitor::CvodeHelperVisitor::in_differential_equation
bool in_differential_equation
Definition: cvode_visitor.cpp:120
nmodl::visitor::AstVisitor
Concrete visitor for all AST classes.
Definition: ast_visitor.hpp:37
nmodl::visitor::StiffVisitor::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node)
visit node of type ast::BinaryExpression
Definition: cvode_visitor.cpp:159
nmodl::symtab::SymbolTable
Represent symbol table for a NMODL block.
Definition: symbol_table.hpp:57
nmodl::visitor::CvodeVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: cvode_visitor.cpp:225
nmodl::symtab::SymbolTable::insert
void insert(const std::shared_ptr< Symbol > &symbol)
Definition: symbol_table.hpp:178
nmodl::visitor::parse_independent_var
static std::pair< std::string, std::optional< int > > parse_independent_var(std::shared_ptr< ast::Identifier > node)
Definition: cvode_visitor.cpp:58
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:127
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
nmodl::visitor::remove_units
static void remove_units(ast::BinaryExpression &node)
Definition: cvode_visitor.cpp:40
nmodl::ast::StatementBlock
Represents block encapsulating list of statements.
Definition: statement_block.hpp:53
nmodl::visitor::get_derivative_block
static std::shared_ptr< ast::DerivativeBlock > get_derivative_block(ast::Program &node)
Definition: cvode_visitor.cpp:199
nmodl::stringutils::remove_character
static std::string remove_character(std::string text, const char c)
Remove all occurrences of a given character in a text.
Definition: string_utils.hpp:73
logger.hpp
Implement logger based on spdlog library.
nmodl::pybind_wrappers
Definition: pyembed.cpp:25
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl::ast::BinaryExpression::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable BinaryExpression::lhs.
Definition: binary_expression.hpp:161
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::ast::Expression
Base class for all expressions in the NMODL.
Definition: expression.hpp:43
cvode_visitor.hpp
Visitor used for generating the necessary AST nodes for CVODE.
nmodl::ast::Program::emplace_back_node
void emplace_back_node(Node *n)
Add member to blocks by raw pointer.
Definition: ast.cpp:12821
nmodl::symtab::SymbolTable::lookup
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Definition: symbol_table.hpp:199
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
nmodl::stringutils::ends_with
static bool ends_with(const std::string &haystack, const std::string &needle)
Check if haystack ends with needle.
Definition: string_utils.hpp:135
nmodl::ModToken
Represent token returned by scanner.
Definition: modtoken.hpp:50
nmodl::visitor::cvode_set_lhs
static std::string cvode_set_lhs(ast::BinaryExpression &node)
Definition: cvode_visitor.cpp:91
nmodl::pybind_wrappers::pybind_wrap_api::diff2c
decltype(&call_diff2c) diff2c
Definition: wrapper.hpp:67
nmodl::ast::String
Represents a string.
Definition: string.hpp:52
all.hpp
Auto generated AST classes declaration.
nmodl::visitor::StiffVisitor
Definition: cvode_visitor.cpp:153
pyembed.hpp
nmodl::visitor::NonStiffVisitor
Definition: cvode_visitor.cpp:129