User Guide
sympy_conductance_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include <algorithm>
11 
12 #include "ast/all.hpp"
13 #include "pybind/pyembed.hpp"
14 #include "symtab/symbol.hpp"
15 #include "utils/logger.hpp"
17 
18 namespace pywrap = nmodl::pybind_wrappers;
19 
20 namespace nmodl {
21 namespace visitor {
22 
23 using ast::AstNodeType;
24 using ast::BinaryOp;
26 
27 /**
28  * Analyse breakpoint block to check if it is safe to insert CONDUCTANCE statements
29  *
30  * Most of the mod files have simple breakpoint blocks without any control flow
31  * statements. SympyConductanceVisitor just collects all the statements in the
32  * breakpoint block and insert conductance statements. If there are control flow
33  * statements like `IF { a } ELSE { b }` block with conflicting current statements
34  * inside IF and ELSE blocks or VERBATIM block then the resulting CONDUCTANCE
35  * statements may be incorrect. For now the simple approach is to not generate
36  * CONDUCTANCE statements if if-else statements exist in the block.
37  *
38  * @param node Ast node for breakpoint block
39  * @return true if it is safe to insert conductance statements otherwise false
40  */
42  return collect_nodes(node, {AstNodeType::IF_STATEMENT, AstNodeType::VERBATIM}).empty();
43 }
44 
45 
46 // Generate statement strings to be added to BREAKPOINT section
47 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
49  ast::BreakpointBlock& node) {
50  std::vector<std::string> statements;
51 
52  // instead of passing all variables in the symbol table find out variables
53  // that are used in the current block and then pass to sympy
54  // name could be parameter or unit so check if it exist in symbol table
55  const auto& names_in_block = collect_nodes(node, {AstNodeType::NAME});
56  string_set used_names_in_block;
57  for (const auto& name: names_in_block) {
58  auto varname = name->get_node_name();
59  if (all_vars.find(varname) != all_vars.end()) {
60  used_names_in_block.insert(varname);
61  }
62  }
63 
64  // iterate over binary expression lhs's from breakpoint
65  for (const auto& lhs_str: ordered_binary_exprs_lhs) {
66  // look for a current name that matches lhs of expr (current write name)
67  auto it = i_name.find(lhs_str);
68  if (it != i_name.end()) {
69  std::string i_name_str = it->second;
70  // SymPy needs the current expression & all previous expressions
71  std::vector<std::string> expressions(ordered_binary_exprs.begin(),
72  ordered_binary_exprs.begin() +
73  static_cast<std::ptrdiff_t>(
74  binary_expr_index[lhs_str]) +
75  1);
76  // differentiate dI/dV
77  auto analytic_diff =
79  analytic_diff->expressions = expressions;
80  analytic_diff->used_names_in_block = used_names_in_block;
81  (*analytic_diff)();
82  auto dIdV = analytic_diff->solution;
83  auto exception_message = analytic_diff->exception_message;
85  if (!exception_message.empty()) {
86  logger->warn("SympyConductance :: python exception: {}", exception_message);
87  }
88  if (dIdV.empty()) {
89  logger->warn(
90  "SympyConductance :: analytic differentiation of ionic current "
91  "not possible");
92  } else {
93  std::string g_var = dIdV;
94  // if conductance g_var is not an existing variable, need to generate
95  // a new variable name, declare it, and assign it
96  if (all_vars.find(g_var) == all_vars.end()) {
97  // generate new variable name
98  std::map<std::string, int> var_map;
99  for (auto const& v: all_vars) {
100  var_map[v] = 0;
101  }
102  g_var = get_new_name("g", i_name_str, var_map);
103  // declare it
104  add_local_variable(*node.get_statement_block(), g_var);
105  // asign dIdV to it
106  std::string statement_str = g_var;
107  statement_str.append(" = ").append(dIdV);
108  statements.insert(statements.begin(), statement_str);
109  logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}",
110  statement_str);
111  }
112  std::string statement_str = "CONDUCTANCE " + g_var;
113  if (!i_name_str.empty()) {
114  statement_str += " USEION " + i_name_str;
115  }
116  statements.push_back(statement_str);
117  logger->debug("SympyConductance :: Adding BREAKPOINT statement: {}", statement_str);
118  }
119  }
120  }
121  return statements;
122 }
123 
125  // only want binary expressions from breakpoint block
126  if (!under_breakpoint_block) {
127  return;
128  }
129  // only want binary expressions of form x = ...
130  if (node.get_lhs()->is_var_name() && (node.get_op().get_value() == BinaryOp::BOP_ASSIGN)) {
131  auto lhs_str =
132  std::dynamic_pointer_cast<ast::VarName>(node.get_lhs())->get_name()->get_node_name();
133  binary_expr_index[lhs_str] = ordered_binary_exprs.size();
134  ordered_binary_exprs.push_back(to_nmodl_for_sympy(node));
135  ordered_binary_exprs_lhs.push_back(lhs_str);
136  }
137 }
138 
140  // add NONSPECIFIC_CURRENT statements to i_name map between write vars and names
141  // note that they don't have an ion name, so we set it to ""
143  for (const auto& ns_curr_ast: nonspecific_nodes) {
144  logger->debug("SympyConductance :: Found NONSPECIFIC_CURRENT statement");
145  for (const auto& write_name:
146  std::dynamic_pointer_cast<const ast::Nonspecific>(ns_curr_ast)->get_currents()) {
147  const std::string& curr_write = write_name->get_node_name();
148  logger->debug("SympyConductance :: -> Adding non-specific current write name: {}",
149  curr_write);
150  i_name[curr_write] = "";
151  }
152  }
153  }
154 }
155 
158 }
159 
160 
162  // add USEION statements to i_name map between write vars and names
163  for (const auto& useion_ast: use_ion_nodes) {
164  const auto& ion = std::dynamic_pointer_cast<const ast::Useion>(useion_ast);
165  const std::string& ion_name = ion->get_node_name();
166  logger->debug("SympyConductance :: Found USEION statement {}", to_nmodl_for_sympy(*ion));
167  if (i_ignore.find(ion_name) != i_ignore.end()) {
168  logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion_name);
169  } else {
170  for (const auto& w: ion->get_writelist()) {
171  std::string ion_write = w->get_node_name();
172  logger->debug(
173  "SympyConductance :: -> Adding ion write name: {} for ion current name: {}",
174  ion_write,
175  ion_name);
176  i_name[ion_write] = ion_name;
177  }
178  }
179  }
180 }
181 
183  // find existing CONDUCTANCE statements - do not want to alter them
184  // so keep a set of ion names i_ignore that we should ignore later
185  logger->debug("SympyConductance :: Found existing CONDUCTANCE statement: {}",
186  to_nmodl_for_sympy(node));
187  if (const auto& ion = node.get_ion()) {
188  logger->debug("SympyConductance :: -> Ignoring ion current name: {}", ion->get_node_name());
189  i_ignore.insert(ion->get_node_name());
190  } else {
191  logger->debug("SympyConductance :: -> Ignoring all non-specific currents");
193  }
194 };
195 
197  // return if it's not safe to insert conductance statements
198  if (!conductance_statement_possible(node)) {
199  logger->warn("SympyConductance :: Unsafe to insert CONDUCTANCE statement");
200  return;
201  }
202 
203  // add any breakpoint local variables to vars
204  if (auto* symtab = node.get_statement_block()->get_symbol_table()) {
205  for (const auto& localvar: symtab->get_variables_with_properties(NmodlType::local_var)) {
206  all_vars.insert(localvar->get_name());
207  }
208  }
209  // visit BREAKPOINT block statements
210  under_breakpoint_block = true;
211  node.visit_children(*this);
212  under_breakpoint_block = false;
213 
214  // lookup USEION and NONSPECIFIC statements from NEURON block
217 
218  // add new CONDUCTANCE statements to BREAKPOINT
219  auto new_statements = generate_statement_strings(node);
220  if (!new_statements.empty()) {
221  // get a copy of existing BREAKPOINT statements
222  auto brkpnt_statements = node.get_statement_block()->get_statements();
223  // insert new CONDUCTANCE statements at top of BREAKPOINT
224  // or just below LOCAL statement if it exists
225  auto insertion_point = brkpnt_statements.begin();
226  while ((*insertion_point)->is_local_list_statement()) {
227  ++insertion_point;
228  }
229  for (const auto& statement_str: new_statements) {
230  insertion_point = brkpnt_statements.insert(insertion_point,
231  create_statement(statement_str));
232  }
233  // replace old set of BREAKPOINT statements in AST with new one
234  node.get_statement_block()->set_statements(std::move(brkpnt_statements));
235  }
236 }
237 
239  all_vars = get_global_vars(node);
240  const auto& program = node;
241  use_ion_nodes = collect_nodes(program, {AstNodeType::USEION});
242  nonspecific_nodes = collect_nodes(program, {AstNodeType::NONSPECIFIC});
243 
244  node.visit_children(*this);
245 }
246 
247 } // namespace visitor
248 } // namespace nmodl
nmodl::ast::BinaryOp
BinaryOp
enum Type for binary operators in NMODL
Definition: ast_common.hpp:47
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance
static EmbeddedPythonLoader & get_instance()
Construct (if not already done) and get the only instance of this class.
Definition: pyembed.hpp:141
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
nmodl::ast::ConductanceHint::get_ion
std::shared_ptr< Name > get_ion() const noexcept
Getter for member variable ConductanceHint::ion.
Definition: conductance_hint.hpp:176
nmodl::visitor::SympyConductanceVisitor::generate_statement_strings
std::vector< std::string > generate_statement_strings(ast::BreakpointBlock &node)
Definition: sympy_conductance_visitor.cpp:48
nmodl::ast::ConductanceHint
Represents CONDUCTANCE statement in NMODL.
Definition: conductance_hint.hpp:46
nmodl::visitor::SympyConductanceVisitor::i_name
string_map i_name
map between current write names and ion names
Definition: sympy_conductance_visitor.hpp:69
nmodl::visitor::SympyConductanceVisitor::under_breakpoint_block
bool under_breakpoint_block
true while visiting breakpoint block
Definition: sympy_conductance_visitor.hpp:60
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::pybind_wrappers::EmbeddedPythonLoader::api
const pybind_wrap_api * api()
Get a pointer to the pybind_wrap_api struct.
Definition: pyembed.cpp:87
nmodl::visitor::SympyConductanceVisitor::all_vars
string_set all_vars
set of all variables for SymPy
Definition: sympy_conductance_visitor.hpp:63
nmodl::visitor::SympyConductanceVisitor::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
Definition: sympy_conductance_visitor.cpp:124
nmodl::visitor::SympyConductanceVisitor::ordered_binary_exprs
std::vector< std::string > ordered_binary_exprs
list in order of binary expressions in breakpoint
Definition: sympy_conductance_visitor.hpp:74
nmodl::ast::BOP_ASSIGN
@ BOP_ASSIGN
=
Definition: ast_common.hpp:59
nmodl::visitor::SympyConductanceVisitor::to_nmodl_for_sympy
static std::string to_nmodl_for_sympy(const ast::Ast &node)
Definition: sympy_conductance_visitor.cpp:156
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::SympyConductanceVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: sympy_conductance_visitor.cpp:238
nmodl::ast::BreakpointBlock::get_statement_block
std::shared_ptr< StatementBlock > get_statement_block() const noexcept override
Getter for member variable BreakpointBlock::statement_block.
Definition: breakpoint_block.hpp:188
nmodl::visitor::SympyConductanceVisitor::use_ion_nodes
std::vector< std::shared_ptr< const ast::Ast > > use_ion_nodes
use ion ast nodes
Definition: sympy_conductance_visitor.hpp:83
nmodl::ast::AstNodeType
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:164
nmodl::visitor::SympyConductanceVisitor::visit_breakpoint_block
void visit_breakpoint_block(ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
Definition: sympy_conductance_visitor.cpp:196
nmodl::visitor::SympyConductanceVisitor::lookup_nonspecific_statements
void lookup_nonspecific_statements()
Definition: sympy_conductance_visitor.cpp:139
nmodl::visitor::SympyConductanceVisitor::string_set
std::set< std::string > string_set
Definition: sympy_conductance_visitor.hpp:56
visitor_utils.hpp
Utility functions for visitors implementation.
sympy_conductance_visitor.hpp
Visitor for generating CONDUCTANCE statements for ions
nmodl::ast::BreakpointBlock
Represents a BREAKPOINT block in NMODL.
Definition: breakpoint_block.hpp:53
nmodl::pybind_wrappers::pybind_wrap_api::create_ads_executor
decltype(&create_ads_executor_func) create_ads_executor
Definition: pyembed.hpp:119
nmodl::visitor::SympyConductanceVisitor::binary_expr_index
std::map< std::string, std::size_t > binary_expr_index
map from lhs of binary expression to index of expression in above vector
Definition: sympy_conductance_visitor.hpp:80
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12902
nmodl::ast::AstNodeType::UNIT_DEF
@ UNIT_DEF
type of ast::UnitDef
nmodl::pybind_wrappers::pybind_wrap_api::destroy_ads_executor
decltype(&destroy_ads_executor_func) destroy_ads_executor
Definition: pyembed.hpp:123
nmodl::ast::BinaryExpression::get_op
const BinaryOperator & get_op() const noexcept
Getter for member variable BinaryExpression::op.
Definition: binary_expression.hpp:170
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:126
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::ast::AstNodeType::UNIT
@ UNIT
type of ast::Unit
nmodl::visitor::get_new_name
std::string get_new_name(const std::string &name, const std::string &suffix, std::map< std::string, int > &variables)
Return new name variable by appending _suffix_COUNT where COUNT is number of times the given variable...
Definition: visitor_utils.cpp:61
nmodl::symtab::syminfo::NmodlType
NmodlType
NMODL variable properties.
Definition: symbol_properties.hpp:116
nmodl::ast::BreakpointBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:4642
nmodl::visitor::conductance_statement_possible
static bool conductance_statement_possible(const ast::BreakpointBlock &node)
Analyse breakpoint block to check if it is safe to insert CONDUCTANCE statements.
Definition: sympy_conductance_visitor.cpp:41
nmodl::ast::BinaryOperator::get_value
BinaryOp get_value() const noexcept
Getter for member variable BinaryOperator::value.
Definition: binary_operator.hpp:143
logger.hpp
Implement logger based on spdlog library.
nmodl::visitor::SympyConductanceVisitor::ordered_binary_exprs_lhs
std::vector< std::string > ordered_binary_exprs_lhs
ditto but for LHS of expression only
Definition: sympy_conductance_visitor.hpp:77
nmodl::visitor::SympyConductanceVisitor::NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS
bool NONSPECIFIC_CONDUCTANCE_ALREADY_EXISTS
Definition: sympy_conductance_visitor.hpp:71
nmodl::pybind_wrappers
Definition: pyembed.cpp:20
nmodl::visitor::add_local_variable
LocalVar * add_local_variable(StatementBlock &node, Identifier *varname)
Definition: visitor_utils.cpp:92
nmodl::ast::BinaryExpression::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable BinaryExpression::lhs.
Definition: binary_expression.hpp:161
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::visitor::SympyConductanceVisitor::i_ignore
string_set i_ignore
set of currents to ignore
Definition: sympy_conductance_visitor.hpp:66
nmodl::visitor::SympyConductanceVisitor::nonspecific_nodes
std::vector< std::shared_ptr< const ast::Ast > > nonspecific_nodes
non specific currents
Definition: sympy_conductance_visitor.hpp:86
nmodl::visitor::SympyConductanceVisitor::visit_conductance_hint
void visit_conductance_hint(ast::ConductanceHint &node) override
visit node of type ast::ConductanceHint
Definition: sympy_conductance_visitor.cpp:182
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
nmodl::visitor::SympyConductanceVisitor::lookup_useion_statements
void lookup_useion_statements()
Definition: sympy_conductance_visitor.cpp:161
nmodl::visitor::get_global_vars
std::set< std::string > get_global_vars(const Program &node)
Return set of strings with the names of all global variables.
Definition: visitor_utils.cpp:170
all.hpp
Auto generated AST classes declaration.
pyembed.hpp