User Guide
steadystate_visitor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
9 
10 #include "ast/all.hpp"
12 #include "utils/logger.hpp"
14 
15 namespace nmodl {
16 namespace visitor {
17 
18 std::shared_ptr<ast::DerivativeBlock> SteadystateVisitor::create_steadystate_block(
19  const std::shared_ptr<ast::SolveBlock>& solve_block,
20  const std::vector<std::shared_ptr<ast::Ast>>& deriv_blocks) {
21  // new block to be returned
22  std::shared_ptr<ast::DerivativeBlock> ss_block;
23 
24  // get method & derivative block
25  const auto solve_block_name = solve_block->get_block_name()->get_value()->eval();
26  const auto steadystate_method = solve_block->get_steadystate()->get_value()->eval();
27 
28  logger->debug("SteadystateVisitor :: Found STEADYSTATE SOLVE statement: using {} for {}",
29  steadystate_method,
30  solve_block_name);
31 
32  ast::DerivativeBlock* deriv_block_ptr = nullptr;
33  for (const auto& block_ptr: deriv_blocks) {
34  auto deriv_block = std::dynamic_pointer_cast<ast::DerivativeBlock>(block_ptr);
35  if (deriv_block->get_node_name() == solve_block_name) {
36  logger->debug("SteadystateVisitor :: -> found corresponding DERIVATIVE block: {}",
37  solve_block_name);
38  deriv_block_ptr = deriv_block.get();
39  break;
40  }
41  }
42 
43  if (deriv_block_ptr != nullptr) {
44  // make a clone of derivative block with "_steadystate" suffix
45  ss_block = std::shared_ptr<ast::DerivativeBlock>(deriv_block_ptr->clone());
46  auto ss_name = ss_block->get_name();
47  ss_name->set_name(ss_name->get_value()->get_value() + "_steadystate");
48  auto ss_name_clone = std::shared_ptr<ast::Name>(ss_name->clone());
49  ss_block->set_name(std::move(ss_name));
50  logger->debug("SteadystateVisitor :: -> adding new DERIVATIVE block: {}",
51  ss_block->get_node_name());
52 
53  std::string new_dt;
54  if (steadystate_method == codegen::naming::SPARSE_METHOD) {
55  new_dt = fmt::format("{:.16g}", STEADYSTATE_SPARSE_DT);
56  } else if (steadystate_method == codegen::naming::DERIVIMPLICIT_METHOD) {
57  new_dt += fmt::format("{:.16g}", STEADYSTATE_DERIVIMPLICIT_DT);
58  } else {
59  logger->warn("SteadystateVisitor :: solve method {} not supported for STEADYSTATE",
60  steadystate_method);
61  return nullptr;
62  }
63 
64  auto statement_block = ss_block->get_statement_block();
65  auto statements = statement_block->get_statements();
66 
67  // add statement for changing the timestep
68  auto update_dt_statement = std::make_shared<ast::UpdateDt>(new ast::Double(new_dt));
69  statements.insert(statements.begin(), update_dt_statement);
70 
71  // replace old set of statements in AST with new one
72  statement_block->set_statements(std::move(statements));
73 
74  // update SOLVE statement:
75  // set name to point to new DERIVATIVE block
76  solve_block->set_block_name(std::move(ss_name_clone));
77  // change from STEADYSTATE to METHOD
78  solve_block->set_method(solve_block->get_steadystate());
79  solve_block->set_steadystate(nullptr);
80  } else {
81  logger->warn("SteadystateVisitor :: Could not find derivative block {} for STEADYSTATE",
82  solve_block_name);
83  return nullptr;
84  }
85  return ss_block;
86 }
87 
89  // get DERIVATIVE blocks
90  const auto& deriv_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK});
91 
92  // get list of STEADYSTATE solve statements with names & methods
93  const auto& solve_block_nodes = collect_nodes(node, {ast::AstNodeType::SOLVE_BLOCK});
94 
95  // create new DERIVATIVE blocks for the STEADYSTATE solves
96  for (const auto& solve_block_ptr: solve_block_nodes) {
97  if (auto solve_block = std::dynamic_pointer_cast<ast::SolveBlock>(solve_block_ptr)) {
98  if (solve_block->get_steadystate()) {
99  auto ss_block = create_steadystate_block(solve_block, deriv_blocks);
100  if (ss_block != nullptr) {
101  node.emplace_back_node(ss_block);
102  }
103  }
104  }
105  }
106 }
107 
108 } // namespace visitor
109 } // namespace nmodl
nmodl::ast::AstNodeType::DERIVATIVE_BLOCK
@ DERIVATIVE_BLOCK
type of ast::DerivativeBlock
nmodl::ast::Double
Represents a double variable.
Definition: double.hpp:53
nmodl::visitor::SteadystateVisitor::STEADYSTATE_DERIVIMPLICIT_DT
const double STEADYSTATE_DERIVIMPLICIT_DT
Definition: steadystate_visitor.hpp:62
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::logger
logger_type logger
Definition: logger.cpp:34
nmodl::codegen::naming::SPARSE_METHOD
static constexpr char SPARSE_METHOD[]
sparse method in nmodl
Definition: codegen_naming.hpp:36
nmodl::visitor::SteadystateVisitor::create_steadystate_block
std::shared_ptr< ast::DerivativeBlock > create_steadystate_block(const std::shared_ptr< ast::SolveBlock > &solve_block, const std::vector< std::shared_ptr< ast::Ast >> &deriv_blocks)
create new steady state derivative block for given solve block
Definition: steadystate_visitor.cpp:18
codegen_naming.hpp
nmodl::visitor::SteadystateVisitor::STEADYSTATE_SPARSE_DT
const double STEADYSTATE_SPARSE_DT
Definition: steadystate_visitor.hpp:60
steadystate_visitor.hpp
Visitor for STEADYSTATE solve statements
visitor_utils.hpp
Utility functions for visitors implementation.
nmodl::ast::DerivativeBlock::clone
DerivativeBlock * clone() const override
Return a copy of the current node.
Definition: derivative_block.hpp:88
nmodl::ast::AstNodeType::SOLVE_BLOCK
@ SOLVE_BLOCK
type of ast::SolveBlock
nmodl::ast::DerivativeBlock
Represents DERIVATIVE block in the NMODL.
Definition: derivative_block.hpp:49
nmodl::visitor::SteadystateVisitor::visit_program
void visit_program(ast::Program &node) override
visit node of type ast::Program
Definition: steadystate_visitor.cpp:88
nmodl::collect_nodes
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
Definition: visitor_utils.cpp:205
nmodl::codegen::naming::DERIVIMPLICIT_METHOD
static constexpr char DERIVIMPLICIT_METHOD[]
derivimplicit method in nmodl
Definition: codegen_naming.hpp:24
logger.hpp
Implement logger based on spdlog library.
nmodl::ast::Program
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
nmodl::ast::Program::emplace_back_node
void emplace_back_node(Node *n)
Add member to blocks by raw pointer.
Definition: ast.cpp:12817
all.hpp
Auto generated AST classes declaration.