User Guide
solve_block_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include <cassert>
11 #include <fmt/format.h>
12 
13 #include "ast/all.hpp"
15 #include "visitor_utils.hpp"
16 
17 namespace nmodl {
18 namespace visitor {
19 
21  in_breakpoint_block = true;
22  node.visit_children(*this);
23  in_breakpoint_block = false;
24 }
25 
26 /// check if given node contains sympy solution
27 static bool has_sympy_solution(const ast::Ast& node) {
29 }
30 
31 /**
32  * Create solution expression node that will be used for solve block
33  * \param solve_block solve block used to describe node to solve and method
34  * \return solution expression that will be used to replace the solve block
35  *
36  * Depending on the solver used, solve block is converted to solve expression statement
37  * that will be used to replace solve block. Note that the blocks are clones instead of
38  * shared_ptr because DerivimplicitCallback is currently contain whole node
39  * instead of just pointer.
40  */
42  ast::SolveBlock& solve_block) {
43  /// find out the block that is going to solved
44  const auto& block_name = solve_block.get_block_name()->get_node_name();
45  const auto& solve_node_symbol = symtab->lookup(block_name);
46  if (solve_node_symbol == nullptr) {
47  throw std::runtime_error(
48  fmt::format("SolveBlockVisitor :: cannot find the block '{}' to solve it", block_name));
49  }
50  auto node_to_solve = solve_node_symbol->get_nodes().front();
51 
52  /// in case of derivimplicit method if neuron solver is used (i.e. not sympy) then
53  /// the solution is not in place but we have to create a callback to newton solver
54  const auto& method = solve_block.get_method();
55  std::string solve_method = method ? method->get_node_name() : "";
56  if (solve_method == codegen::naming::DERIVIMPLICIT_METHOD &&
57  !has_sympy_solution(*node_to_solve)) {
58  /// typically derivimplicit is used for derivative block only
59  assert(node_to_solve->get_node_type() == ast::AstNodeType::DERIVATIVE_BLOCK);
60  auto derivative_block = dynamic_cast<ast::DerivativeBlock*>(node_to_solve);
61  auto callback_expr = new ast::DerivimplicitCallback(derivative_block->clone());
62  return new ast::SolutionExpression(solve_block.clone(), callback_expr);
63  }
64 
65  auto block_to_solve = node_to_solve->get_statement_block();
66  return new ast::SolutionExpression(solve_block.clone(), block_to_solve->clone());
67 }
68 
69 /**
70  * Replace solve blocks with solution expression
71  * @param node Ast node for SOLVE statement in the mod file
72  */
74  node.visit_children(*this);
75  if (node.get_expression()->is_solve_block()) {
76  auto solve_block = dynamic_cast<ast::SolveBlock*>(node.get_expression().get());
77  auto sol_expr = create_solution_expression(*solve_block);
78  if (in_breakpoint_block) {
79  nrn_state_solve_statements.emplace_back(new ast::ExpressionStatement(sol_expr));
80  } else {
81  node.set_expression(std::shared_ptr<ast::SolutionExpression>(sol_expr));
82  }
83  }
84 }
85 
87  symtab = node.get_symbol_table();
88  node.visit_children(*this);
89  /// add new node NrnState with solve blocks from breakpoint block
90  if (!nrn_state_solve_statements.empty()) {
91  auto nrn_state = new ast::NrnStateBlock(nrn_state_solve_statements);
92  node.emplace_back_node(nrn_state);
93  }
94 }
95 
96 } // namespace visitor
97 } // namespace nmodl
nmodl::ast::ExpressionStatement::get_expression
std::shared_ptr< Expression > get_expression() const noexcept
Getter for member variable ExpressionStatement::expression.
Definition: expression_statement.hpp:143
nmodl::ast::AstNodeType::DERIVATIVE_BLOCK
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
nmodl::ast::SolveBlock
TODO.
Definition: solve_block.hpp:38
nmodl::visitor::SolveBlockVisitor::symtab
symtab::SymbolTable * symtab
Definition: solve_block_visitor.hpp:37
nmodl::ast::Ast
Base class for all Abstract Syntax Tree node types.
Definition: ast.hpp:69
solve_block_visitor.hpp
Replace solve block statements with actual solution node in the AST.
nmodl::ast::AstNodeType::EIGEN_NEWTON_SOLVER_BLOCK
@ EIGEN_NEWTON_SOLVER_BLOCK
type of ast::EigenNewtonSolverBlock
nmodl::visitor::SolveBlockVisitor::create_solution_expression
ast::SolutionExpression * create_solution_expression(ast::SolveBlock &solve_block)
Create solution expression node that will be used for solve block.
Definition: solve_block_visitor.cpp:41
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
codegen_naming.hpp
nmodl::ast::ExpressionStatement
TODO.
Definition: expression_statement.hpp:38
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::BreakpointBlock
Represents a BREAKPOINT block in NMODL.
Definition: breakpoint_block.hpp:53
nmodl::ast::SolveBlock::get_block_name
std::shared_ptr< Name > get_block_name() const noexcept
Getter for member variable SolveBlock::block_name.
Definition: solve_block.hpp:177
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12902
nmodl::visitor::SolveBlockVisitor::visit_breakpoint_block
void visit_breakpoint_block(ast::BreakpointBlock &node) override
visit node of type ast::BreakpointBlock
Definition: solve_block_visitor.cpp:20
nmodl::visitor::has_sympy_solution
static bool has_sympy_solution(const ast::Ast &node)
check if given node contains sympy solution
Definition: solve_block_visitor.cpp:27
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::ast::SolveBlock::get_method
std::shared_ptr< Name > get_method() const noexcept
Getter for member variable SolveBlock::method.
Definition: solve_block.hpp:186
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::visitor::SolveBlockVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: solve_block_visitor.cpp:86
nmodl::ast::ExpressionStatement::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:9030
nmodl::ast::NrnStateBlock
Represents the coreneuron nrn_state callback function.
Definition: nrn_state_block.hpp:39
nmodl::ast::BreakpointBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:4642
nmodl::ast::DerivimplicitCallback
Represent a callback to NEURON's derivimplicit solver.
Definition: derivimplicit_callback.hpp:38
nmodl::ast::SolveBlock::clone
SolveBlock * clone() const override
Return a copy of the current node.
Definition: solve_block.hpp:79
nmodl::codegen::naming::DERIVIMPLICIT_METHOD
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Definition: codegen_naming.hpp:24
nmodl::ast::SolutionExpression
Represent solution of a block in the AST.
Definition: solution_expression.hpp:38
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::ast::Program::emplace_back_node
void emplace_back_node(Node *n)
Add member to blocks by raw pointer.
Definition: ast.cpp:12817
nmodl::symtab::SymbolTable::lookup
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Definition: symbol_table.hpp:199
nmodl::visitor::SolveBlockVisitor::nrn_state_solve_statements
ast::StatementVector nrn_state_solve_statements
solve expression statements for NrnState block
Definition: solve_block_visitor.hpp:42
nmodl::visitor::SolveBlockVisitor::in_breakpoint_block
bool in_breakpoint_block
Definition: solve_block_visitor.hpp:39
nmodl::visitor::SolveBlockVisitor::visit_expression_statement
void visit_expression_statement(ast::ExpressionStatement &node) override
Replace solve blocks with solution expression.
Definition: solve_block_visitor.cpp:73
all.hpp
Auto generated AST classes declaration.
nmodl::ast::ExpressionStatement::set_expression
void set_expression(std::shared_ptr< Expression > &&expression)
Setter for member variable ExpressionStatement::expression (rvalue reference)
Definition: ast.cpp:9088