30 if (!conserve_equations.empty()) {
31 std::unordered_set<ast::Statement*> eqs;
32 for (
const auto& item: conserve_equations) {
33 eqs.insert(std::dynamic_pointer_cast<ast::Statement>(item).get());
44 std::regex unit_pattern(R
"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))");
46 auto rhs_string_no_units = fmt::format(
"{} = {}",
48 std::regex_replace(rhs_string, unit_pattern,
"$1"));
49 logger->debug(
"CvodeVisitor :: removing units from statement {}",
to_nmodl(node));
50 logger->debug(
"CvodeVisitor :: result: {}", rhs_string_no_units);
51 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
53 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
54 expr_statement->get_expression());
55 node.
set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
59 std::shared_ptr<ast::Identifier> node) {
60 auto variable = std::make_pair(node->get_node_name(), std::optional<int>());
61 if (node->is_indexed_name()) {
62 variable.second = std::optional<int>(
63 get_index(*std::dynamic_pointer_cast<const ast::IndexedName>(node)));
70 const std::string& ignored_name) {
71 std::unordered_set<std::string> indexed_variables;
76 for (
const auto& var: indexed_vars) {
77 const auto& varname = var->get_node_name();
79 auto varname_not_reserved =
80 std::none_of(reserved_symbols.begin(),
81 reserved_symbols.end(),
82 [&varname](
const auto item) { return varname == item; });
83 if (indexed_variables.count(varname) == 0 && varname != ignored_name &&
84 varname_not_reserved) {
85 indexed_variables.insert(varname);
88 return indexed_variables;
92 const auto& lhs = node.
get_lhs();
94 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
97 if (name->is_prime_name()) {
98 varname =
"D" + name->get_node_name();
100 }
else if (name->is_indexed_name()) {
103 if (!nodes.empty()) {
105 auto statement = fmt::format(
"{} = {}", varname, varname);
106 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
108 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
109 expr_statement->get_expression());
110 node.
set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
136 const auto& lhs = node.
get_lhs();
142 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
146 auto symbol = std::make_shared<symtab::Symbol>(varname,
ModToken());
147 symbol->set_original_name(name->get_node_name());
160 const auto& lhs = node.
get_lhs();
166 auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();
170 auto symbol = std::make_shared<symtab::Symbol>(varname,
ModToken());
171 symbol->set_original_name(name->get_node_name());
182 auto [jacobian, exception_message] =
184 if (!exception_message.empty()) {
185 logger->warn(
"CvodeVisitor :: python exception: {}", exception_message);
189 auto statement = fmt::format(
"{} = {} / (1 - dt * ({}))", varname, varname, jacobian);
190 logger->debug(
"CvodeVisitor :: replacing statement {} with {}",
to_nmodl(node), statement);
191 auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
193 const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
194 expr_statement->get_expression());
195 node.
set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
201 if (derivative_blocks.empty()) {
206 auto not_steadystate = [](
const auto& item) {
207 auto name = std::dynamic_pointer_cast<const ast::DerivativeBlock>(item)->get_node_name();
210 decltype(derivative_blocks) derivative_blocks_copy;
211 std::copy_if(derivative_blocks.begin(),
212 derivative_blocks.end(),
213 std::back_inserter(derivative_blocks_copy),
215 if (derivative_blocks_copy.size() > 1) {
216 auto message =
"CvodeVisitor :: cannot have multiple DERIVATIVE blocks";
218 throw std::runtime_error(message);
221 return std::dynamic_pointer_cast<ast::DerivativeBlock>(derivative_blocks_copy[0]);
227 if (derivative_block ==
nullptr) {
231 auto non_stiff_block = derivative_block->get_statement_block()->clone();
234 auto stiff_block = derivative_block->get_statement_block()->clone();
241 derivative_block->get_name(),
242 std::shared_ptr<ast::Integer>(
new ast::Integer(prime_vars.size(),
nullptr)),
243 std::shared_ptr<ast::StatementBlock>(non_stiff_block),
244 std::shared_ptr<ast::StatementBlock>(stiff_block)));