36 using visitor::DefUseAnalyzeVisitor;
38 using visitor::RenameVisitor;
39 using visitor::SymtabVisitor;
40 using visitor::VarUsageVisitor;
52 return "C++ (api-compatibility)";
68 for (
const auto& var: codegen_float_variables) {
69 if (var->get_name() == name) {
72 index += var->get_length();
74 throw std::logic_error(name +
" variable not found");
80 for (
const auto& var: codegen_int_variables) {
81 if (var.symbol->get_name() == name) {
84 index += var.symbol->get_length();
86 throw std::logic_error(name +
" variable not found");
95 const std::string& name = token;
101 if (program_symtab->is_method_defined(token)) {
102 return method_name(token);
110 auto new_name = replace_if_verbatim_variable(name);
111 if (new_name != name) {
112 return get_variable_name(new_name,
false);
120 auto use_instance = !printing_top_verbatim_blocks;
121 return get_variable_name(token, use_instance);
131 auto symbol = program_symtab->lookup_in_scope(name);
132 bool is_constant =
false;
133 if (symbol !=
nullptr) {
135 if (info.is_ion_variable(name)) {
140 else if (symbol->has_any_property(NmodlType::param_assign) &&
141 info.variables_in_verbatim.find(name) == info.variables_in_verbatim.end() &&
142 symbol->get_write_count() == 0) {
245 std::vector<std::shared_ptr<const ast::Ast>> nodes;
253 printer->add_line(
"#pragma omp simd");
254 printer->add_line(
"#pragma ivdep");
260 return info.point_process;
265 if (info.point_process) {
273 if (info.point_process) {
274 printer->add_line(
"shadow_rhs[id] = rhs;");
275 printer->add_line(
"shadow_d[id] = g;");
277 auto rhs_op = operator_for_rhs();
278 auto d_op = operator_for_d();
279 printer->fmt_line(
"vec_rhs[node_id] {} rhs;", rhs_op);
280 printer->fmt_line(
"vec_d[node_id] {} g;", d_op);
286 auto rhs_op = operator_for_rhs();
287 auto d_op = operator_for_d();
288 if (info.point_process) {
289 printer->add_line(
"int node_id = node_index[id];");
290 printer->fmt_line(
"vec_rhs[node_id] {} shadow_rhs[id];", rhs_op);
291 printer->fmt_line(
"vec_d[node_id] {} shadow_d[id];", d_op);
302 printer->add_line(
"#pragma omp atomic update");
332 return optimize_ionvar_copies;
337 printer->add_newline(2);
338 auto args =
"size_t num, size_t size, size_t alignment = 16";
339 printer->fmt_push_block(
"static inline void* mem_alloc({})", args);
340 printer->add_line(
"void* ptr;");
341 printer->add_line(
"posix_memalign(&ptr, alignment, num*size);");
342 printer->add_line(
"memset(ptr, 0, size);");
343 printer->add_line(
"return ptr;");
344 printer->pop_block();
346 printer->add_newline(2);
347 printer->push_block(
"static inline void mem_free(void* ptr)");
348 printer->add_line(
"free(ptr);");
349 printer->pop_block();
354 printer->add_newline(2);
355 printer->push_block(
"static inline void coreneuron_abort()");
356 printer->add_line(
"abort();");
357 printer->pop_block();
367 if (info.top_verbatim_blocks.empty()) {
370 print_namespace_stop();
372 printer->add_newline(2);
373 printer->add_line(
"using namespace coreneuron;");
375 printing_top_verbatim_blocks =
true;
377 for (
const auto& block: info.top_blocks) {
378 if (block->is_verbatim()) {
379 printer->add_newline(2);
380 block->accept(*
this);
384 printing_top_verbatim_blocks =
false;
386 print_namespace_start();
391 if (info.functions.empty() && info.procedures.empty()) {
395 printer->add_newline(2);
396 for (
const auto& node: info.functions) {
397 print_function_declaration(*node, node->get_node_name());
398 printer->add_text(
';');
399 printer->add_newline();
401 for (
const auto& node: info.procedures) {
402 print_function_declaration(*node, node->get_node_name());
403 printer->add_text(
';');
404 printer->add_newline();
412 const auto& table_statements =
collect_nodes(node, {AstNodeType::TABLE_STATEMENT});
414 if (table_statements.size() != 1) {
415 auto message = fmt::format(
"One table statement expected in {} found {}",
417 table_statements.size());
418 throw std::runtime_error(message);
420 return dynamic_cast<const TableStatement*
>(table_statements.front().get());
425 auto symbol = program_symtab->lookup_in_scope(name);
427 throw std::runtime_error(
428 fmt::format(
"CodegenCoreneuronCppVisitor:: {} not found in symbol table!", name));
430 if (symbol->is_array()) {
431 return {
true, symbol->get_length()};
440 auto table_variables = statement->get_table_vars();
441 auto depend_variables = statement->get_depend_vars();
442 const auto& from = statement->get_from();
443 const auto& to = statement->get_to();
445 auto internal_params = internal_method_parameters();
446 auto with = statement->get_with()->eval();
448 auto tmin_name = get_variable_name(
"tmin_" + name);
449 auto mfac_name = get_variable_name(
"mfac_" + name);
450 auto float_type = default_float_data_type();
452 printer->add_newline(2);
453 print_device_method_annotation();
454 printer->fmt_push_block(
"void check_{}({})",
456 get_parameter_str(internal_params));
458 printer->fmt_push_block(
"if ({} == 0)", use_table_var);
459 printer->add_line(
"return;");
460 printer->pop_block();
462 printer->add_line(
"static bool make_table = true;");
463 for (
const auto& variable: depend_variables) {
464 printer->fmt_line(
"static {} save_{};", float_type, variable->get_node_name());
467 for (
const auto& variable: depend_variables) {
468 const auto& var_name = variable->get_node_name();
469 const auto& instance_name = get_variable_name(var_name);
470 printer->fmt_push_block(
"if (save_{} != {})", var_name, instance_name);
471 printer->add_line(
"make_table = true;");
472 printer->pop_block();
475 printer->push_block(
"if (make_table)");
477 printer->add_line(
"make_table = false;");
479 printer->add_indent();
480 printer->add_text(tmin_name,
" = ");
482 printer->add_text(
';');
483 printer->add_newline();
485 printer->add_indent();
486 printer->add_text(
"double tmax = ");
488 printer->add_text(
';');
489 printer->add_newline();
492 printer->fmt_line(
"double dx = (tmax-{}) / {}.;", tmin_name, with);
493 printer->fmt_line(
"{} = 1./dx;", mfac_name);
495 printer->fmt_line(
"double x = {};", tmin_name);
496 printer->fmt_push_block(
"for (std::size_t i = 0; i < {}; x += dx, i++)", with + 1);
497 auto function = method_name(
"f_" + name);
499 printer->fmt_line(
"{}({}, x);",
function, internal_method_arguments());
500 for (
const auto& variable: table_variables) {
501 auto var_name = variable->get_node_name();
502 auto instance_name = get_variable_name(var_name);
503 auto table_name = get_variable_name(
"t_" + var_name);
504 auto [is_array, array_length] = check_if_var_is_array(var_name);
506 for (
int j = 0; j < array_length; j++) {
508 "{}[{}][i] = {}[{}];", table_name, j, instance_name, j);
511 printer->fmt_line(
"{}[i] = {};", table_name, instance_name);
515 auto table_name = get_variable_name(
"t_" + name);
516 printer->fmt_line(
"{}[i] = {}({}, x);",
519 internal_method_arguments());
521 printer->pop_block();
523 for (
const auto& variable: depend_variables) {
524 auto var_name = variable->get_node_name();
525 auto instance_name = get_variable_name(var_name);
526 printer->fmt_line(
"save_{} = {};", var_name, instance_name);
529 printer->pop_block();
531 printer->pop_block();
538 auto table_variables = statement->get_table_vars();
539 auto with = statement->get_with()->eval();
541 auto tmin_name = get_variable_name(
"tmin_" + name);
542 auto mfac_name = get_variable_name(
"mfac_" + name);
543 auto function_name = method_name(
"f_" + name);
545 printer->add_newline(2);
546 print_function_declaration(node, name);
547 printer->push_block();
550 printer->fmt_push_block(
"if ({} == 0)", use_table_var);
552 printer->fmt_line(
"{}({}, {});",
554 internal_method_arguments(),
555 params[0].get()->get_node_name());
556 printer->add_line(
"return 0;");
558 printer->fmt_line(
"return {}({}, {});",
560 internal_method_arguments(),
561 params[0].get()->get_node_name());
563 printer->pop_block();
565 printer->fmt_line(
"double xi = {} * ({} - {});",
567 params[0].get()->get_node_name(),
569 printer->push_block(
"if (isnan(xi))");
571 for (
const auto& var: table_variables) {
572 auto var_name = get_variable_name(var->get_node_name());
573 auto [is_array, array_length] = check_if_var_is_array(var->get_node_name());
575 for (
int j = 0; j < array_length; j++) {
576 printer->fmt_line(
"{}[{}] = xi;", var_name, j);
579 printer->fmt_line(
"{} = xi;", var_name);
582 printer->add_line(
"return 0;");
584 printer->add_line(
"return xi;");
586 printer->pop_block();
588 printer->fmt_push_block(
"if (xi <= 0. || xi >= {}.)", with);
589 printer->fmt_line(
"int index = (xi <= 0.) ? 0 : {};", with);
591 for (
const auto& variable: table_variables) {
592 auto var_name = variable->get_node_name();
593 auto instance_name = get_variable_name(var_name);
594 auto table_name = get_variable_name(
"t_" + var_name);
595 auto [is_array, array_length] = check_if_var_is_array(var_name);
597 for (
int j = 0; j < array_length; j++) {
599 "{}[{}] = {}[{}][index];", instance_name, j, table_name, j);
602 printer->fmt_line(
"{} = {}[index];", instance_name, table_name);
605 printer->add_line(
"return 0;");
607 auto table_name = get_variable_name(
"t_" + name);
608 printer->fmt_line(
"return {}[index];", table_name);
610 printer->pop_block();
612 printer->add_line(
"int i = int(xi);");
613 printer->add_line(
"double theta = xi - double(i);");
615 for (
const auto& var: table_variables) {
616 auto var_name = var->get_node_name();
617 auto instance_name = get_variable_name(var_name);
618 auto table_name = get_variable_name(
"t_" + var_name);
619 auto [is_array, array_length] = check_if_var_is_array(var->get_node_name());
621 for (
size_t j = 0; j < array_length; j++) {
623 "{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);",
629 printer->fmt_line(
"{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);",
634 printer->add_line(
"return 0;");
636 auto table_name = get_variable_name(
"t_" + name);
637 printer->fmt_line(
"return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name);
640 printer->pop_block();
645 if (info.table_count == 0) {
649 printer->add_newline(2);
650 auto name = method_name(
"check_table_thread");
651 auto parameters = external_method_parameters(
true);
653 printer->fmt_push_block(
"static void {} ({})", name, parameters);
654 printer->add_line(
"setup_instance(nt, ml);");
655 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
656 printer->add_line(
"double v = 0;");
658 for (
const auto&
function: info.functions_with_table) {
659 auto method_name_str = method_name(
"check_" + function->get_node_name());
660 auto arguments = internal_method_arguments();
661 printer->fmt_line(
"{}({});", method_name_str, arguments);
664 printer->pop_block();
669 const std::string& name) {
670 printer->add_newline(2);
671 print_function_declaration(node, name);
672 printer->add_text(
" ");
673 printer->push_block();
677 auto type = default_float_data_type();
678 printer->fmt_line(
"{} ret_{} = 0.0;", type, name);
680 printer->fmt_line(
"int ret_{} = 0;", name);
684 printer->fmt_line(
"return ret_{};", name);
685 printer->pop_block();
692 if (info.function_uses_table(name)) {
693 auto new_name =
"f_" + name;
694 print_function_or_procedure(node, new_name);
695 print_table_check_function(node);
696 print_table_replacement_function(node);
698 print_function_or_procedure(node, name);
704 print_function_procedure_helper(node);
712 std::string return_var;
713 if (info.function_uses_table(name)) {
714 return_var =
"ret_f_" + name;
716 return_var =
"ret_" + name;
724 print_function_procedure_helper(node);
731 auto params = internal_method_parameters();
732 for (
const auto& i: p) {
733 params.emplace_back(
"",
"double",
"", i->get_node_name());
735 printer->fmt_line(
"double {}({})", method_name(name), get_parameter_str(params));
736 printer->push_block();
737 printer->fmt_line(
"double _arg[{}];", p.size());
738 for (
size_t i = 0; i < p.size(); ++i) {
739 printer->fmt_line(
"_arg[{}] = {};", i, p[i]->get_node_name());
741 printer->fmt_line(
"return hoc_func_table({}, {}, _arg);",
742 get_variable_name(std::string(
"_ptable_" + name),
true),
744 printer->pop_block();
746 printer->fmt_push_block(
"double table_{}()", method_name(name));
747 printer->fmt_line(
"hoc_spec_table(&{}, {});",
748 get_variable_name(std::string(
"_ptable_" + name)),
750 printer->add_line(
"return 0.;");
751 printer->pop_block();
779 auto model_symbol_table = std::make_shared<symtab::ModelSymbolTable>();
787 auto is_functor_const =
true;
789 for (
const auto& variable: variables) {
790 const auto& chain = v.
analyze(complete_block, variable->get_node_name());
791 is_functor_const = !(chain.eval() == DUState::D || chain.eval() == DUState::LD ||
792 chain.eval() == DUState::CD);
793 if (!is_functor_const) {
798 return is_functor_const;
806 auto float_type = default_float_data_type();
809 const auto functor_name = info.functor_names[&node];
810 printer->fmt_push_block(
"struct {0}", functor_name);
811 printer->add_line(
"NrnThread* nt;");
812 printer->add_line(instance_struct(),
"* inst;");
813 printer->add_line(
"int id, pnodecount;");
814 printer->add_line(
"double v;");
815 printer->add_line(
"const Datum* indexes;");
816 printer->add_line(
"double* data;");
817 printer->add_line(
"ThreadDatum* thread;");
819 if (ion_variable_struct_required()) {
820 print_ion_variable();
824 printer->add_newline();
826 printer->push_block(
"void initialize()");
828 printer->pop_block();
829 printer->add_newline();
832 "{0}(NrnThread* nt, {1}* inst, int id, int pnodecount, double v, const Datum* indexes, "
833 "double* data, ThreadDatum* thread) : "
834 "nt{{nt}}, inst{{inst}}, id{{id}}, pnodecount{{pnodecount}}, v{{v}}, indexes{{indexes}}, "
835 "data{{data}}, thread{{thread}} "
840 printer->add_indent();
846 "void operator()(const Eigen::Matrix<{0}, {1}, 1>& nmodl_eigen_xm, Eigen::Matrix<{0}, {1}, "
847 "1>& nmodl_eigen_fm, "
848 "Eigen::Matrix<{0}, {1}, {1}>& nmodl_eigen_jm) {2}",
851 is_functor_const(variable_block, functor_block) ?
"const " :
"");
852 printer->push_block();
853 printer->fmt_line(
"const {}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
854 printer->fmt_line(
"{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
855 printer->fmt_line(
"{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);
856 print_statement_block(functor_block,
false,
false);
857 printer->pop_block();
858 printer->add_newline();
861 printer->push_block(
"void finalize()");
863 printer->pop_block();
865 printer->pop_block(
";");
872 printer->add_multi_line(R
"CODE(
874 nmodl_eigen_jm.computeInverseWithCheck(nmodl_eigen_jm_inv,invertible);
875 nmodl_eigen_xm = nmodl_eigen_jm_inv*nmodl_eigen_fm;
876 if (!invertible) assert(false && "Singular or ill-conditioned matrix (Eigen::inverse)!");
883 printer->add_line(
"if (!nmodl_eigen_jm.IsRowMajor) nmodl_eigen_jm.transposeInPlace();");
886 printer->fmt_line(
"Eigen::Matrix<int, {}, 1> pivot;", N);
887 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, 1> rowmax;", float_type, N);
891 "if (nmodl::crout::Crout<{0}>({1}, nmodl_eigen_jm.data(), pivot.data(), rowmax.data()) "
892 "< 0) assert(false && \"Singular or ill-conditioned matrix (nmodl::crout)!\");",
898 "nmodl::crout::solveCrout<{0}>({1}, nmodl_eigen_jm.data(), nmodl_eigen_fm.data(), "
899 "nmodl_eigen_xm.data(), pivot.data());",
912 if (ion_variable_struct_required()) {
913 return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
915 return "id, pnodecount, inst, data, indexes, thread, nt, v";
924 params.emplace_back(
"",
"int",
"",
"id");
925 params.emplace_back(
"",
"int",
"",
"pnodecount");
926 params.emplace_back(
"", fmt::format(
"{}*", instance_struct()),
"",
"inst");
927 if (ion_variable_struct_required()) {
928 params.emplace_back(
"",
"IonCurVar&",
"",
"ionvar");
930 params.emplace_back(
"",
"double*",
"",
"data");
931 params.emplace_back(
"const ",
"Datum*",
"",
"indexes");
932 params.emplace_back(
"",
"ThreadDatum*",
"",
"thread");
933 params.emplace_back(
"",
"NrnThread*",
"",
"nt");
934 params.emplace_back(
"",
"double",
"",
"v");
940 return "id, pnodecount, data, indexes, thread, nt, ml, v";
946 return "int id, int pnodecount, double* data, Datum* indexes, "
947 "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, int tml_id";
949 return "int id, int pnodecount, double* data, Datum* indexes, "
950 "ThreadDatum* thread, NrnThread* nt, Memb_list* ml, double v";
955 if (ion_variable_struct_required()) {
956 return "id, pnodecount, ionvar, data, indexes, thread, nt, ml, v";
958 return "id, pnodecount, data, indexes, thread, nt, ml, v";
967 if (ion_variable_struct_required()) {
968 return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
970 return "id, pnodecount, inst, data, indexes, thread, nt, v";
988 if (internal_method_call_encountered) {
989 name = nrn_thread_internal_arguments();
990 internal_method_call_encountered =
false;
992 name = nrn_thread_arguments();
996 name = external_method_parameters();
1009 driver.scan_string(text);
1010 auto tokens =
driver.all_tokens();
1012 for (
size_t i = 0; i < tokens.size(); i++) {
1013 auto token = tokens[i];
1017 if (program_symtab->is_method_defined(token) && tokens[i + 1] ==
"(") {
1018 internal_method_call_encountered =
true;
1020 auto name = process_verbatim_token(token);
1023 name.insert(0, 1,
'&');
1025 if (token ==
"_STRIDE") {
1026 name =
"pnodecount+id";
1035 auto nrn_channel_info_var_name = get_channel_info_var_name();
1042 return fmt::format(
"{}, {}, {}, nullptr, {}, {}, {}, {}, first_pointer_var_index()",
1043 nrn_channel_info_var_name,
1048 nrn_private_constructor,
1049 nrn_private_destructor);
1054 const std::string& concentration,
1057 auto style_var_name = get_variable_name(
"style_" + ion_name);
1059 "nrn_wrote_conc({}_type,"
1063 " nrn_ion_global_map,"
1065 " nt->_ml_list[{}_type]->_nodecount_padded)",
1081 printer->add_newline(2);
1082 print_device_method_annotation();
1083 printer->push_block(
"static inline int first_pointer_var_index()");
1084 printer->fmt_line(
"return {};", info.first_pointer_var_index);
1085 printer->pop_block();
1090 printer->add_newline(2);
1091 print_device_method_annotation();
1092 printer->push_block(
"static inline int first_random_var_index()");
1093 printer->fmt_line(
"return {};", info.first_random_var_index);
1094 printer->pop_block();
1099 printer->add_newline(2);
1100 print_device_method_annotation();
1101 printer->push_block(
"static inline int float_variables_size()");
1102 printer->fmt_line(
"return {};", float_variables_size());
1103 printer->pop_block();
1105 printer->add_newline(2);
1106 print_device_method_annotation();
1107 printer->push_block(
"static inline int int_variables_size()");
1108 printer->fmt_line(
"return {};", int_variables_size());
1109 printer->pop_block();
1114 if (!net_receive_exist()) {
1117 printer->add_newline(2);
1118 print_device_method_annotation();
1119 printer->push_block(
"static inline int num_net_receive_args()");
1120 printer->fmt_line(
"return {};", info.num_net_receive_parameters);
1121 printer->pop_block();
1126 printer->add_newline(2);
1127 print_device_method_annotation();
1128 printer->push_block(
"static inline int get_mech_type()");
1130 printer->fmt_line(
"return {};", get_variable_name(
"mech_type",
false));
1131 printer->pop_block();
1136 printer->add_newline(2);
1137 print_device_method_annotation();
1138 printer->push_block(
"static inline Memb_list* get_memb_list(NrnThread* nt)");
1139 printer->push_block(
"if (!nt->_ml_list)");
1140 printer->add_line(
"return nullptr;");
1141 printer->pop_block();
1142 printer->add_line(
"return nt->_ml_list[get_mech_type()];");
1143 printer->pop_block();
1148 printer->add_newline(2);
1149 printer->push_block(
"namespace coreneuron");
1154 printer->pop_block();
1170 if (info.vectorize && info.derivimplicit_used()) {
1171 int tid = info.derivimplicit_var_thread_id;
1172 int list = info.derivimplicit_list_num;
1175 printer->add_newline(2);
1176 printer->add_line(
"/** thread specific helper routines for derivimplicit */");
1178 printer->add_newline(1);
1179 printer->fmt_push_block(
"static inline int* deriv{}_advance(ThreadDatum* thread)", list);
1180 printer->fmt_line(
"return &(thread[{}].i);", tid);
1181 printer->pop_block();
1182 printer->add_newline();
1184 printer->fmt_push_block(
"static inline int dith{}()", list);
1185 printer->fmt_line(
"return {};", tid+1);
1186 printer->pop_block();
1187 printer->add_newline();
1189 printer->fmt_push_block(
"static inline void** newtonspace{}(ThreadDatum* thread)", list);
1190 printer->fmt_line(
"return &(thread[{}]._pvoid);", tid+2);
1191 printer->pop_block();
1194 if (info.vectorize && !info.thread_variables.empty()) {
1195 printer->add_newline(2);
1196 printer->add_line(
"/** tid for thread variables */");
1197 printer->push_block(
"static inline int thread_var_tid()");
1198 printer->fmt_line(
"return {};", info.thread_var_thread_id);
1199 printer->pop_block();
1202 if (info.vectorize && !info.top_local_variables.empty()) {
1203 printer->add_newline(2);
1204 printer->add_line(
"/** tid for top local tread variables */");
1205 printer->push_block(
"static inline int top_local_var_tid()");
1206 printer->fmt_line(
"return {};", info.top_local_thread_id);
1207 printer->pop_block();
1219 bool use_instance)
const {
1220 auto name = symbol->get_name();
1221 auto dimension = symbol->get_length();
1222 auto position = position_of_float_var(name);
1224 if (symbol->is_array()) {
1226 return fmt::format(
"(inst->{}+id*{})", name, dimension);
1228 return fmt::format(
"(data + {}*pnodecount + id*{})", position, dimension);
1231 return fmt::format(
"inst->{}[id]", name);
1233 return fmt::format(
"data[{}*pnodecount + id]", position);
1239 const std::string& name,
1240 bool use_instance)
const {
1241 auto position = position_of_int_var(name);
1245 return fmt::format(
"inst->{}[{}]", name, position);
1247 return fmt::format(
"indexes[{}]", position);
1251 return fmt::format(
"inst->{}[{}*pnodecount+id]", name, position);
1253 return fmt::format(
"indexes[{}*pnodecount+id]", position);
1256 return fmt::format(
"inst->{}[indexes[{}*pnodecount + id]]", name, position);
1258 auto data = symbol.
is_vdata ?
"_vdata" :
"_data";
1259 return fmt::format(
"nt->{}[indexes[{}*pnodecount + id]]", data, position);
1265 bool use_instance)
const {
1269 return fmt::format(
"{}.{}", global_struct_instance(), symbol->get_name());
1275 bool use_instance)
const {
1276 const std::string& varname = update_if_ion_variable_name(name);
1279 auto symbol_comparator = [&varname](
const SymbolType& sym) {
1280 return varname == sym->get_name();
1284 return varname == var.symbol->get_name();
1289 auto f = std::find_if(codegen_float_variables.begin(),
1290 codegen_float_variables.end(),
1292 if (f != codegen_float_variables.end()) {
1293 return float_variable_name(*f, use_instance);
1298 std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator);
1299 if (i != codegen_int_variables.end()) {
1300 return int_variable_name(*i, varname, use_instance);
1304 auto g = std::find_if(codegen_global_variables.begin(),
1305 codegen_global_variables.end(),
1307 if (g != codegen_global_variables.end()) {
1308 return global_variable_name(*g, use_instance);
1322 std::find_if(info.neuron_global_variables.begin(),
1323 info.neuron_global_variables.end(),
1324 [&varname](
auto const& entry) { return entry.first->get_name() == varname; });
1325 if (iter != info.neuron_global_variables.end()) {
1330 ret.append(varname);
1348 time_t current_time{};
1349 time(¤t_time);
1350 std::string data_time_str{std::ctime(¤t_time)};
1353 printer->add_line(
"/*********************************************************");
1354 printer->add_line(
"Model Name : ", info.mod_suffix);
1355 printer->add_line(
"Filename : ", info.mod_file,
".mod");
1356 printer->add_line(
"NMODL Version : ", nmodl_version());
1357 printer->fmt_line(
"Vectorized : {}", info.vectorize);
1358 printer->fmt_line(
"Threadsafe : {}", info.thread_safe);
1360 printer->add_line(
"Simulator : ", simulator_name());
1361 printer->add_line(
"Backend : ", backend_name());
1362 printer->add_line(
"NMODL Compiler : ", version);
1363 printer->add_line(
"*********************************************************/");
1368 printer->add_newline();
1369 printer->add_multi_line(R
"CODE(
1379 printer->add_newline();
1380 printer->add_multi_line(R
"CODE(
1381 #include <coreneuron/gpu/nrn_acc_manager.hpp>
1382 #include <coreneuron/mechanism/mech/mod2c_core_thread.hpp>
1383 #include <coreneuron/mechanism/register_mech.hpp>
1384 #include <coreneuron/nrnconf.h>
1385 #include <coreneuron/nrniv/nrniv_decl.h>
1386 #include <coreneuron/sim/multicore.hpp>
1387 #include <coreneuron/sim/scopmath/newton_thread.hpp>
1388 #include <coreneuron/utils/ivocvect.hpp>
1389 #include <coreneuron/utils/nrnoc_aux.hpp>
1390 #include <coreneuron/utils/randoms/nrnran123.h>
1392 if (info.eigen_newton_solver_exist) {
1393 printer->add_line(
"#include <newton/newton.hpp>");
1395 if (info.eigen_linear_solver_exist) {
1396 if (std::accumulate(info.state_vars.begin(),
1397 info.state_vars.end(),
1400 return l += variable->get_length();
1402 printer->add_line(
"#include <crout/crout.hpp>");
1404 printer->add_line(
"#include <Eigen/Dense>");
1405 printer->add_line(
"#include <Eigen/LU>");
1412 if (info.primes_size == 0) {
1415 const auto count_prime_variables = [](
auto size,
const SymbolType& symbol) {
1416 return size += symbol->get_length();
1418 const auto prime_variables_by_order_size =
1419 std::accumulate(info.prime_variables_by_order.begin(),
1420 info.prime_variables_by_order.end(),
1422 count_prime_variables);
1423 if (info.primes_size != prime_variables_by_order_size) {
1424 throw std::runtime_error{
1425 fmt::format(
"primes_size = {} differs from prime_variables_by_order.size() = {}, "
1426 "this should not happen.",
1428 info.prime_variables_by_order.size())};
1430 auto const initializer_list = [&](
auto const& primes,
const char* prefix) -> std::string {
1431 if (!print_initializers) {
1434 std::string list{
"{"};
1435 for (
auto iter = primes.begin(); iter != primes.end(); ++iter) {
1436 auto const& prime = *iter;
1437 list.append(
std::to_string(position_of_float_var(prefix + prime->get_name())));
1438 if (std::next(iter) != primes.end()) {
1445 printer->fmt_line(
"int slist1[{}]{};",
1447 initializer_list(info.prime_variables_by_order,
""));
1448 printer->fmt_line(
"int dlist1[{}]{};",
1450 initializer_list(info.prime_variables_by_order,
"D"));
1451 codegen_global_variables.push_back(make_symbol(
"slist1"));
1452 codegen_global_variables.push_back(make_symbol(
"dlist1"));
1454 if (info.derivimplicit_used()) {
1455 auto primes = program_symtab->get_variables_with_properties(NmodlType::prime_name);
1456 printer->fmt_line(
"int slist2[{}]{};", info.primes_size, initializer_list(primes,
""));
1457 codegen_global_variables.push_back(make_symbol(
"slist2"));
1479 const auto value_initialize = print_initializers ?
"{}" :
"";
1481 auto float_type = default_float_data_type();
1482 printer->add_newline(2);
1483 printer->add_line(
"/** all global variables */");
1484 printer->fmt_push_block(
"struct {}", global_struct());
1486 for (
const auto& ion: info.ions) {
1487 auto name = fmt::format(
"{}_type", ion.name);
1488 printer->fmt_line(
"int {}{};", name, value_initialize);
1489 codegen_global_variables.push_back(make_symbol(name));
1492 if (info.point_process) {
1493 printer->fmt_line(
"int point_type{};", value_initialize);
1494 codegen_global_variables.push_back(make_symbol(
"point_type"));
1497 for (
const auto& var: info.state_vars) {
1498 auto name = var->get_name() +
"0";
1499 auto symbol = program_symtab->lookup(name);
1500 if (symbol ==
nullptr) {
1501 printer->fmt_line(
"{} {}{};", float_type, name, value_initialize);
1502 codegen_global_variables.push_back(make_symbol(name));
1510 auto& top_locals = info.top_local_variables;
1511 if (!info.vectorize && !top_locals.empty()) {
1512 for (
const auto& var: top_locals) {
1513 auto name = var->get_name();
1514 auto length = var->get_length();
1515 if (var->is_array()) {
1516 printer->fmt_line(
"{} {}[{}] /* TODO init top-local-array */;",
1521 printer->fmt_line(
"{} {} /* TODO init top-local */;", float_type, name);
1523 codegen_global_variables.push_back(var);
1527 if (!info.thread_variables.empty()) {
1528 printer->fmt_line(
"int thread_data_in_use{};", value_initialize);
1529 printer->fmt_line(
"{} thread_data[{}] /* TODO init thread_data */;",
1531 info.thread_var_data_size);
1532 codegen_global_variables.push_back(make_symbol(
"thread_data_in_use"));
1533 auto symbol = make_symbol(
"thread_data");
1534 symbol->set_as_array(info.thread_var_data_size);
1535 codegen_global_variables.push_back(symbol);
1539 printer->fmt_line(
"int reset{};", value_initialize);
1540 codegen_global_variables.push_back(make_symbol(
"reset"));
1542 printer->fmt_line(
"int mech_type{};", value_initialize);
1543 codegen_global_variables.push_back(make_symbol(
"mech_type"));
1545 for (
const auto& var: info.global_variables) {
1546 auto name = var->get_name();
1547 auto length = var->get_length();
1548 if (var->is_array()) {
1549 printer->fmt_line(
"{} {}[{}] /* TODO init const-array */;", float_type, name, length);
1552 if (
auto const& value_ptr = var->get_value()) {
1555 printer->fmt_line(
"{} {}{};",
1558 print_initializers ? fmt::format(
"{{{:g}}}", value) : std::string{});
1560 codegen_global_variables.push_back(var);
1563 for (
const auto& var: info.constant_variables) {
1564 auto const name = var->get_name();
1565 auto*
const value_ptr = var->get_value().get();
1566 double const value{value_ptr ? *value_ptr : 0};
1567 printer->fmt_line(
"{} {}{};",
1570 print_initializers ? fmt::format(
"{{{:g}}}", value) : std::string{});
1571 codegen_global_variables.push_back(var);
1574 print_sdlists_init(print_initializers);
1576 if (info.table_count > 0) {
1577 printer->fmt_line(
"double usetable{};", print_initializers ?
"{1}" :
"");
1580 for (
const auto& block: info.functions_with_table) {
1581 const auto& name = block->get_node_name();
1582 printer->fmt_line(
"{} tmin_{}{};", float_type, name, value_initialize);
1583 printer->fmt_line(
"{} mfac_{}{};", float_type, name, value_initialize);
1584 codegen_global_variables.push_back(make_symbol(
"tmin_" + name));
1585 codegen_global_variables.push_back(make_symbol(
"mfac_" + name));
1588 for (
const auto& variable: info.table_statement_variables) {
1589 auto const name =
"t_" + variable->get_name();
1590 auto const num_values = variable->get_num_values();
1591 if (variable->is_array()) {
1592 int array_len = variable->get_length();
1594 "{} {}[{}][{}]{};", float_type, name, array_len, num_values, value_initialize);
1596 printer->fmt_line(
"{} {}[{}]{};", float_type, name, num_values, value_initialize);
1598 codegen_global_variables.push_back(make_symbol(name));
1602 for (
const auto& f: info.function_tables) {
1603 printer->fmt_line(
"void* _ptable_{}{{}};", f->get_node_name());
1604 codegen_global_variables.push_back(make_symbol(
"_ptable_" + f->get_node_name()));
1607 if (info.vectorize && info.thread_data_index) {
1608 printer->fmt_line(
"ThreadDatum ext_call_thread[{}]{};",
1609 info.thread_data_index,
1611 codegen_global_variables.push_back(make_symbol(
"ext_call_thread"));
1614 printer->pop_block(
";");
1616 print_global_var_struct_assertions();
1617 print_global_var_struct_decl();
1626 auto variable_printer =
1627 [&](
const std::vector<SymbolType>& variables,
bool if_array,
bool if_vector) {
1628 for (
const auto& variable: variables) {
1629 if (variable->is_array() == if_array) {
1632 auto name = get_variable_name(variable->get_name(),
false);
1633 auto ename = add_escape_quote(variable->get_name() +
"_" + info.mod_suffix);
1634 auto length = variable->get_length();
1636 printer->fmt_line(
"{{{}, {}, {}}},", ename, name, length);
1638 printer->fmt_line(
"{{{}, &{}}},", ename, name);
1644 auto globals = info.global_variables;
1645 auto thread_vars = info.thread_variables;
1647 if (info.table_count > 0) {
1651 printer->add_newline(2);
1652 printer->add_line(
"/** connect global (scalar) variables to hoc -- */");
1653 printer->add_line(
"static DoubScal hoc_scalar_double[] = {");
1654 printer->increase_indent();
1655 variable_printer(globals,
false,
false);
1656 variable_printer(thread_vars,
false,
false);
1657 printer->add_line(
"{nullptr, nullptr}");
1658 printer->decrease_indent();
1659 printer->add_line(
"};");
1661 printer->add_newline(2);
1662 printer->add_line(
"/** connect global (array) variables to hoc -- */");
1663 printer->add_line(
"static DoubVec hoc_vector_double[] = {");
1664 printer->increase_indent();
1665 variable_printer(globals,
true,
true);
1666 variable_printer(thread_vars,
true,
true);
1667 printer->add_line(
"{nullptr, nullptr, 0}");
1668 printer->decrease_indent();
1669 printer->add_line(
"};");
1683 std::string register_type{};
1688 register_type =
"BAType::Before";
1690 dynamic_cast<const ast::BeforeBlock*
>(block)->get_bablock()->get_type()->get_value();
1693 register_type =
"BAType::After";
1695 dynamic_cast<const ast::AfterBlock*
>(block)->get_bablock()->get_type()->get_value();
1701 register_type +=
" + BAType::Breakpoint";
1703 register_type +=
" + BAType::Solve";
1705 register_type +=
" + BAType::Initial";
1707 register_type +=
" + BAType::Step";
1709 throw std::runtime_error(
"Unhandled Before/After type encountered during code generation");
1711 return register_type;
1733 printer->add_newline(2);
1734 printer->add_line(
"/** register channel with the simulator */");
1735 printer->fmt_push_block(
"void _{}_reg()", info.mod_file);
1738 auto suffix = add_escape_quote(info.mod_suffix);
1739 printer->add_newline();
1740 printer->fmt_line(
"int mech_type = nrn_get_mechtype({});", suffix);
1741 printer->fmt_line(
"{} = mech_type;", get_variable_name(
"mech_type",
false));
1742 printer->push_block(
"if (mech_type == -1)");
1743 printer->add_line(
"return;");
1744 printer->pop_block();
1746 printer->add_newline();
1747 printer->add_line(
"_nrn_layout_reg(mech_type, 0);");
1750 const auto mech_arguments = register_mechanism_arguments();
1751 const auto number_of_thread_objects = num_thread_objects();
1752 if (info.point_process) {
1753 printer->fmt_line(
"point_register_mech({}, {}, {}, {});",
1759 number_of_thread_objects);
1761 printer->fmt_line(
"register_mech({}, {});", mech_arguments, number_of_thread_objects);
1762 if (info.constructor_node) {
1763 printer->fmt_line(
"register_constructor({});",
1769 for (
const auto& ion: info.ions) {
1770 printer->fmt_line(
"{} = nrn_get_mechtype({});",
1771 get_variable_name(ion.name +
"_type",
false),
1772 add_escape_quote(ion.name +
"_ion"));
1774 printer->add_newline();
1780 if (info.vectorize && (info.thread_data_index != 0)) {
1782 printer->fmt_line(
"thread_mem_init({});", get_variable_name(
"ext_call_thread",
false));
1785 if (!info.thread_variables.empty()) {
1786 printer->fmt_line(
"{} = 0;", get_variable_name(
"thread_data_in_use"));
1789 if (info.thread_callback_register) {
1790 printer->add_line(
"_nrn_thread_reg0(mech_type, thread_mem_cleanup);");
1791 printer->add_line(
"_nrn_thread_reg1(mech_type, thread_mem_init);");
1794 if (info.emit_table_thread()) {
1795 auto name = method_name(
"check_table_thread");
1796 printer->fmt_line(
"_nrn_thread_table_reg(mech_type, {});", name);
1800 if (info.bbcore_pointer_used) {
1801 printer->add_line(
"hoc_reg_bbcore_read(mech_type, bbcore_read);");
1802 printer->add_line(
"hoc_reg_bbcore_write(mech_type, bbcore_write);");
1807 printer->add_line(
"hoc_register_prop_size(mech_type, float_variables_size(), int_variables_size());");
1811 for (
auto& semantic: info.semantics) {
1813 fmt::format(
"mech_type, {}, {}", semantic.index, add_escape_quote(semantic.name));
1814 printer->fmt_line(
"hoc_register_dparam_semantics({});", args);
1817 if (info.is_watch_used()) {
1819 printer->fmt_line(
"hoc_register_watch_check({}, mech_type);", watch_fun);
1822 if (info.write_concentration) {
1823 printer->add_line(
"nrn_writes_conc(mech_type, 0);");
1827 if (info.net_event_used) {
1828 printer->add_line(
"add_nrn_has_net_event(mech_type);");
1830 if (info.artificial_cell) {
1831 printer->fmt_line(
"add_nrn_artcell(mech_type, {});", info.tqitem_index);
1833 if (net_receive_buffering_required()) {
1834 printer->fmt_line(
"hoc_register_net_receive_buffering({}, mech_type);",
1835 method_name(
"net_buf_receive"));
1837 if (info.num_net_receive_parameters != 0) {
1838 auto net_recv_init_arg =
"nullptr";
1839 if (info.net_receive_initial_node !=
nullptr) {
1840 net_recv_init_arg =
"net_init";
1842 printer->fmt_line(
"set_pnt_receive(mech_type, {}, {}, num_net_receive_args());",
1843 method_name(
"net_receive"),
1846 if (info.for_netcon_used) {
1849 std::find_if(info.semantics.begin(), info.semantics.end(), [](
const IndexSemantics& a) {
1850 return a.name == naming::FOR_NETCON_SEMANTIC;
1852 printer->fmt_line(
"add_nrn_fornetcons(mech_type, {});",
index);
1855 if (info.net_event_used || info.net_send_used) {
1856 printer->add_line(
"hoc_register_net_send_buffering(mech_type);");
1860 for (
size_t i = 0; i < info.before_after_blocks.size(); i++) {
1862 const auto& block = info.before_after_blocks[i];
1864 std::string function_name = method_name(fmt::format(
"nrn_before_after_{}", i));
1865 printer->fmt_line(
"hoc_reg_ba(mech_type, {}, {});", function_name, register_type);
1869 printer->add_line(
"hoc_register_var(hoc_scalar_double, hoc_vector_double, NULL);");
1870 printer->pop_block();
1875 if (!info.thread_callback_register) {
1880 printer->add_newline(2);
1881 printer->add_line(
"/** thread memory allocation callback */");
1882 printer->push_block(
"static void thread_mem_init(ThreadDatum* thread) ");
1884 if (info.vectorize && info.derivimplicit_used()) {
1885 printer->fmt_line(
"thread[dith{}()].pval = nullptr;", info.derivimplicit_list_num);
1887 if (info.vectorize && (info.top_local_thread_size != 0)) {
1888 auto length = info.top_local_thread_size;
1889 auto allocation = fmt::format(
"(double*)mem_alloc({}, sizeof(double))", length);
1890 printer->fmt_line(
"thread[top_local_var_tid()].pval = {};", allocation);
1892 if (info.thread_var_data_size != 0) {
1893 auto length = info.thread_var_data_size;
1894 auto thread_data = get_variable_name(
"thread_data");
1895 auto thread_data_in_use = get_variable_name(
"thread_data_in_use");
1896 auto allocation = fmt::format(
"(double*)mem_alloc({}, sizeof(double))", length);
1897 printer->fmt_push_block(
"if ({})", thread_data_in_use);
1898 printer->fmt_line(
"thread[thread_var_tid()].pval = {};", allocation);
1899 printer->chain_block(
"else");
1900 printer->fmt_line(
"thread[thread_var_tid()].pval = {};", thread_data);
1901 printer->fmt_line(
"{} = 1;", thread_data_in_use);
1902 printer->pop_block();
1904 printer->pop_block();
1905 printer->add_newline(2);
1909 printer->add_line(
"/** thread memory cleanup callback */");
1910 printer->push_block(
"static void thread_mem_cleanup(ThreadDatum* thread) ");
1913 if (info.vectorize && info.derivimplicit_used()) {
1914 int n = info.derivimplicit_list_num;
1915 printer->fmt_line(
"free(thread[dith{}()].pval);", n);
1916 printer->fmt_line(
"nrn_destroy_newtonspace(static_cast<NewtonSpace*>(*newtonspace{}(thread)));", n);
1920 if (info.top_local_thread_size != 0) {
1921 auto line =
"free(thread[top_local_var_tid()].pval);";
1922 printer->add_line(line);
1924 if (info.thread_var_data_size != 0) {
1925 auto thread_data = get_variable_name(
"thread_data");
1926 auto thread_data_in_use = get_variable_name(
"thread_data_in_use");
1927 printer->fmt_push_block(
"if (thread[thread_var_tid()].pval == {})", thread_data);
1928 printer->fmt_line(
"{} = 0;", thread_data_in_use);
1929 printer->chain_block(
"else");
1930 printer->add_line(
"free(thread[thread_var_tid()].pval);");
1931 printer->pop_block();
1933 printer->pop_block();
1938 auto const value_initialize = print_initializers ?
"{}" :
"";
1939 auto int_type = default_int_data_type();
1940 printer->add_newline(2);
1941 printer->add_line(
"/** all mechanism instance variables and global variables */");
1942 printer->fmt_push_block(
"struct {} ", instance_struct());
1944 for (
auto const& [var, type]: info.neuron_global_variables) {
1945 auto const name = var->get_name();
1946 printer->fmt_line(
"{}* {}{};",
1949 print_initializers ? fmt::format(
"{{&coreneuron::{}}}", name)
1952 for (
auto& var: codegen_float_variables) {
1953 const auto& name = var->get_name();
1954 auto type = get_range_var_float_type(var);
1955 auto qualifier = is_constant_variable(name) ?
"const " :
"";
1956 printer->fmt_line(
"{}{}* {}{};", qualifier, type, name, value_initialize);
1958 for (
auto& var: codegen_int_variables) {
1959 const auto& name = var.symbol->get_name();
1960 if (var.is_index || var.is_integer) {
1961 auto qualifier = var.is_constant ?
"const " :
"";
1962 printer->fmt_line(
"{}{}* {}{};", qualifier, int_type, name, value_initialize);
1964 auto qualifier = var.is_constant ?
"const " :
"";
1965 auto type = var.is_vdata ?
"void*" : default_float_data_type();
1966 printer->fmt_line(
"{}{}* {}{};", qualifier, type, name, value_initialize);
1970 printer->fmt_line(
"{}* {}{};",
1973 print_initializers ? fmt::format(
"{{&{}}}", global_struct_instance())
1975 printer->pop_block(
";");
1980 if (!ion_variable_struct_required()) {
1983 printer->add_newline(2);
1984 printer->add_line(
"/** ion write variables */");
1985 printer->push_block(
"struct IonCurVar");
1987 std::string float_type = default_float_data_type();
1988 std::vector<std::string> members;
1990 for (
auto& ion: info.ions) {
1991 for (
auto& var: ion.writes) {
1992 printer->fmt_line(
"{} {};", float_type, var);
1993 members.push_back(var);
1996 for (
auto& var: info.currents) {
1997 if (!info.is_ion_variable(var)) {
1998 printer->fmt_line(
"{} {};", float_type, var);
1999 members.push_back(var);
2003 print_ion_var_constructor(members);
2005 printer->pop_block(
";");
2010 const std::vector<std::string>& members) {
2012 printer->add_newline();
2013 printer->add_indent();
2014 printer->add_text(
"IonCurVar() : ");
2015 for (
int i = 0; i < members.size(); i++) {
2016 printer->fmt_text(
"{}(0)", members[i]);
2017 if (i + 1 < members.size()) {
2018 printer->add_text(
", ");
2021 printer->add_text(
" {}");
2022 printer->add_newline();
2027 printer->add_line(
"IonCurVar ionvar;");
2037 auto type = float_data_type();
2038 printer->add_newline(2);
2039 printer->add_line(
"/** allocate and setup array for range variable */");
2040 printer->fmt_push_block(
"static inline {}* setup_range_variable(double* variable, int n)",
2042 printer->fmt_line(
"{0}* data = ({0}*) mem_alloc(n, sizeof({0}));", type);
2043 printer->push_block(
"for(size_t i = 0; i < n; i++)");
2044 printer->add_line(
"data[i] = variable[i];");
2045 printer->pop_block();
2046 printer->add_line(
"return data;");
2047 printer->pop_block();
2059 auto with = NmodlType::read_ion_var
2060 | NmodlType::write_ion_var
2061 | NmodlType::pointer_var
2062 | NmodlType::bbcore_pointer_var
2063 | NmodlType::extern_neuron_variable;
2065 bool need_default_type = symbol->has_any_property(with);
2066 if (need_default_type) {
2067 return default_float_data_type();
2069 return float_data_type();
2074 if (range_variable_setup_required()) {
2075 print_setup_range_variable();
2078 printer->add_newline();
2079 printer->add_line(
"// Allocate instance structure");
2080 printer->fmt_push_block(
"static void {}(NrnThread* nt, Memb_list* ml, int type)",
2082 printer->add_line(
"assert(!ml->instance);");
2083 printer->add_line(
"assert(!ml->global_variables);");
2084 printer->add_line(
"assert(ml->global_variables_size == 0);");
2085 printer->fmt_line(
"auto* const inst = new {}{{}};", instance_struct());
2086 printer->fmt_line(
"assert(inst->{} == &{});",
2088 global_struct_instance());
2089 printer->add_line(
"ml->instance = inst;");
2091 printer->fmt_line(
"ml->global_variables_size = sizeof({});", global_struct());
2092 printer->pop_block();
2093 printer->add_newline();
2095 auto const cast_inst_and_assert_validity = [&]() {
2096 printer->fmt_line(
"auto* const inst = static_cast<{}*>(ml->instance);", instance_struct());
2097 printer->add_line(
"assert(inst);");
2099 printer->fmt_line(
"assert(inst->{} == &{});",
2101 global_struct_instance());
2103 printer->fmt_line(
"assert(ml->global_variables_size == sizeof({}));", global_struct());
2108 print_instance_struct_transfer_routine_declarations();
2110 printer->add_line(
"// Deallocate the instance structure");
2111 printer->fmt_push_block(
"static void {}(NrnThread* nt, Memb_list* ml, int type)",
2113 cast_inst_and_assert_validity();
2116 if (info.random_variables.size()) {
2117 printer->add_line(
"int pnodecount = ml->_nodecount_padded;");
2118 printer->add_line(
"int nodecount = ml->nodecount;");
2119 printer->add_line(
"Datum* indexes = ml->pdata;");
2120 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2121 for (
const auto& var: info.random_variables) {
2122 const auto& name = get_variable_name(var->get_name());
2123 printer->fmt_line(
"nrnran123_deletestream((nrnran123_State*){});", name);
2125 printer->pop_block();
2127 print_instance_struct_delete_from_device();
2128 printer->add_multi_line(R
"CODE(
2130 ml->instance = nullptr;
2131 ml->global_variables = nullptr;
2132 ml->global_variables_size = 0;
2134 printer->pop_block();
2135 printer->add_newline();
2138 printer->add_line("/** initialize mechanism instance variables */");
2139 printer->push_block(
"static inline void setup_instance(NrnThread* nt, Memb_list* ml)");
2140 cast_inst_and_assert_validity();
2143 printer->add_line(
"int pnodecount = ml->_nodecount_padded;");
2144 stride =
"*pnodecount";
2146 printer->add_line(
"Datum* indexes = ml->pdata;");
2148 auto const float_type = default_float_data_type();
2152 for (
auto const& [var, type]: info.neuron_global_variables) {
2153 ptr_members.push_back(var->get_name());
2155 ptr_members.reserve(ptr_members.size() + codegen_float_variables.size() +
2156 codegen_int_variables.size());
2157 for (
auto& var: codegen_float_variables) {
2158 auto name = var->get_name();
2159 auto range_var_type = get_range_var_float_type(var);
2160 if (float_type == range_var_type) {
2161 auto const variable = fmt::format(
"ml->data+{}{}",
id, stride);
2162 printer->fmt_line(
"inst->{} = {};", name, variable);
2165 printer->fmt_line(
"inst->{} = setup_range_variable(ml->data+{}{}, pnodecount);",
2170 ptr_members.push_back(std::move(name));
2171 id += var->get_length();
2174 for (
auto& var: codegen_int_variables) {
2175 auto name = var.symbol->get_name();
2176 auto const variable = [&var]() {
2177 if (var.is_index || var.is_integer) {
2179 }
else if (var.is_vdata) {
2180 return "nt->_vdata";
2185 printer->fmt_line(
"inst->{} = {};", name, variable);
2186 ptr_members.push_back(std::move(name));
2188 print_instance_struct_copy_to_device();
2189 printer->pop_block();
2190 printer->add_newline();
2192 print_instance_struct_transfer_routines(ptr_members);
2197 if (info.artificial_cell) {
2198 printer->add_line(
"double v = 0.0;");
2200 printer->add_line(
"int node_id = node_index[id];");
2201 printer->add_line(
"double v = voltage[node_id];");
2205 if (ion_variable_struct_required()) {
2206 printer->add_line(
"IonCurVar ionvar;");
2211 for (
auto& statement: read_statements) {
2212 printer->add_line(statement);
2216 for (
auto& var: info.state_vars) {
2217 auto name = var->get_name();
2218 if (!info.is_ionic_conc(name)) {
2219 auto lhs = get_variable_name(name);
2220 auto rhs = get_variable_name(name +
"0");
2221 if (var->is_array()) {
2222 for (
int i = 0; i < var->get_length(); ++i) {
2223 printer->fmt_line(
"{}[{}] = {};", lhs, i, rhs);
2226 printer->fmt_line(
"{} = {};", lhs, rhs);
2232 if (node !=
nullptr) {
2234 print_statement_block(*block,
false,
false);
2239 for (
auto& statement: write_statements) {
2241 printer->add_line(text);
2248 const std::string& function_name) {
2250 if (function_name.empty()) {
2251 method = compute_method_name(type);
2253 method = function_name;
2255 auto args =
"NrnThread* nt, Memb_list* ml, int type";
2259 args =
"NrnThread* nt, Memb_list* ml";
2262 print_global_method_annotation();
2263 printer->fmt_push_block(
"void {}({})", method, args);
2267 print_kernel_data_present_annotation_block_begin();
2271 printer->add_line(
"#ifndef CORENEURON_BUILD");
2273 printer->add_multi_line(R
"CODE(
2274 int nodecount = ml->nodecount;
2275 int pnodecount = ml->_nodecount_padded;
2276 const int* node_index = ml->nodeindices;
2277 double* data = ml->data;
2278 const double* voltage = nt->_actual_v;
2282 printer->add_line(
"double* vec_rhs = nt->_actual_rhs;");
2283 printer->add_line(
"double* vec_d = nt->_actual_d;");
2284 print_rhs_d_shadow_variables();
2286 printer->add_line(
"Datum* indexes = ml->pdata;");
2287 printer->add_line(
"ThreadDatum* thread = ml->_thread;");
2290 printer->add_newline();
2291 printer->add_line(
"setup_instance(nt, ml);");
2293 printer->fmt_line(
"auto* const inst = static_cast<{}*>(ml->instance);", instance_struct());
2294 printer->add_newline(1);
2298 printer->add_newline(2);
2299 printer->add_line(
"/** initialize channel */");
2302 if (info.derivimplicit_used()) {
2303 printer->add_newline();
2304 int nequation = info.num_equations;
2305 int list_num = info.derivimplicit_list_num;
2307 printer->fmt_line(
"int& deriv_advance_flag = *deriv{}_advance(thread);", list_num);
2308 printer->add_line(
"deriv_advance_flag = 0;");
2309 print_deriv_advance_flag_transfer_to_device();
2310 printer->fmt_line(
"auto ns = newtonspace{}(thread);", list_num);
2311 printer->fmt_line(
"auto& th = thread[dith{}()];", list_num);
2312 printer->push_block(
"if (*ns == nullptr)");
2313 printer->fmt_line(
"int vec_size = 2*{}*pnodecount*sizeof(double);", nequation);
2314 printer->fmt_line(
"double* vec = makevector(vec_size);", nequation);
2315 printer->fmt_line(
"th.pval = vec;", list_num);
2316 printer->fmt_line(
"*ns = nrn_cons_newtonspace({}, pnodecount);", nequation);
2317 print_newtonspace_transfer_to_device();
2318 printer->pop_block();
2325 print_global_variable_device_update_annotation();
2327 if (skip_init_check) {
2328 printer->push_block(
"if (_nrn_skip_initmodel == 0)");
2331 if (!info.changed_dt.empty()) {
2332 printer->fmt_line(
"double _save_prev_dt = {};",
2334 printer->fmt_line(
"{} = {};",
2337 print_dt_update_to_device();
2341 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2343 if (info.net_receive_node !=
nullptr) {
2344 printer->fmt_line(
"{} = -1e20;", get_variable_name(
"tsave"));
2347 print_initial_block(info.initial_node);
2348 printer->pop_block();
2350 if (!info.changed_dt.empty()) {
2352 print_dt_update_to_device();
2355 printer->pop_block();
2357 if (info.derivimplicit_used()) {
2358 printer->add_line(
"deriv_advance_flag = 1;");
2359 print_deriv_advance_flag_transfer_to_device();
2362 if (info.net_send_used && !info.artificial_cell) {
2363 print_send_event_move();
2366 print_kernel_data_present_annotation_block_end();
2367 if (skip_init_check) {
2368 printer->pop_block();
2374 std::string ba_type;
2375 std::shared_ptr<ast::BABlock> ba_block;
2385 std::string ba_block_type = ba_block->get_type()->eval();
2388 std::string function_name = method_name(fmt::format(
"nrn_before_after_{}", block_id));
2391 printer->add_newline(2);
2392 printer->fmt_line(
"/** {} of block type {} # {} */", ba_type, ba_block_type, block_id);
2396 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2398 printer->add_line(
"int node_id = node_index[id];");
2399 printer->add_line(
"double v = voltage[node_id];");
2404 for (
auto& statement: read_statements) {
2405 printer->add_line(statement);
2409 printer->add_indent();
2410 print_statement_block(*ba_block->get_statement_block());
2411 printer->add_newline();
2415 for (
auto& statement: write_statements) {
2417 printer->add_line(text);
2421 printer->pop_block();
2422 printer->pop_block();
2423 print_kernel_data_present_annotation_block_end();
2427 printer->add_newline(2);
2429 if (info.constructor_node !=
nullptr) {
2430 const auto& block = info.constructor_node->get_statement_block();
2431 print_statement_block(*block,
false,
false);
2433 printer->add_line(
"#endif");
2434 printer->pop_block();
2439 printer->add_newline(2);
2441 if (info.destructor_node !=
nullptr) {
2442 const auto& block = info.destructor_node->get_statement_block();
2443 print_statement_block(*block,
false,
false);
2445 printer->add_line(
"#endif");
2446 printer->pop_block();
2451 for (
const auto& functor_name: info.functor_names) {
2452 printer->add_newline(2);
2453 print_functor_definition(*functor_name.first);
2459 printer->add_newline(2);
2461 printer->fmt_push_block(
"static void {}(double* data, Datum* indexes, int type)", method);
2462 printer->add_line(
"// do nothing");
2463 printer->pop_block();
2472 if (info.watch_statements.empty()) {
2476 printer->add_newline(2);
2477 auto inst = fmt::format(
"{}* inst", instance_struct());
2479 printer->fmt_push_block(
2480 "static void nrn_watch_activate({}, int id, int pnodecount, int watch_id, "
2481 "double v, bool &watch_remove)",
2485 printer->push_block(
"if (watch_remove == false)");
2486 for (
int i = 0; i < info.watch_count; i++) {
2487 auto name = get_variable_name(fmt::format(
"watch{}", i + 1));
2488 printer->fmt_line(
"{} = 0;", name);
2490 printer->add_line(
"watch_remove = true;");
2491 printer->pop_block();
2497 for (
int i = 0; i < info.watch_statements.size(); i++) {
2498 auto statement = info.watch_statements[i];
2499 printer->fmt_push_block(
"if (watch_id == {})", i);
2501 auto varname = get_variable_name(fmt::format(
"watch{}", i + 1));
2502 printer->add_indent();
2503 printer->fmt_text(
"{} = 2 + (", varname);
2504 auto watch = statement->get_statements().front();
2505 watch->get_expression()->visit_children(*
this);
2506 printer->add_text(
");");
2507 printer->add_newline();
2509 printer->pop_block();
2511 printer->pop_block();
2520 if (info.watch_statements.empty()) {
2524 printer->add_newline(2);
2525 printer->add_line(
"/** routine to check watch activation */");
2534 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2536 if (info.is_voltage_used_by_watch_statements()) {
2537 printer->add_line(
"int node_id = node_index[id];");
2538 printer->add_line(
"double v = voltage[node_id];");
2543 printer->add_line(
"bool watch_untriggered = true;");
2545 for (
int i = 0; i < info.watch_statements.size(); i++) {
2546 auto statement = info.watch_statements[i];
2547 const auto& watch = statement->get_statements().front();
2548 const auto& varname = get_variable_name(fmt::format(
"watch{}", i + 1));
2551 printer->fmt_push_block(
"if ({}&2 && watch_untriggered)", varname);
2554 printer->add_indent();
2555 printer->add_text(
"if (");
2556 watch->get_expression()->accept(*
this);
2557 printer->add_text(
") {");
2558 printer->add_newline();
2559 printer->increase_indent();
2562 printer->fmt_push_block(
"if (({}&1) == 0)", varname);
2564 printer->add_line(
"watch_untriggered = false;");
2566 const auto& tqitem = get_variable_name(
"tqitem");
2567 const auto& point_process = get_variable_name(
"point_process");
2568 printer->add_indent();
2569 printer->add_text(
"net_send_buffering(");
2570 const auto& t = get_variable_name(
"t");
2571 printer->fmt_text(
"nt, ml->_net_send_buffer, 0, {}, -1, {}, {}+0.0, ",
2575 watch->get_value()->accept(*
this);
2576 printer->add_text(
");");
2577 printer->add_newline();
2578 printer->pop_block();
2580 printer->add_line(varname,
" = 3;");
2584 printer->decrease_indent();
2585 printer->push_block(
"} else");
2586 printer->add_line(varname,
" = 2;");
2587 printer->pop_block();
2590 printer->pop_block();
2594 printer->pop_block();
2595 print_send_event_move();
2596 print_kernel_data_present_annotation_block_end();
2597 printer->pop_block();
2602 bool need_mech_inst) {
2603 printer->add_multi_line(R
"CODE(
2604 int tid = pnt->_tid;
2605 int id = pnt->_i_instance;
2610 printer->add_line(
"NrnThread* nt = nrn_threads + tid;");
2611 printer->add_line(
"Memb_list* ml = nt->_ml_list[pnt->_type];");
2614 print_kernel_data_present_annotation_block_begin();
2617 printer->add_multi_line(R
"CODE(
2618 int nodecount = ml->nodecount;
2619 int pnodecount = ml->_nodecount_padded;
2620 double* data = ml->data;
2621 double* weights = nt->weights;
2622 Datum* indexes = ml->pdata;
2623 ThreadDatum* thread = ml->_thread;
2625 if (need_mech_inst) {
2626 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
2630 print_net_init_acc_serial_annotation_block_begin();
2634 auto parameters = info.net_receive_node->get_parameters();
2635 if (!parameters.empty()) {
2637 printer->add_newline();
2638 for (
auto& parameter: parameters) {
2639 auto name = parameter->get_node_name();
2642 printer->fmt_line(
"double* {} = weights + weight_index + {};", name, i);
2654 const auto& tqitem = get_variable_name(
"tqitem");
2655 std::string weight_index =
"weight_index";
2656 std::string pnt =
"pnt";
2660 if (!printing_net_receive && !printing_net_init) {
2662 auto var = get_variable_name(
"point_process");
2663 if (info.artificial_cell) {
2664 pnt =
"(Point_process*)" + var;
2670 if (info.artificial_cell) {
2671 printer->fmt_text(
"artcell_net_send(&{}, {}, {}, nt->_t+", tqitem, weight_index, pnt);
2673 const auto& point_process = get_variable_name(
"point_process");
2674 const auto& t = get_variable_name(
"t");
2675 printer->add_text(
"net_send_buffering(");
2676 printer->fmt_text(
"nt, ml->_net_send_buffer, 0, {}, {}, {}, {}+", tqitem, weight_index, point_process, t);
2679 print_vector_elements(arguments,
", ");
2680 printer->add_text(
')');
2685 if (!printing_net_receive && !printing_net_init) {
2686 throw std::runtime_error(
"Error : net_move only allowed in NET_RECEIVE block");
2690 const auto& tqitem = get_variable_name(
"tqitem");
2691 std::string weight_index =
"-1";
2692 std::string pnt =
"pnt";
2696 if (info.artificial_cell) {
2697 printer->fmt_text(
"artcell_net_move(&{}, {}, ", tqitem, pnt);
2698 print_vector_elements(arguments,
", ");
2699 printer->add_text(
")");
2701 const auto& point_process = get_variable_name(
"point_process");
2702 printer->add_text(
"net_send_buffering(");
2703 printer->fmt_text(
"nt, ml->_net_send_buffer, 2, {}, {}, {}, ", tqitem, weight_index, point_process);
2704 print_vector_elements(arguments,
", ");
2705 printer->add_text(
", 0.0");
2706 printer->add_text(
")");
2713 if (info.artificial_cell) {
2714 printer->add_text(
"net_event(pnt, ");
2715 print_vector_elements(arguments,
", ");
2717 const auto& point_process = get_variable_name(
"point_process");
2718 printer->add_text(
"net_send_buffering(");
2719 printer->fmt_text(
"nt, ml->_net_send_buffer, 1, -1, -1, {}, ", point_process);
2720 print_vector_elements(arguments,
", ");
2721 printer->add_text(
", 0.0");
2723 printer->add_text(
")");
2752 for (
auto& parameter: parameters) {
2753 const auto& name = parameter->get_node_name();
2764 const auto node = info.net_receive_initial_node;
2765 if (node ==
nullptr) {
2772 printing_net_init =
true;
2773 auto args =
"Point_process* pnt, int weight_index, double flag";
2774 printer->add_newline(2);
2775 printer->add_line(
"/** initialize block for net receive */");
2776 printer->fmt_push_block(
"static void net_init({})", args);
2777 auto block = node->get_statement_block().get();
2778 if (block->get_statements().empty()) {
2779 printer->add_line(
"// do nothing");
2781 print_net_receive_common_code(*node);
2782 print_statement_block(*block,
false,
false);
2783 if (node->is_initial_block()) {
2784 print_net_init_acc_serial_annotation_block_end();
2785 print_kernel_data_present_annotation_block_end();
2786 printer->add_line(
"auto& nsb = ml->_net_send_buffer;");
2787 print_net_send_buf_update_to_host();
2790 printer->pop_block();
2791 printing_net_init =
false;
2796 printer->add_newline();
2797 printer->add_line(
"NetSendBuffer_t* nsb = ml->_net_send_buffer;");
2798 print_net_send_buf_update_to_host();
2799 printer->push_block(
"for (int i=0; i < nsb->_cnt; i++)");
2800 printer->add_multi_line(R
"CODE(
2801 int type = nsb->_sendtype[i];
2803 double t = nsb->_nsb_t[i];
2804 double flag = nsb->_nsb_flag[i];
2805 int vdata_index = nsb->_vdata_index[i];
2806 int weight_index = nsb->_weight_index[i];
2807 int point_index = nsb->_pnt_index[i];
2808 net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag);
2810 printer->pop_block();
2811 printer->add_line("nsb->_cnt = 0;");
2812 print_net_send_buf_count_update_to_device();
2817 return fmt::format(
"void {}(NrnThread* nt)", method_name(
"net_buf_receive"));
2822 printer->add_line(
"Memb_list* ml = get_memb_list(nt);");
2823 printer->push_block(
"if (!ml)");
2824 printer->add_line(
"return;");
2825 printer->pop_block();
2826 printer->add_newline();
2831 printer->add_line(
"int count = nrb->_displ_cnt;");
2833 printer->push_block(
"for (int i = 0; i < count; i++)");
2838 printer->pop_block();
2843 if (!net_receive_required() || info.artificial_cell) {
2846 printer->add_newline(2);
2847 printer->push_block(net_receive_buffering_declaration());
2849 print_get_memb_list();
2851 const auto& net_receive = method_name(
"net_receive_kernel");
2853 print_kernel_data_present_annotation_block_begin();
2855 printer->add_line(
"NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;");
2856 if (need_mech_inst) {
2857 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
2859 print_net_receive_loop_begin();
2860 printer->add_line(
"int start = nrb->_displ[i];");
2861 printer->add_line(
"int end = nrb->_displ[i+1];");
2862 printer->push_block(
"for (int j = start; j < end; j++)");
2863 printer->add_multi_line(R
"CODE(
2864 int index = nrb->_nrb_index[j];
2865 int offset = nrb->_pnt_index[index];
2866 double t = nrb->_nrb_t[index];
2867 int weight_index = nrb->_weight_index[index];
2868 double flag = nrb->_nrb_flag[index];
2869 Point_process* point_process = nt->pntprocs + offset;
2871 printer->add_line(net_receive, "(t, point_process, inst, nt, ml, weight_index, flag);");
2872 printer->pop_block();
2873 print_net_receive_loop_end();
2875 print_device_stream_wait();
2876 printer->add_line(
"nrb->_displ_cnt = 0;");
2877 printer->add_line(
"nrb->_cnt = 0;");
2879 if (info.net_send_used || info.net_event_used) {
2880 print_send_event_move();
2883 print_kernel_data_present_annotation_block_end();
2884 printer->pop_block();
2889 printer->add_line(
"i = nsb->_cnt++;");
2894 printer->push_block(
"if (i >= nsb->_size)");
2895 printer->add_line(
"nsb->grow();");
2896 printer->pop_block();
2901 if (!net_send_buffer_required()) {
2905 printer->add_newline(2);
2906 print_device_method_annotation();
2908 "const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, "
2909 "int weight_index, int point_index, double t, double flag";
2910 printer->fmt_push_block(
"static inline void net_send_buffering({})", args);
2911 printer->add_line(
"int i = 0;");
2912 print_net_send_buffering_cnt_update();
2913 print_net_send_buffering_grow();
2914 printer->push_block(
"if (i < nsb->_size)");
2915 printer->add_multi_line(R
"CODE(
2916 nsb->_sendtype[i] = type;
2917 nsb->_vdata_index[i] = vdata_index;
2918 nsb->_weight_index[i] = weight_index;
2919 nsb->_pnt_index[i] = point_index;
2921 nsb->_nsb_flag[i] = flag;
2923 printer->pop_block();
2924 printer->pop_block();
2929 if (!net_receive_required()) {
2933 printing_net_receive =
true;
2934 const auto node = info.net_receive_node;
2941 if (!info.artificial_cell) {
2942 name = method_name(
"net_receive_kernel");
2943 params.emplace_back(
"",
"double",
"",
"t");
2944 params.emplace_back(
"",
"Point_process*",
"",
"pnt");
2945 params.emplace_back(
"", fmt::format(
"{}*", instance_struct()),
2947 params.emplace_back(
"",
"NrnThread*",
"",
"nt");
2948 params.emplace_back(
"",
"Memb_list*",
"",
"ml");
2949 params.emplace_back(
"",
"int",
"",
"weight_index");
2950 params.emplace_back(
"",
"double",
"",
"flag");
2952 name = method_name(
"net_receive");
2953 params.emplace_back(
"",
"Point_process*",
"",
"pnt");
2954 params.emplace_back(
"",
"int",
"",
"weight_index");
2955 params.emplace_back(
"",
"double",
"",
"flag");
2958 printer->add_newline(2);
2959 printer->fmt_push_block(
"static inline void {}({})", name, get_parameter_str(params));
2960 print_net_receive_common_code(*node, info.artificial_cell);
2961 if (info.artificial_cell) {
2962 printer->add_line(
"double t = nt->_t;");
2968 printer->add_line(
"int node_id = ml->nodeindices[id];");
2969 printer->add_line(
"v = nt->_actual_v[node_id];");
2972 printer->fmt_line(
"{} = t;", get_variable_name(
"tsave"));
2974 if (info.is_watch_used()) {
2975 printer->add_line(
"bool watch_remove = false;");
2978 printer->add_indent();
2979 node->get_statement_block()->accept(*
this);
2980 printer->add_newline();
2981 printer->pop_block();
2983 printing_net_receive =
false;
2988 if (!net_receive_required()) {
2992 printing_net_receive =
true;
2993 if (!info.artificial_cell) {
2994 const auto& name = method_name(
"net_receive");
2996 params.emplace_back(
"",
"Point_process*",
"",
"pnt");
2997 params.emplace_back(
"",
"int",
"",
"weight_index");
2998 params.emplace_back(
"",
"double",
"",
"flag");
2999 printer->add_newline(2);
3000 printer->fmt_push_block(
"static void {}({})", name, get_parameter_str(params));
3001 printer->add_line(
"NrnThread* nt = nrn_threads + pnt->_tid;");
3002 printer->add_line(
"Memb_list* ml = get_memb_list(nt);");
3003 printer->add_line(
"NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;");
3004 printer->push_block(
"if (nrb->_cnt >= nrb->_size)");
3005 printer->add_line(
"realloc_net_receive_buffer(nt, ml);");
3006 printer->pop_block();
3007 printer->add_multi_line(R
"CODE(
3009 nrb->_pnt_index[id] = pnt-nt->pntprocs;
3010 nrb->_weight_index[id] = weight_index;
3011 nrb->_nrb_t[id] = nt->_t;
3012 nrb->_nrb_flag[id] = flag;
3015 printer->pop_block();
3017 printing_net_receive = false;
3029 auto ext_args = external_method_arguments();
3030 auto ext_params = external_method_parameters();
3031 auto suffix = info.mod_suffix;
3032 auto list_num = info.derivimplicit_list_num;
3034 auto primes_size = info.primes_size;
3035 auto stride =
"*pnodecount+id";
3037 printer->add_newline(2);
3039 printer->push_block(
"namespace");
3040 printer->fmt_push_block(
"struct _newton_{}_{}", block_name, info.mod_suffix);
3041 printer->fmt_push_block(
"int operator()({}) const", external_method_parameters());
3042 auto const instance = fmt::format(
"auto* const inst = static_cast<{0}*>(ml->instance);",
3044 auto const slist1 = fmt::format(
"auto const& slist{} = {};",
3046 get_variable_name(fmt::format(
"slist{}", list_num)));
3047 auto const slist2 = fmt::format(
"auto& slist{} = {};",
3049 get_variable_name(fmt::format(
"slist{}", list_num + 1)));
3050 auto const dlist1 = fmt::format(
"auto const& dlist{} = {};",
3052 get_variable_name(fmt::format(
"dlist{}", list_num)));
3053 auto const dlist2 = fmt::format(
3054 "double* dlist{} = static_cast<double*>(thread[dith{}()].pval) + ({}*pnodecount);",
3058 printer->add_line(instance);
3059 if (ion_variable_struct_required()) {
3060 print_ion_variable();
3062 printer->fmt_line(
"double* savstate{} = static_cast<double*>(thread[dith{}()].pval);",
3065 printer->add_line(slist1);
3066 printer->add_line(dlist1);
3067 printer->add_line(dlist2);
3071 printer->add_line(
"int counter = -1;");
3072 printer->fmt_push_block(
"for (int i=0; i<{}; i++)", info.num_primes);
3073 printer->fmt_push_block(
"if (*deriv{}_advance(thread))", list_num);
3075 "dlist{0}[(++counter){1}] = "
3076 "data[dlist{2}[i]{1}]-(data[slist{2}[i]{1}]-savstate{2}[i{1}])/nt->_dt;",
3080 printer->chain_block(
"else");
3081 printer->fmt_line(
"dlist{0}[(++counter){1}] = data[slist{2}[i]{1}]-savstate{2}[i{1}];",
3085 printer->pop_block();
3086 printer->pop_block();
3087 printer->add_line(
"return 0;");
3088 printer->pop_block();
3089 printer->pop_block(
";");
3090 printer->pop_block();
3091 printer->add_newline();
3092 printer->fmt_push_block(
"int {}_{}({})", block_name, suffix, ext_params);
3093 printer->add_line(instance);
3094 printer->fmt_line(
"double* savstate{} = (double*) thread[dith{}()].pval;", list_num, list_num);
3095 printer->add_line(slist1);
3096 printer->add_line(slist2);
3097 printer->add_line(dlist2);
3098 printer->fmt_push_block(
"for (int i=0; i<{}; i++)", info.num_primes);
3099 printer->fmt_line(
"savstate{}[i{}] = data[slist{}[i]{}];", list_num, stride, list_num, stride);
3100 printer->pop_block();
3102 "int reset = nrn_newton_thread(static_cast<NewtonSpace*>(*newtonspace{}(thread)), {}, "
3103 "slist{}, _newton_{}_{}{{}}, dlist{}, {});",
3111 printer->add_line(
"return reset;");
3112 printer->pop_block();
3113 printer->add_newline(2);
3128 if (!nrn_state_required()) {
3132 printer->add_newline(2);
3133 printer->add_line(
"/** update state */");
3135 print_channel_iteration_block_parallel_hint(
BlockType::State, info.nrn_state_block);
3136 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
3138 printer->add_line(
"int node_id = node_index[id];");
3139 printer->add_line(
"double v = voltage[node_id];");
3146 if (ion_variable_struct_required()) {
3147 print_ion_variable();
3151 for (
auto& statement: read_statements) {
3152 printer->add_line(statement);
3155 if (info.nrn_state_block) {
3156 info.nrn_state_block->visit_children(*
this);
3159 if (info.currents.empty() && info.breakpoint_node !=
nullptr) {
3160 auto block = info.breakpoint_node->get_statement_block();
3161 print_statement_block(*block,
false,
false);
3165 for (
auto& statement: write_statements) {
3166 const auto& text = process_shadow_update_statement(statement,
BlockType::State);
3167 printer->add_line(text);
3169 printer->pop_block();
3171 print_kernel_data_present_annotation_block_end();
3173 printer->pop_block();
3183 const auto& args = internal_method_parameters();
3185 printer->add_newline(2);
3186 print_device_method_annotation();
3187 printer->fmt_push_block(
"inline double nrn_current_{}({})",
3189 get_parameter_str(args));
3190 printer->add_line(
"double current = 0.0;");
3191 print_statement_block(*block,
false,
false);
3192 for (
auto& current: info.currents) {
3193 const auto& name = get_variable_name(current);
3194 printer->fmt_line(
"current += {};", name);
3196 printer->add_line(
"return current;");
3197 printer->pop_block();
3203 print_statement_block(*block,
false,
false);
3204 if (!info.currents.empty()) {
3206 for (
const auto& current: info.currents) {
3207 auto var = breakpoint_current(current);
3208 sum += get_variable_name(var);
3209 if (¤t != &info.currents.back()) {
3213 printer->fmt_line(
"double rhs = {};", sum);
3217 for (
const auto& conductance: info.conductances) {
3218 auto var = breakpoint_current(conductance.variable);
3219 sum += get_variable_name(var);
3220 if (&conductance != &info.conductances.back()) {
3224 printer->fmt_line(
"double g = {};", sum);
3226 for (
const auto& conductance: info.conductances) {
3227 if (!conductance.ion.empty()) {
3229 const auto& rhs = get_variable_name(conductance.variable);
3232 printer->add_line(text);
3239 printer->fmt_line(
"double g = nrn_current_{}({}+0.001);",
3241 internal_method_arguments());
3242 for (
auto& ion: info.ions) {
3243 for (
auto& var: ion.writes) {
3244 if (ion.is_ionic_current(var)) {
3245 const auto& name = get_variable_name(var);
3246 printer->fmt_line(
"double di{} = {};", ion.name, name);
3250 printer->fmt_line(
"double rhs = nrn_current_{}({});",
3252 internal_method_arguments());
3253 printer->add_line(
"g = (g-rhs)/0.001;");
3254 for (
auto& ion: info.ions) {
3255 for (
auto& var: ion.writes) {
3256 if (ion.is_ionic_current(var)) {
3258 auto rhs = fmt::format(
"(di{}-{})/0.001", ion.name, get_variable_name(var));
3259 if (info.point_process) {
3261 rhs += fmt::format(
"*1.e2/{}", area);
3265 printer->add_line(text);
3273 printer->add_line(
"int node_id = node_index[id];");
3274 printer->add_line(
"double v = voltage[node_id];");
3276 if (ion_variable_struct_required()) {
3277 print_ion_variable();
3281 for (
auto& statement: read_statements) {
3282 printer->add_line(statement);
3285 if (info.conductances.empty()) {
3286 print_nrn_cur_non_conductance_kernel();
3288 print_nrn_cur_conductance_kernel(node);
3292 for (
auto& statement: write_statements) {
3294 printer->add_line(text);
3297 if (info.point_process) {
3299 printer->fmt_line(
"double mfactor = 1.e2/{};", area);
3300 printer->add_line(
"g = g*mfactor;");
3301 printer->add_line(
"rhs = rhs*mfactor;");
3309 if (!info.electrode_current) {
3313 auto rhs_op = operator_for_rhs();
3314 auto d_op = operator_for_d();
3315 if (info.point_process) {
3316 rhs =
"shadow_rhs[id]";
3323 printer->push_block(
"if (nt->nrn_fast_imem)");
3324 if (nrn_cur_reduction_loop_required()) {
3325 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
3326 printer->add_line(
"int node_id = node_index[id];");
3328 printer->fmt_line(
"nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} {};", rhs_op, rhs);
3329 printer->fmt_line(
"nt->nrn_fast_imem->nrn_sav_d[node_id] {} {};", d_op, d);
3330 if (nrn_cur_reduction_loop_required()) {
3331 printer->pop_block();
3333 printer->pop_block();
3338 if (!nrn_cur_required()) {
3342 if (info.conductances.empty()) {
3343 print_nrn_current(*info.breakpoint_node);
3346 printer->add_newline(2);
3347 printer->add_line(
"/** update current */");
3350 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
3351 print_nrn_cur_kernel(*info.breakpoint_node);
3352 print_nrn_cur_matrix_shadow_update();
3353 if (!nrn_cur_reduction_loop_required()) {
3354 print_fast_imem_calculation();
3356 printer->pop_block();
3358 if (nrn_cur_reduction_loop_required()) {
3359 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
3360 print_nrn_cur_matrix_shadow_reduction();
3361 printer->pop_block();
3362 print_fast_imem_calculation();
3365 print_kernel_data_present_annotation_block_end();
3366 printer->pop_block();
3375 print_standard_includes();
3376 print_backend_includes();
3377 print_coreneuron_includes();
3382 print_namespace_start();
3383 print_backend_namespace_start();
3388 print_backend_namespace_stop();
3389 print_namespace_stop();
3394 print_first_pointer_var_index_getter();
3395 print_first_random_var_index_getter();
3396 print_net_receive_arg_size_getter();
3397 print_thread_getters();
3398 print_num_variable_getter();
3399 print_mech_type_getter();
3400 print_memb_list_getter();
3405 print_mechanism_global_var_structure(print_initializers);
3406 print_mechanism_range_var_structure(print_initializers);
3407 print_ion_var_structure();
3412 if (!info.vectorize) {
3415 printer->add_multi_line(R
"CODE(
3417 inst->v_unused[id] = v;
3424 printer->add_multi_line(R
"CODE(
3426 inst->g_unused[id] = g;
3433 print_top_verbatim_blocks();
3434 for (
const auto& procedure: info.procedures) {
3435 print_procedure(*procedure);
3437 for (
const auto&
function: info.functions) {
3438 print_function(*
function);
3440 for (
const auto&
function: info.function_tables) {
3441 print_function_tables(*
function);
3443 for (
size_t i = 0; i < info.before_after_blocks.size(); i++) {
3444 print_before_after_block(info.before_after_blocks[i], i);
3446 for (
const auto& callback: info.derivimplicit_callbacks) {
3447 const auto& block = *callback->get_node_to_solve();
3448 print_derivimplicit_kernel(block);
3450 print_net_send_buffering();
3452 print_watch_activate();
3453 print_watch_check();
3454 print_net_receive_kernel();
3455 print_net_receive();
3456 print_net_receive_buffering();
3464 print_backend_info();
3465 print_headers_include();
3466 print_namespace_begin();
3467 print_nmodl_constants();
3468 print_prcellstate_macros();
3469 print_mechanism_info();
3470 print_data_structures(
true);
3471 print_global_variables_for_hoc();
3472 print_common_getters();
3473 print_memory_allocation_routine();
3474 print_abort_routine();
3475 print_thread_memory_callbacks();
3476 print_instance_variable_setup();
3478 print_nrn_constructor();
3479 print_nrn_destructor();
3480 print_function_prototypes();
3481 print_functors_definitions();
3482 print_compute_functions();
3483 print_check_table_thread_function();
3484 print_mechanism_register();
3485 print_namespace_end();
3495 printer->fmt_line(
"{}_{}({});",
3498 external_method_arguments());
3505 printer->add_newline();
3507 auto float_type = default_float_data_type();
3509 printer->fmt_line(
"Eigen::Matrix<{}, {}, 1> nmodl_eigen_xm;", float_type, N);
3510 printer->fmt_line(
"{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
3515 printer->add_line(
"// call newton solver");
3516 printer->fmt_line(
"{} newton_functor(nt, inst, id, pnodecount, v, indexes, data, thread);",
3517 info.functor_names[&node]);
3518 printer->add_line(
"newton_functor.initialize();");
3520 "int newton_iterations = nmodl::newton::newton_solver(nmodl_eigen_xm, newton_functor);");
3522 "if (newton_iterations < 0) assert(false && \"Newton solver did not converge!\");");
3526 printer->add_line(
"newton_functor.finalize();");
3532 printer->add_newline();
3534 const std::string float_type = default_float_data_type();
3536 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, 1> nmodl_eigen_xm, nmodl_eigen_fm;", float_type, N);
3537 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm;", float_type, N);
3539 printer->fmt_line(
"Eigen::Matrix<{0}, {1}, {1}> nmodl_eigen_jm_inv;", float_type, N);
3540 printer->fmt_line(
"{}* nmodl_eigen_x = nmodl_eigen_xm.data();", float_type);
3541 printer->fmt_line(
"{}* nmodl_eigen_j = nmodl_eigen_jm.data();", float_type);
3542 printer->fmt_line(
"{}* nmodl_eigen_f = nmodl_eigen_fm.data();", float_type);
3547 printer->add_newline();
3548 print_eigen_linear_solver(float_type, N);
3549 printer->add_newline();
3564 for (
size_t i_arg = 0; i_arg < args.size(); ++i_arg) {
3568 const auto& new_name = fmt::format(
"weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg);
3569 v.
set(old_name, new_name);
3570 statement_block->accept(v);
3574 std::find_if(info.semantics.begin(), info.semantics.end(), [](
const IndexSemantics& a) {
3575 return a.name == naming::FOR_NETCON_SEMANTIC;
3578 printer->fmt_text(
"const size_t offset = {}*pnodecount + id;",
index);
3579 printer->add_newline();
3581 "const size_t for_netcon_start = nt->_fornetcon_perm_indices[indexes[offset]];");
3583 "const size_t for_netcon_end = nt->_fornetcon_perm_indices[indexes[offset] + 1];");
3585 printer->add_line(
"for (auto i = for_netcon_start; i < for_netcon_end; ++i) {");
3586 printer->increase_indent();
3587 print_statement_block(*statement_block,
false,
false);
3588 printer->decrease_indent();
3590 printer->add_line(
"}");
3595 printer->add_text(fmt::format(
"nrn_watch_activate(inst, id, pnodecount, {}, v, watch_remove)",
3596 current_watch_statement++));