User Guide
sympy_conductance_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include <algorithm>
11 
12 #include "ast/all.hpp"
13 #include "pybind/pyembed.hpp"
14 #include "symtab/symbol.hpp"
15 #include "utils/logger.hpp"
17 
18 namespace pywrap = nmodl::pybind_wrappers;
19 
20 namespace nmodl {
21 namespace visitor {
22 
23 using ast::AstNodeType;
24 using ast::BinaryOp;
26 
27 /**
28  * Analyse breakpoint block to check if it is safe to insert CONDUCTANCE statements
29  *
30  * Most of the mod files have simple breakpoint blocks without any control flow
31  * statements. SympyConductanceVisitor just collects all the statements in the
32  * breakpoint block and insert conductance statements. If there are control flow
33  * statements like `IF { a } ELSE { b }` block with conflicting current statements
34  * inside IF and ELSE blocks or VERBATIM block then the resulting CONDUCTANCE
35  * statements may be incorrect. For now the simple approach is to not generate
36  * CONDUCTANCE statements if if-else statements exist in the block.
37  *
38  * @param node Ast node for breakpoint block
39  * @return true if it is safe to insert conductance statements otherwise false
40  */
42  return collect_nodes(node, {AstNodeType::IF_STATEMENT, AstNodeType::VERBATIM}).empty();
43 }
44 
45 
46 // Generate statement strings to be added to BREAKPOINT section
47 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
49  ast::BreakpointBlock& node) {
50  std::vector<std::string> statements;
51 
52  // instead of passing all variables in the symbol table find out variables
53  // that are used in the current block and then pass to sympy
54  // name could be parameter or unit so check if it exist in symbol table
55  const auto& names_in_block = collect_nodes(node, {AstNodeType::NAME});
56  string_set used_names_in_block;
57  for (const auto& name: names_in_block) {
58  auto varname = name->get_node_name();
59  if (all_vars.find(varname) != all_vars.end()) {
60  used_names_in_block.insert(varname);
61  }
62  }
63 
64  // iterate over binary expression lhs's from breakpoint
65  for (const auto& lhs_str: ordered_binary_exprs_lhs) {
66  // look for a current name that matches lhs of expr (current write name)
67  auto it = i_name.find(lhs_str);
68  if (it != i_name.end()) {
69  std::string i_name_str = it->second;
70  // SymPy needs the current expression & all previous expressions
71  std::vector<std::string> expressions(ordered_binary_exprs.begin(),
72  ordered_binary_exprs.begin() +
73  static_cast<std::ptrdiff_t>(
74  binary_expr_index[lhs_str]) +
75  1);
76  // differentiate dI/dV
78  auto [dIdV, exception_message] = analytic_diff(expressions, used_names_in_block);
79  if (!exception_message.empty()) {
80  logger->warn("SympyConductance :: python exception: {}", exception_message);
81  }
82  if (dIdV.empty()) {
83  logger->warn(
84  "SympyConductance :: analytic differentiation of ionic current "
85  "not possible");
86  } else {
87  std::string g_var = dIdV;
88  // if conductance g_var is not an existing variable, need to generate
89  // a new variable name, declare it, and assign it
90  if (all_vars.find(g_var) == all_vars.end()) {
91  // generate new variable name
92  std::map<std::string, int> var_map;
93  for (auto const& v: all_vars) {
94  var_map[v] = 0;
95  }
96  g_var = get_new_name("g", i_name_str, var_map);
97  // declare it
99  // asign dIdV to it
100  std::string statement_str = g_var;
101  statement_str.append(" = ").append(dIdV);
102  statements.insert(statements.begin(), statement_str);
103  logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}",
104  statement_str);
105  }
106  std::string statement_str = "CONDUCTANCE " + g_var;
107  if (!i_name_str.empty()) {
108  statement_str += " USEION " + i_name_str;
109  }
110  statements.push_back(statement_str);
111  logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}", statement_str);
112  }
113  }
114  }
115  return statements;
116 }
117 
119  // only want binary expressions from breakpoint block
120  if (!under_breakpoint_block) {
121  return;
122  }
123  // only want binary expressions of form x = ...
124  if (node.get_lhs()->is_var_name() && (node.get_op().get_value() == BinaryOp::BOP_ASSIGN)) {
125  auto lhs_str =
126  std::dynamic_pointer_cast<ast::VarName>(node.get_lhs())->get_name()->get_node_name();
127  binary_expr_index[lhs_str] = ordered_binary_exprs.size();
128  ordered_binary_exprs.push_back(to_nmodl_for_sympy(node));
129  ordered_binary_exprs_lhs.push_back(lhs_str);
130  }
131 }
132 
134  // add NONSPECIFIC_CURRENT statements to i_name map between write vars and names
135  // note that they don't have an ion name, so we set it to ""
137  for (const auto& ns_curr_ast: nonspecific_nodes) {
138  logger->debug("SympyConductance :: Found NONSPECIFIC_CURRENT statement");
139  for (const auto& write_name:
140  std::dynamic_pointer_cast<const ast::Nonspecific>(ns_curr_ast)->get_currents()) {
141  const std::string& curr_write = write_name->get_node_name();
142  logger->debug("SympyConductance :: -> Adding non-specific current write name: {}",
143  curr_write);
144  i_name[curr_write] = "";
145  }
146  }
147  }
148 }
149 
152 }
153 
154 
156  // add USEION statements to i_name map between write vars and names
157  for (const auto& useion_ast: use_ion_nodes) {
158  const auto& ion = std::dynamic_pointer_cast<const ast::Useion>(useion_ast);
159  const std::string& ion_name = ion->get_node_name();
160  logger->debug("SympyConductance :: Found USEION statement {}", to_nmodl_for_sympy(*ion));
161  if (i_ignore.find(ion_name) != i_ignore.end()) {
162  logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion_name);
163  } else {
164  for (const auto& w: ion->get_writelist()) {
165  std::string ion_write = w->get_node_name();
166  logger->debug(
167  "SympyConductance :: -> Adding ion write name: {} for ion current name: {}",
168  ion_write,
169  ion_name);
170  i_name[ion_write] = ion_name;
171  }
172  }
173  }
174 }
175 
177  // find existing CONDUCTANCE statements - do not want to alter them
178  // so keep a set of ion names i_ignore that we should ignore later
179  logger->debug("SympyConductance :: Found existing CONDUCTANCE statement: {}",
180  to_nmodl_for_sympy(node));
181  if (const auto& ion = node.get_ion()) {
182  logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion->get_node_name());
183  i_ignore.insert(ion->get_node_name());
184  } else {
185  logger->debug("SympyConductance :: -> Ignoring all non-specific currents");
187  }
188 };
189 
191  // return if it's not safe to insert conductance statements
192  if (!conductance_statement_possible(node)) {
193  logger->warn("SympyConductance :: Unsafe to insert CONDUCTANCE statement");
194  return;
195  }
196 
197  // add any breakpoint local variables to vars
198  if (auto* symtab = node.get_statement_block()->get_symbol_table()) {
199  for (const auto& localvar: symtab->get_variables_with_properties(NmodlType::local_var)) {
200  all_vars.insert(localvar->get_name());
201  }
202  }
203  // visit BREAKPOINT block statements
204  under_breakpoint_block = true;
205  node.visit_children(*this);
206  under_breakpoint_block = false;
207 
208  // lookup USEION and NONSPECIFIC statements from NEURON block
211 
212  // add new CONDUCTANCE statements to BREAKPOINT
213  auto new_statements = generate_statement_strings(node);
214  if (!new_statements.empty()) {
215  // get a copy of existing BREAKPOINT statements
216  auto brkpnt_statements = node.get_statement_block()->get_statements();
217  // insert new CONDUCTANCE statements at top of BREAKPOINT
218  // or just below LOCAL statement if it exists
219  auto insertion_point = brkpnt_statements.begin();
220  while ((*insertion_point)->is_local_list_statement()) {
221  ++insertion_point;
222  }
223  for (const auto& statement_str: new_statements) {
224  insertion_point = brkpnt_statements.insert(insertion_point,
225  create_statement(statement_str));
226  }
227  // replace old set of BREAKPOINT statements in AST with new one
228  node.get_statement_block()->set_statements(std::move(brkpnt_statements));
229  }
230 }
231 
233  all_vars = get_global_vars(node);
234  const auto& program = node;
235  use_ion_nodes = collect_nodes(program, {AstNodeType::USEION});
236  nonspecific_nodes = collect_nodes(program, {AstNodeType::NONSPECIFIC});
237 
238  node.visit_children(*this);
239 }
240 
241 } // namespace visitor
242 } // namespace nmodl
nmodl::ast::BinaryOp
BinaryOp
enum Type for binary operators in NMODL
Definition: ast_common.hpp:47
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:29
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:234
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
nmodl::ast::ConductanceHint::get_ion
std::shared_ptr< Name > get_ion() const noexcept
Getter for member variable ConductanceHint::ion.
Definition: conductance_hint.hpp:176
nmodl::visitor::SympyConductanceVisitor::generate_statement_strings
std::vector< std::string > generate_statement_strings(ast::BreakpointBlock &node)
Definition: sympy_conductance_visitor.cpp:48
nmodl::ast::ConductanceHint
Represents CONDUCTANCE statement in NMODL.
Definition: conductance_hint.hpp:46
nmodl::visitor::SympyConductanceVisitor::i_name
string_map i_name
map between current write names and ion names
Definition: sympy_conductance_visitor.hpp:70
nmodl::visitor::SympyConductanceVisitor::under_breakpoint_block
bool under_breakpoint_block
true while visiting breakpoint block
Definition: sympy_conductance_visitor.hpp:61
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::visitor::SympyConductanceVisitor::all_vars
string_set all_vars
set of all variables for SymPy
Definition: sympy_conductance_visitor.hpp:64
nmodl::visitor::SympyConductanceVisitor::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
Definition: sympy_conductance_visitor.cpp:118
nmodl::visitor::SympyConductanceVisitor::ordered_binary_exprs
std::vector< std::string > ordered_binary_exprs
list in order of binary expressions in breakpoint
Definition: sympy_conductance_visitor.hpp:75
nmodl::ast::BOP_ASSIGN
@ BOP_ASSIGN
=
Definition: ast_common.hpp:59
nmodl::visitor::SympyConductanceVisitor::to_nmodl_for_sympy
static std::string to_nmodl_for_sympy(const ast::Ast &node)
Definition: sympy_conductance_visitor.cpp:150
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::SympyConductanceVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: sympy_conductance_visitor.cpp:232
nmodl::ast::BreakpointBlock::get_statement_block
std::shared_ptr< StatementBlock > get_statement_block() const noexcept override
Getter for member variable BreakpointBlock::statement_block.
Definition: breakpoint_block.hpp:188
nmodl::visitor::SympyConductanceVisitor::use_ion_nodes
std::vector< std::shared_ptr< const ast::Ast > > use_ion_nodes
use ion ast nodes
Definition: sympy_conductance_visitor.hpp:84
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
nmodl::visitor::SympyConductanceVisitor::visit_breakpoint_block
void visit_breakpoint_block(ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
Definition: sympy_conductance_visitor.cpp:190
nmodl::visitor::SympyConductanceVisitor::lookup_nonspecific_statements
void lookup_nonspecific_statements()
Definition: sympy_conductance_visitor.cpp:133
nmodl::visitor::SympyConductanceVisitor::string_set
std::set< std::string > string_set
Definition: sympy_conductance_visitor.hpp:57
visitor_utils.hpp
Utility functions for visitors implementation.
sympy_conductance_visitor.hpp
Visitor for generating CONDUCTANCE statements for ions
nmodl::ast::BreakpointBlock
Represents a BREAKPOINT block in NMODL.
Definition: breakpoint_block.hpp:53
nmodl::visitor::SympyConductanceVisitor::binary_expr_index
std::map< std::string, std::size_t > binary_expr_index
map from lhs of binary expression to index of expression in above vector
Definition: sympy_conductance_visitor.hpp:81
nmodl::pybind_wrappers::EmbeddedPythonLoader::api
const pybind_wrap_api & api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:135
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12906
nmodl::ast::AstNodeType::UNIT_DEF
@ UNIT_DEF
type of ast::UnitDef
nmodl::ast::BinaryExpression::get_op
const BinaryOperator & get_op() const noexcept
Getter for member variable BinaryExpression::op.
Definition: binary_expression.hpp:170
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:127
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:206
nmodl::ast::AstNodeType::UNIT
@ UNIT
type of ast::Unit
nmodl::visitor::get_new_name
std::string get_new_name(const std::string &name, const std::string &suffix, std::map< std::string, int > &variables)
Return new name variable by appending _suffix_COUNT where COUNT is number of times the given variable...
Definition: visitor_utils.cpp:62
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::pybind_wrappers::pybind_wrap_api::analytic_diff
decltype(&call_analytic_diff) analytic_diff
Definition: wrapper.hpp:66
nmodl::ast::BreakpointBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:4646
nmodl::visitor::conductance_statement_possible
static bool conductance_statement_possible(const ast::BreakpointBlock &node)
Analyse breakpoint block to check if it is safe to insert CONDUCTANCE statements.
Definition: sympy_conductance_visitor.cpp:41
nmodl::ast::BinaryOperator::get_value
BinaryOp get_value() const noexcept
Getter for member variable BinaryOperator::value.
Definition: binary_operator.hpp:143
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::SympyConductanceVisitor::ordered_binary_exprs_lhs
std::vector< std::string > ordered_binary_exprs_lhs
ditto but for LHS of expression only
Definition: sympy_conductance_visitor.hpp:78
nmodl::visitor::SympyConductanceVisitor::NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS
bool NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS
Definition: sympy_conductance_visitor.hpp:72
nmodl::pybind_wrappers
Definition: pyembed.cpp:25
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:93
nmodl::ast::BinaryExpression::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable BinaryExpression::lhs.
Definition: binary_expression.hpp:161
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::visitor::SympyConductanceVisitor::i_ignore
string_set i_ignore
set of currents to ignore
Definition: sympy_conductance_visitor.hpp:67
nmodl::visitor::SympyConductanceVisitor::nonspecific_nodes
std::vector< std::shared_ptr< const ast::Ast > > nonspecific_nodes
non specific currents
Definition: sympy_conductance_visitor.hpp:87
nmodl::visitor::SympyConductanceVisitor::visit_conductance_hint
void visit_conductance_hint(ast::ConductanceHint &node) override
visit node of type ast::ConductanceHint
Definition: sympy_conductance_visitor.cpp:176
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
nmodl::visitor::SympyConductanceVisitor::lookup_useion_statements
void lookup_useion_statements()
Definition: sympy_conductance_visitor.cpp:155
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:171
all.hpp
Auto generated AST classes declaration.
pyembed.hpp