User Guide
neuron_solve_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
12 #include "parser/diffeq_driver.hpp"
13 #include "symtab/symbol.hpp"
14 #include "utils/logger.hpp"
16 
17 
18 namespace nmodl {
19 namespace visitor {
20 
22  auto name = node.get_block_name()->get_node_name();
23  const auto& method = node.get_method();
24  solve_method = method ? method->get_value()->eval() : "";
25  solve_blocks[name] = solve_method;
26 }
27 
28 
30  derivative_block_name = node.get_name()->get_node_name();
31  derivative_block = true;
32  node.visit_children(*this);
33  derivative_block = false;
35  const auto& statement_block = node.get_statement_block();
36  for (auto& e: euler_solution_expressions) {
37  statement_block->emplace_back_statement(e);
38  }
39  }
40 }
41 
42 
44  differential_equation = true;
45  node.visit_children(*this);
46  differential_equation = false;
47 }
48 
49 
51  const auto& lhs = node.get_lhs();
52 
53  /// we have to only solve odes under derivative block where lhs is variable
54  if (!derivative_block || !differential_equation || !lhs->is_var_name()) {
55  return;
56  }
57 
59  auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
60 
61  if (name->is_prime_name()) {
62  auto equation = to_nmodl(node);
64  std::string solution;
65  /// check if ode can be solved with cnexp method
66  if (parser::DiffeqDriver::cnexp_possible(equation, solution)) {
67  auto statement = create_statement(solution);
68  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
69  statement);
70  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
71  expr_statement->get_expression());
72  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
73  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
74  } else {
75  logger->warn("NeuronSolveVisitor :: cnexp solver not possible for {}",
76  to_nmodl(node));
77  }
79  // computation of the derivative in place
80  {
81  std::string solution = parser::DiffeqDriver::solve(equation, solve_method);
82  auto statement = create_statement(solution);
83  auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
84  statement);
85  const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
86  expr_statement->get_expression());
87  node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
88  node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
89  }
90 
91  // create a new statement to compute the value based on the derivative
92  // this statement will be pushed at the end of the derivative block
93  {
94  std::string n = name->get_node_name();
95  auto statement = create_statement(fmt::format("{} = {} + dt * D{}", n, n, n));
96  euler_solution_expressions.emplace_back(statement);
97  }
99  auto varname = "D" + name->get_node_name();
100  node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
101  if (program_symtab->lookup(varname) == nullptr) {
102  auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
103  symbol->set_original_name(name->get_node_name());
104  symbol->created_from_state();
105  program_symtab->insert(symbol);
106  }
107  } else {
108  logger->error("NeuronSolveVisitor :: solver method '{}' not supported", solve_method);
109  }
110  }
111 }
112 
115  node.visit_children(*this);
116 }
117 
118 } // namespace visitor
119 } // namespace nmodl
nmodl::ast::SolveBlock
TODO.
Definition: solve_block.hpp:38
nmodl::to_nmodl
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Definition: visitor_utils.cpp:227
nmodl::visitor::NeuronSolveVisitor::visit_binary_expression
void visit_binary_expression(ast::BinaryExpression &node) override
visit node of type ast::BinaryExpression
Definition: neuron_solve_visitor.cpp:50
nmodl::visitor::NeuronSolveVisitor::euler_solution_expressions
std::vector< std::shared_ptr< ast::Statement > > euler_solution_expressions
Definition: neuron_solve_visitor.hpp:62
symbol.hpp
Implement class to represent a symbol in Symbol Table.
nmodl::ast::DiffEqExpression
Represents differential equation in DERIVATIVE block.
Definition: diff_eq_expression.hpp:38
nmodl::visitor::NeuronSolveVisitor::visit_solve_block
void visit_solve_block(ast::SolveBlock &node) override
visit node of type ast::SolveBlock
Definition: neuron_solve_visitor.cpp:21
nmodl::parser::DiffeqDriver::cnexp_possible
static bool cnexp_possible(const std::string &equation, std::string &solution)
check if given equation can be solved using cnexp method
Definition: diffeq_driver.cpp:64
nmodl::visitor::NeuronSolveVisitor::visit_diff_eq_expression
void visit_diff_eq_expression(ast::DiffEqExpression &node) override
visit node of type ast::DiffEqExpression
Definition: neuron_solve_visitor.cpp:43
nmodl::codegen::naming::CNEXP_METHOD
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
Definition: codegen_naming.hpp:30
nmodl::ast::DerivativeBlock::get_name
std::shared_ptr< Name > get_name() const noexcept
Getter for member variable DerivativeBlock::name.
Definition: derivative_block.hpp:198
nmodl::visitor::NeuronSolveVisitor::derivative_block
bool derivative_block
visiting derivative block
Definition: neuron_solve_visitor.hpp:57
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::visitor::NeuronSolveVisitor::differential_equation
bool differential_equation
true while visiting differential equation
Definition: neuron_solve_visitor.hpp:45
nmodl::visitor::NeuronSolveVisitor::program_symtab
symtab::SymbolTable * program_symtab
global symbol table
Definition: neuron_solve_visitor.hpp:48
codegen_naming.hpp
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::BinaryExpression::set_rhs
void set_rhs(std::shared_ptr< Expression > &&rhs)
Setter for member variable BinaryExpression::rhs (rvalue reference)
Definition: ast.cpp:6612
nmodl::visitor::NeuronSolveVisitor::derivative_block_name
std::string derivative_block_name
the derivative name currently being visited
Definition: neuron_solve_visitor.hpp:60
nmodl::ast::DerivativeBlock::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:3249
nmodl::ast::BinaryExpression::set_lhs
void set_lhs(std::shared_ptr< Expression > &&lhs)
Setter for member variable BinaryExpression::lhs (rvalue reference)
Definition: ast.cpp:6586
nmodl::visitor::NeuronSolveVisitor::solve_method
std::string solve_method
method specified in solve block
Definition: neuron_solve_visitor.hpp:54
nmodl::ast::SolveBlock::get_block_name
std::shared_ptr< Name > get_block_name() const noexcept
Getter for member variable SolveBlock::block_name.
Definition: solve_block.hpp:177
nmodl::ast::DiffEqExpression::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:6638
nmodl::ast::DerivativeBlock::get_statement_block
std::shared_ptr< StatementBlock > get_statement_block() const noexcept override
Getter for member variable DerivativeBlock::statement_block.
Definition: derivative_block.hpp:207
nmodl::parser::DiffeqDriver::solve
static std::string solve(const std::string &equation, std::string method, bool debug=false)
solve equation using provided method
Definition: diffeq_driver.cpp:38
nmodl::visitor::NeuronSolveVisitor::visit_derivative_block
void visit_derivative_block(ast::DerivativeBlock &node) override
visit node of type ast::DerivativeBlock
Definition: neuron_solve_visitor.cpp:29
nmodl::visitor::NeuronSolveVisitor::solve_blocks
std::map< std::string, std::string > solve_blocks
a map holding solve block names and methods
Definition: neuron_solve_visitor.hpp:51
nmodl::ast::Program::visit_children
void visit_children(visitor::Visitor &v) override
visit children i.e.
Definition: ast.cpp:12902
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::ast::SolveBlock::get_method
std::shared_ptr< Name > get_method() const noexcept
Getter for member variable SolveBlock::method.
Definition: solve_block.hpp:186
diffeq_driver.hpp
nmodl::symtab::SymbolTable::insert
void insert(const std::shared_ptr< Symbol > &symbol)
Definition: symbol_table.hpp:178
nmodl::visitor::create_statement
std::shared_ptr< Statement > create_statement(const std::string &code_statement)
Convert given code statement (in string format) to corresponding ast node.
Definition: visitor_utils.cpp:126
nmodl::visitor::NeuronSolveVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: neuron_solve_visitor.cpp:113
nmodl::codegen::naming::DERIVIMPLICIT_METHOD
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Definition: codegen_naming.hpp:24
logger.hpp
Implement logger based on spdlog library.
nmodl::codegen::naming::EULER_METHOD
static constexpr char EULER_METHOD[]
euler method in nmodl
Definition: codegen_naming.hpp:27
nmodl::ast::Program::get_symbol_table
symtab::SymbolTable * get_symbol_table() const override
Return associated symbol table for the current ast node.
Definition: program.hpp:153
neuron_solve_visitor.hpp
Visitor that solves ODEs using old solvers of NEURON
nmodl::ast::BinaryExpression::get_lhs
std::shared_ptr< Expression > get_lhs() const noexcept
Getter for member variable BinaryExpression::lhs.
Definition: binary_expression.hpp:161
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::symtab::SymbolTable::lookup
std::shared_ptr< Symbol > lookup(const std::string &name) const
check if symbol with given name exist in the current table (but not in parents)
Definition: symbol_table.hpp:199
nmodl::ast::BinaryExpression
Represents binary expression in the NMODL.
Definition: binary_expression.hpp:52
nmodl::ModToken
Represent token returned by scanner.
Definition: modtoken.hpp:50
nmodl::ast::String
Represents a string.
Definition: string.hpp:52
all.hpp
Auto generated AST classes declaration.