24 #include "solver/solver.hpp"
38 using visitor::DefUseAnalyzeVisitor;
40 using visitor::RenameVisitor;
41 using visitor::SymtabVisitor;
42 using visitor::VarUsageVisitor;
54 return "C++ (api-compatibility)";
63 return info.vectorize;
86 const std::string& name = token;
92 if (program_symtab->is_method_defined(token)) {
93 return method_name(token);
101 auto new_name = replace_if_verbatim_variable(name);
102 if (new_name != name) {
103 new_name = get_variable_name(new_name,
false);
106 new_name.insert(0, 1,
'&');
117 auto use_instance = !printing_top_verbatim_blocks;
118 return get_variable_name(token, use_instance);
123 printer->add_line(
"// VERBATIM");
124 const auto& result = process_verbatim_text(text);
127 for (
const auto& statement: statements) {
129 if (trimed_stmt.find_first_not_of(
' ') != std::string::npos) {
130 printer->add_line(trimed_stmt);
133 printer->add_line(
"// ENDVERBATIM");
143 auto symbol = program_symtab->lookup_in_scope(name);
144 bool is_constant =
false;
145 if (symbol !=
nullptr) {
147 if (info.is_ion_variable(name)) {
152 else if (symbol->has_any_property(NmodlType::param_assign) &&
153 info.variables_in_verbatim.find(name) == info.variables_in_verbatim.end() &&
154 symbol->get_write_count() == 0) {
237 return info.point_process;
242 if (info.point_process) {
250 if (info.point_process) {
251 printer->add_line(
"shadow_rhs[id] = rhs;");
252 printer->add_line(
"shadow_d[id] = g;");
254 auto rhs_op = operator_for_rhs();
255 auto d_op = operator_for_d();
256 printer->fmt_line(
"vec_rhs[node_id] {} rhs;", rhs_op);
257 printer->fmt_line(
"vec_d[node_id] {} g;", d_op);
263 auto rhs_op = operator_for_rhs();
264 auto d_op = operator_for_d();
265 if (info.point_process) {
266 printer->add_line(
"int node_id = node_index[id];");
267 printer->fmt_line(
"vec_rhs[node_id] {} shadow_rhs[id];", rhs_op);
268 printer->fmt_line(
"vec_d[node_id] {} shadow_d[id];", d_op);
279 printer->add_line(
"#pragma omp atomic update");
294 return optimize_ionvar_copies;
299 printer->add_newline(2);
300 auto args =
"size_t num, size_t size, size_t alignment = 64";
301 printer->fmt_push_block(
"static inline void* mem_alloc({})", args);
303 "size_t aligned_size = ((num*size + alignment - 1) / alignment) * alignment;");
304 printer->add_line(
"void* ptr = aligned_alloc(alignment, aligned_size);");
305 printer->add_line(
"memset(ptr, 0, aligned_size);");
306 printer->add_line(
"return ptr;");
307 printer->pop_block();
309 printer->add_newline(2);
310 printer->push_block(
"static inline void mem_free(void* ptr)");
311 printer->add_line(
"free(ptr);");
312 printer->pop_block();
317 printer->add_newline(2);
318 printer->push_block(
"static inline void coreneuron_abort()");
319 printer->add_line(
"abort();");
320 printer->pop_block();
330 if (info.functions.empty() && info.procedures.empty()) {
334 printer->add_newline(2);
335 for (
const auto& node: info.functions) {
336 print_function_declaration(*node, node->get_node_name());
337 printer->add_text(
';');
338 printer->add_newline();
340 for (
const auto& node: info.procedures) {
341 print_function_declaration(*node, node->get_node_name());
342 printer->add_text(
';');
343 printer->add_newline();
349 if (info.table_count == 0) {
353 printer->add_newline(2);
354 auto name = method_name(
"check_table_thread");
355 auto parameters = get_parameter_str(external_method_parameters(
true));
357 printer->fmt_push_block(
"static void {} ({})", name, parameters);
358 printer->add_line(
"setup_instance(nt, ml);");
359 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
360 printer->add_line(
"double v = 0;");
362 for (
const auto&
function: info.functions_with_table) {
363 auto method_name_str = table_update_function_name(function->get_node_name());
364 auto arguments = internal_method_arguments();
365 printer->fmt_line(
"{}({});", method_name_str, arguments);
368 printer->pop_block();
374 const std::string& name,
375 const std::unordered_set<CppObjectSpecifier>& specifiers) {
376 printer->add_newline(2);
377 print_function_declaration(node, name, specifiers);
378 printer->add_text(
" ");
379 printer->push_block();
383 auto type = default_float_data_type();
384 printer->fmt_line(
"{} ret_{} = 0.0;", type, name);
386 printer->fmt_line(
"int ret_{} = 0;", name);
390 printer->fmt_line(
"return ret_{};", name);
391 printer->pop_block();
398 if (info.function_uses_table(name)) {
399 auto new_name =
"f_" + name;
400 print_function_or_procedure(node, new_name);
401 print_table_check_function(node);
402 print_table_replacement_function(node);
404 print_function_or_procedure(node, name);
416 if (info.net_send_used) {
417 if (info.artificial_cell) {
421 variables.back().is_constant =
true;
423 info.tqitem_index =
static_cast<int>(variables.size() - 1);
428 std::vector<IndexVariableInfo>& variables) {
430 if (info.artificial_cell) {
434 variables.back().is_constant =
true;
439 return get_arg_str(internal_method_parameters());
448 {
"",
"int",
"",
"pnodecount"},
449 {
"", fmt::format(
"{}*", instance_struct()),
"",
"inst"}};
450 if (ion_variable_struct_required()) {
451 params.emplace_back(
"",
"IonCurVar&",
"",
"ionvar");
453 ParamVector other_params = {{
"",
"double*",
"",
"data"},
454 {
"const ",
"Datum*",
"",
"indexes"},
455 {
"",
"ThreadDatum*",
"",
"thread"},
456 {
"",
"NrnThread*",
"",
"nt"},
457 {
"",
"double",
"",
"v"}};
458 params.insert(params.end(), other_params.begin(), other_params.end());
464 return get_arg_str(external_method_parameters());
469 bool table) noexcept {
471 {
"",
"int",
"",
"pnodecount"},
472 {
"",
"double*",
"",
"data"},
473 {
"",
"Datum*",
"",
"indexes"},
474 {
"",
"ThreadDatum*",
"",
"thread"},
475 {
"",
"NrnThread*",
"",
"nt"},
476 {
"",
"Memb_list*",
"",
"ml"}};
478 args.emplace_back(
"",
"int",
"",
"tml_id");
480 args.emplace_back(
"",
"double",
"",
"v");
487 if (ion_variable_struct_required()) {
488 return "id, pnodecount, ionvar, data, indexes, thread, nt, ml, v";
490 return "id, pnodecount, data, indexes, thread, nt, ml, v";
499 return get_arg_str(internal_method_parameters());
502 std::pair<CodegenCoreneuronCppVisitor::ParamVector, CodegenCoreneuronCppVisitor::ParamVector>
504 auto params = internal_method_parameters();
506 params.emplace_back(
"",
"double",
"", i->get_node_name());
508 return {params, internal_method_parameters()};
526 if (internal_method_call_encountered) {
527 name = nrn_thread_internal_arguments();
528 internal_method_call_encountered =
false;
530 name = nrn_thread_arguments();
534 name = get_parameter_str(external_method_parameters());
548 auto tokens =
driver.all_tokens();
550 for (
size_t i = 0; i < tokens.size(); i++) {
551 auto token = tokens[i];
555 if (program_symtab->is_method_defined(token) && tokens[i + 1] ==
"(") {
556 internal_method_call_encountered =
true;
558 result += process_verbatim_token(token);
565 auto nrn_channel_info_var_name = get_channel_info_var_name();
572 return fmt::format(
"{}, {}, {}, nullptr, {}, {}, {}, {}, first_pointer_var_index()",
573 nrn_channel_info_var_name,
578 nrn_private_constructor,
579 nrn_private_destructor);
584 std::vector<ShadowUseStatement>& statements,
586 const std::string& concentration) {
594 throw std::logic_error(fmt::format(
"codegen error for {} ion", ion.
name));
596 auto ion_type_name = fmt::format(
"{}_type", ion.
name);
597 auto lhs = fmt::format(
"int {}", ion_type_name);
599 auto rhs = get_variable_name(ion_type_name);
602 auto ion_name = ion.
name;
604 auto style_var_name = get_variable_name(
"style_" + ion_name);
605 auto statement = fmt::format(
606 "nrn_wrote_conc({}_type,"
610 " nrn_ion_global_map,"
612 " nt->_ml_list[{}_type]->_nodecount_padded)",
630 printer->add_newline(2);
631 printer->push_block(
"static inline int first_pointer_var_index()");
632 printer->fmt_line(
"return {};", info.first_pointer_var_index);
633 printer->pop_block();
638 printer->add_newline(2);
639 printer->push_block(
"static inline int first_random_var_index()");
640 printer->fmt_line(
"return {};", info.first_random_var_index);
641 printer->pop_block();
646 printer->add_newline(2);
647 printer->push_block(
"static inline int float_variables_size()");
648 printer->fmt_line(
"return {};", float_variables_size());
649 printer->pop_block();
651 printer->add_newline(2);
652 printer->push_block(
"static inline int int_variables_size()");
653 printer->fmt_line(
"return {};", int_variables_size());
654 printer->pop_block();
659 if (!net_receive_exist()) {
662 printer->add_newline(2);
663 printer->push_block(
"static inline int num_net_receive_args()");
664 printer->fmt_line(
"return {};", info.num_net_receive_parameters);
665 printer->pop_block();
670 printer->add_newline(2);
671 printer->push_block(
"static inline int get_mech_type()");
673 printer->fmt_line(
"return {};", get_variable_name(
"mech_type",
false));
674 printer->pop_block();
679 printer->add_newline(2);
680 printer->push_block(
"static inline Memb_list* get_memb_list(NrnThread* nt)");
681 printer->push_block(
"if (!nt->_ml_list)");
682 printer->add_line(
"return nullptr;");
683 printer->pop_block();
684 printer->add_line(
"return nt->_ml_list[get_mech_type()];");
685 printer->pop_block();
705 if (info.vectorize && info.derivimplicit_used()) {
706 int tid = info.derivimplicit_var_thread_id;
707 int list = info.derivimplicit_list_num;
710 printer->add_newline(2);
711 printer->add_line(
"/** thread specific helper routines for derivimplicit */");
713 printer->add_newline(1);
714 printer->fmt_push_block(
"static inline int* deriv{}_advance(ThreadDatum* thread)", list);
715 printer->fmt_line(
"return &(thread[{}].i);", tid);
716 printer->pop_block();
717 printer->add_newline();
719 printer->fmt_push_block(
"static inline int dith{}()", list);
720 printer->fmt_line(
"return {};", tid+1);
721 printer->pop_block();
722 printer->add_newline();
724 printer->fmt_push_block(
"static inline void** newtonspace{}(ThreadDatum* thread)", list);
725 printer->fmt_line(
"return &(thread[{}]._pvoid);", tid+2);
726 printer->pop_block();
729 if (info.vectorize && !info.thread_variables.empty()) {
730 printer->add_newline(2);
731 printer->add_line(
"/** tid for thread variables */");
732 printer->push_block(
"static inline int thread_var_tid()");
733 printer->fmt_line(
"return {};", info.thread_var_thread_id);
734 printer->pop_block();
737 if (info.vectorize && !info.top_local_variables.empty()) {
738 printer->add_newline(2);
739 printer->add_line(
"/** tid for top local tread variables */");
740 printer->push_block(
"static inline int top_local_var_tid()");
741 printer->fmt_line(
"return {};", info.top_local_thread_id);
742 printer->pop_block();
754 bool use_instance)
const {
755 auto name = symbol->get_name();
756 auto dimension = symbol->get_length();
757 auto position = position_of_float_var(name);
758 if (symbol->is_array()) {
760 return fmt::format(
"(inst->{}+id*{})", name, dimension);
762 return fmt::format(
"(data + {}*pnodecount + id*{})", position, dimension);
765 return fmt::format(
"inst->{}[id]", name);
767 return fmt::format(
"data[{}*pnodecount + id]", position);
772 const std::string& name,
773 bool use_instance)
const {
774 auto position = position_of_int_var(name);
778 return fmt::format(
"inst->{}[{}]", name, position);
780 return fmt::format(
"indexes[{}]", position);
784 return fmt::format(
"inst->{}[{}*pnodecount+id]", name, position);
786 return fmt::format(
"indexes[{}*pnodecount+id]", position);
789 return fmt::format(
"inst->{}[indexes[{}*pnodecount + id]]", name, position);
791 auto data = symbol.
is_vdata ?
"_vdata" :
"_data";
792 return fmt::format(
"nt->{}[indexes[{}*pnodecount + id]]", data, position);
798 bool use_instance)
const {
802 return fmt::format(
"{}.{}", global_struct_instance(), symbol->get_name());
808 bool use_instance)
const {
809 const std::string& varname = update_if_ion_variable_name(name);
812 auto symbol_comparator = [&varname](
const SymbolType& sym) {
813 return varname == sym->get_name();
817 return varname == var.symbol->get_name();
822 auto f = std::find_if(codegen_float_variables.begin(),
823 codegen_float_variables.end(),
825 if (f != codegen_float_variables.end()) {
826 return float_variable_name(*f, use_instance);
831 std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator);
832 if (i != codegen_int_variables.end()) {
833 auto full_name = int_variable_name(*i, varname, use_instance);
834 auto pos = position_of_int_var(varname);
837 return "(nrnran123_State*) " + full_name;
843 auto g = std::find_if(codegen_global_variables.begin(),
844 codegen_global_variables.end(),
846 if (g != codegen_global_variables.end()) {
847 return global_variable_name(*g, use_instance);
861 std::find_if(info.neuron_global_variables.begin(),
862 info.neuron_global_variables.end(),
863 [&varname](
auto const& entry) { return entry.first->get_name() == varname; });
864 if (iter != info.neuron_global_variables.end()) {
887 printer->add_newline();
888 printer->add_multi_line(R
"CODE(
898 printer->add_newline();
899 printer->add_multi_line(R
"CODE(
900 #include <coreneuron/gpu/nrn_acc_manager.hpp>
901 #include <coreneuron/mechanism/mech/mod2c_core_thread.hpp>
902 #include <coreneuron/mechanism/register_mech.hpp>
903 #include <coreneuron/nrnconf.h>
904 #include <coreneuron/nrniv/nrniv_decl.h>
905 #include <coreneuron/sim/multicore.hpp>
906 #include <coreneuron/sim/scopmath/newton_thread.hpp>
907 #include <coreneuron/utils/ivocvect.hpp>
908 #include <coreneuron/utils/nrnoc_aux.hpp>
909 #include <coreneuron/utils/randoms/nrnran123.h>
911 if (info.eigen_newton_solver_exist) {
912 printer->add_multi_line(nmodl::solvers::newton_hpp);
914 if (info.eigen_linear_solver_exist) {
915 if (std::accumulate(info.state_vars.begin(),
916 info.state_vars.end(),
919 return l += variable->get_length();
921 printer->add_multi_line(nmodl::solvers::crout_hpp);
923 printer->add_line(
"#include <Eigen/Dense>");
924 printer->add_line(
"#include <Eigen/LU>");
931 if (info.primes_size == 0) {
934 const auto count_prime_variables = [](
auto size,
const SymbolType& symbol) {
935 return size += symbol->get_length();
937 const auto prime_variables_by_order_size =
938 std::accumulate(info.prime_variables_by_order.begin(),
939 info.prime_variables_by_order.end(),
941 count_prime_variables);
942 if (info.primes_size != prime_variables_by_order_size) {
943 throw std::runtime_error{
944 fmt::format(
"primes_size = {} differs from prime_variables_by_order.size() = {}, "
945 "this should not happen.",
947 info.prime_variables_by_order.size())};
949 auto const initializer_list = [&](
auto const& primes,
const char* prefix) -> std::string {
950 if (!print_initializers) {
953 std::string list{
"{"};
954 for (
auto iter = primes.begin(); iter != primes.end(); ++iter) {
955 auto const& prime = *iter;
956 list.append(
std::to_string(position_of_float_var(prefix + prime->get_name())));
957 if (std::next(iter) != primes.end()) {
964 printer->fmt_line(
"int slist1[{}]{};",
966 initializer_list(info.prime_variables_by_order,
""));
967 printer->fmt_line(
"int dlist1[{}]{};",
969 initializer_list(info.prime_variables_by_order,
"D"));
970 codegen_global_variables.push_back(make_symbol(
"slist1"));
971 codegen_global_variables.push_back(make_symbol(
"dlist1"));
973 if (info.derivimplicit_used()) {
974 auto primes = program_symtab->get_variables_with_properties(NmodlType::prime_name);
975 printer->fmt_line(
"int slist2[{}]{};", info.primes_size, initializer_list(primes,
""));
976 codegen_global_variables.push_back(make_symbol(
"slist2"));
983 {
"", fmt::format(
"{}*", instance_struct()),
"",
"inst"},
984 {
"",
"int",
"",
"id"},
985 {
"",
"int",
"",
"pnodecount"},
986 {
"",
"double",
"",
"v"},
987 {
"const ",
"Datum*",
"",
"indexes"},
988 {
"",
"double*",
"",
"data"},
989 {
"",
"ThreadDatum*",
"",
"thread"}};
1010 const auto value_initialize = print_initializers ?
"{}" :
"";
1012 auto float_type = default_float_data_type();
1013 printer->add_newline(2);
1014 printer->add_line(
"/** all global variables */");
1015 printer->fmt_push_block(
"struct {}", global_struct());
1017 for (
const auto& ion: info.ions) {
1018 auto name = fmt::format(
"{}_type", ion.name);
1019 printer->fmt_line(
"int {}{};", name, value_initialize);
1020 codegen_global_variables.push_back(make_symbol(name));
1023 if (info.point_process) {
1024 printer->fmt_line(
"int point_type{};", value_initialize);
1025 codegen_global_variables.push_back(make_symbol(
"point_type"));
1028 for (
const auto& var: info.state_vars) {
1029 auto name = var->get_name() +
"0";
1030 auto symbol = program_symtab->lookup(name);
1031 if (symbol ==
nullptr) {
1032 printer->fmt_line(
"{} {}{};", float_type, name, value_initialize);
1033 codegen_global_variables.push_back(make_symbol(name));
1041 auto& top_locals = info.top_local_variables;
1042 if (!info.vectorize && !top_locals.empty()) {
1043 for (
const auto& var: top_locals) {
1044 auto name = var->get_name();
1045 auto length = var->get_length();
1046 if (var->is_array()) {
1047 printer->fmt_line(
"{} {}[{}] /* TODO init top-local-array */;",
1052 printer->fmt_line(
"{} {} /* TODO init top-local */;", float_type, name);
1054 codegen_global_variables.push_back(var);
1058 if (!info.thread_variables.empty()) {
1059 printer->fmt_line(
"int thread_data_in_use{};", value_initialize);
1060 printer->fmt_line(
"{} thread_data[{}] /* TODO init thread_data */;",
1062 info.thread_var_data_size);
1063 codegen_global_variables.push_back(make_symbol(
"thread_data_in_use"));
1064 auto symbol = make_symbol(
"thread_data");
1065 symbol->set_as_array(info.thread_var_data_size);
1066 codegen_global_variables.push_back(symbol);
1070 printer->fmt_line(
"int reset{};", value_initialize);
1071 codegen_global_variables.push_back(make_symbol(
"reset"));
1073 printer->fmt_line(
"int mech_type{};", value_initialize);
1074 codegen_global_variables.push_back(make_symbol(
"mech_type"));
1076 for (
const auto& var: info.global_variables) {
1077 auto name = var->get_name();
1078 auto length = var->get_length();
1079 if (var->is_array()) {
1080 printer->fmt_line(
"{} {}[{}] /* TODO init const-array */;", float_type, name, length);
1083 if (
auto const& value_ptr = var->get_value()) {
1086 printer->fmt_line(
"{} {}{};",
1089 print_initializers ? fmt::format(
"{{{:g}}}", value) : std::string{});
1091 codegen_global_variables.push_back(var);
1094 for (
const auto& var: info.constant_variables) {
1095 auto const name = var->get_name();
1096 auto*
const value_ptr = var->get_value().get();
1097 double const value{value_ptr ? *value_ptr : 0};
1098 printer->fmt_line(
"{} {}{};",
1101 print_initializers ? fmt::format(
"{{{:g}}}", value) : std::string{});
1102 codegen_global_variables.push_back(var);
1105 print_sdlists_init(print_initializers);
1107 if (info.table_count > 0) {
1108 printer->fmt_line(
"double usetable{};", print_initializers ?
"{1}" :
"");
1111 for (
const auto& block: info.functions_with_table) {
1112 const auto& name = block->get_node_name();
1113 printer->fmt_line(
"{} tmin_{}{};", float_type, name, value_initialize);
1114 printer->fmt_line(
"{} mfac_{}{};", float_type, name, value_initialize);
1115 codegen_global_variables.push_back(make_symbol(
"tmin_" + name));
1116 codegen_global_variables.push_back(make_symbol(
"mfac_" + name));
1119 for (
const auto& variable: info.table_statement_variables) {
1120 auto const name =
"t_" + variable->get_name();
1121 auto const num_values = variable->get_num_values();
1122 if (variable->is_array()) {
1123 int array_len = variable->get_length();
1125 "{} {}[{}][{}]{};", float_type, name, array_len, num_values, value_initialize);
1127 printer->fmt_line(
"{} {}[{}]{};", float_type, name, num_values, value_initialize);
1129 codegen_global_variables.push_back(make_symbol(name));
1133 print_global_struct_function_table_ptrs();
1135 if (info.vectorize && info.thread_data_index) {
1136 printer->fmt_line(
"ThreadDatum ext_call_thread[{}]{};",
1137 info.thread_data_index,
1139 codegen_global_variables.push_back(make_symbol(
"ext_call_thread"));
1142 printer->pop_block(
";");
1144 print_global_var_struct_assertions();
1145 print_global_var_struct_decl();
1154 auto variable_printer =
1155 [&](
const std::vector<SymbolType>& variables,
bool if_array,
bool if_vector) {
1156 for (
const auto& variable: variables) {
1157 if (variable->is_array() == if_array) {
1160 auto name = get_variable_name(variable->get_name(),
false);
1161 auto ename = add_escape_quote(variable->get_name() +
"_" + info.mod_suffix);
1162 auto length = variable->get_length();
1164 printer->fmt_line(
"{{{}, {}, {}}},", ename, name, length);
1166 printer->fmt_line(
"{{{}, &{}}},", ename, name);
1172 auto globals = info.global_variables;
1173 auto thread_vars = info.thread_variables;
1175 if (info.table_count > 0) {
1179 printer->add_newline(2);
1180 printer->add_line(
"/** connect global (scalar) variables to hoc -- */");
1181 printer->add_line(
"static DoubScal hoc_scalar_double[] = {");
1182 printer->increase_indent();
1183 variable_printer(globals,
false,
false);
1184 variable_printer(thread_vars,
false,
false);
1185 printer->add_line(
"{nullptr, nullptr}");
1186 printer->decrease_indent();
1187 printer->add_line(
"};");
1189 printer->add_newline(2);
1190 printer->add_line(
"/** connect global (array) variables to hoc -- */");
1191 printer->add_line(
"static DoubVec hoc_vector_double[] = {");
1192 printer->increase_indent();
1193 variable_printer(globals,
true,
true);
1194 variable_printer(thread_vars,
true,
true);
1195 printer->add_line(
"{nullptr, nullptr, 0}");
1196 printer->decrease_indent();
1197 printer->add_line(
"};");
1211 std::string register_type{};
1216 register_type =
"BAType::Before";
1218 dynamic_cast<const ast::BeforeBlock*
>(block)->get_bablock()->get_type()->get_value();
1221 register_type =
"BAType::After";
1223 dynamic_cast<const ast::AfterBlock*
>(block)->get_bablock()->get_type()->get_value();
1229 register_type +=
" + BAType::Breakpoint";
1231 register_type +=
" + BAType::Solve";
1233 register_type +=
" + BAType::Initial";
1235 register_type +=
" + BAType::Step";
1237 throw std::runtime_error(
"Unhandled Before/After type encountered during code generation");
1239 return register_type;
1261 printer->add_newline(2);
1262 printer->add_line(
"/** register channel with the simulator */");
1263 printer->fmt_push_block(
"void _{}_reg()", info.mod_file);
1266 auto suffix = add_escape_quote(info.mod_suffix);
1267 printer->add_newline();
1268 printer->fmt_line(
"int mech_type = nrn_get_mechtype({});", suffix);
1269 printer->fmt_line(
"{} = mech_type;", get_variable_name(
"mech_type",
false));
1270 printer->push_block(
"if (mech_type == -1)");
1271 printer->add_line(
"return;");
1272 printer->pop_block();
1274 printer->add_newline();
1275 printer->add_line(
"_nrn_layout_reg(mech_type, 0);");
1278 const auto mech_arguments = register_mechanism_arguments();
1279 const auto number_of_thread_objects = num_thread_objects();
1280 if (info.point_process) {
1281 printer->fmt_line(
"point_register_mech({}, {}, {}, {});",
1287 number_of_thread_objects);
1289 printer->fmt_line(
"register_mech({}, {});", mech_arguments, number_of_thread_objects);
1290 if (info.constructor_node) {
1291 printer->fmt_line(
"register_constructor({});",
1297 for (
const auto& ion: info.ions) {
1298 printer->fmt_line(
"{} = nrn_get_mechtype({});",
1299 get_variable_name(ion.name +
"_type",
false),
1300 add_escape_quote(ion.name +
"_ion"));
1302 printer->add_newline();
1308 if (info.vectorize && (info.thread_data_index != 0)) {
1310 printer->fmt_line(
"thread_mem_init({});", get_variable_name(
"ext_call_thread",
false));
1313 if (!info.thread_variables.empty()) {
1314 printer->fmt_line(
"{} = 0;", get_variable_name(
"thread_data_in_use"));
1317 if (info.thread_callback_register) {
1318 printer->add_line(
"_nrn_thread_reg0(mech_type, thread_mem_cleanup);");
1319 printer->add_line(
"_nrn_thread_reg1(mech_type, thread_mem_init);");
1322 if (info.emit_table_thread()) {
1323 auto name = method_name(
"check_table_thread");
1324 printer->fmt_line(
"_nrn_thread_table_reg(mech_type, {});", name);
1328 if (info.bbcore_pointer_used) {
1329 printer->add_line(
"hoc_reg_bbcore_read(mech_type, bbcore_read);");
1330 printer->add_line(
"hoc_reg_bbcore_write(mech_type, bbcore_write);");
1335 printer->add_line(
"hoc_register_prop_size(mech_type, float_variables_size(), int_variables_size());");
1339 for (
auto& semantic: info.semantics) {
1341 fmt::format(
"mech_type, {}, {}", semantic.index, add_escape_quote(semantic.name));
1342 printer->fmt_line(
"hoc_register_dparam_semantics({});", args);
1345 if (info.is_watch_used()) {
1347 printer->fmt_line(
"hoc_register_watch_check({}, mech_type);", watch_fun);
1350 if (info.write_concentration) {
1351 printer->add_line(
"nrn_writes_conc(mech_type, 0);");
1355 if (info.net_event_used) {
1356 printer->add_line(
"add_nrn_has_net_event(mech_type);");
1358 if (info.artificial_cell) {
1359 printer->fmt_line(
"add_nrn_artcell(mech_type, {});", info.tqitem_index);
1361 if (net_receive_buffering_required()) {
1362 printer->fmt_line(
"hoc_register_net_receive_buffering({}, mech_type);",
1363 method_name(
"net_buf_receive"));
1365 if (info.num_net_receive_parameters != 0) {
1366 auto net_recv_init_arg =
"nullptr";
1367 if (info.net_receive_initial_node !=
nullptr) {
1368 net_recv_init_arg =
"net_init";
1370 printer->fmt_line(
"set_pnt_receive(mech_type, {}, {}, num_net_receive_args());",
1371 method_name(
"net_receive"),
1374 if (info.for_netcon_used) {
1376 printer->fmt_line(
"add_nrn_fornetcons(mech_type, {});",
index);
1379 if (info.net_event_used || info.net_send_used) {
1380 printer->add_line(
"hoc_register_net_send_buffering(mech_type);");
1384 for (
size_t i = 0; i < info.before_after_blocks.size(); i++) {
1386 const auto& block = info.before_after_blocks[i];
1388 std::string function_name = method_name(fmt::format(
"nrn_before_after_{}", i));
1389 printer->fmt_line(
"hoc_reg_ba(mech_type, {}, {});", function_name, register_type);
1393 printer->add_line(
"hoc_register_var(hoc_scalar_double, hoc_vector_double, NULL);");
1394 printer->pop_block();
1399 if (!info.thread_callback_register) {
1404 printer->add_newline(2);
1405 printer->add_line(
"/** thread memory allocation callback */");
1406 printer->push_block(
"static void thread_mem_init(ThreadDatum* thread) ");
1408 if (info.vectorize && info.derivimplicit_used()) {
1409 printer->fmt_line(
"thread[dith{}()].pval = nullptr;", info.derivimplicit_list_num);
1411 if (info.vectorize && (info.top_local_thread_size != 0)) {
1412 auto length = info.top_local_thread_size;
1413 auto allocation = fmt::format(
"(double*)mem_alloc({}, sizeof(double))", length);
1414 printer->fmt_line(
"thread[top_local_var_tid()].pval = {};", allocation);
1416 if (info.thread_var_data_size != 0) {
1417 auto length = info.thread_var_data_size;
1418 auto thread_data = get_variable_name(
"thread_data");
1419 auto thread_data_in_use = get_variable_name(
"thread_data_in_use");
1420 auto allocation = fmt::format(
"(double*)mem_alloc({}, sizeof(double))", length);
1421 printer->fmt_push_block(
"if ({})", thread_data_in_use);
1422 printer->fmt_line(
"thread[thread_var_tid()].pval = {};", allocation);
1423 printer->chain_block(
"else");
1424 printer->fmt_line(
"thread[thread_var_tid()].pval = {};", thread_data);
1425 printer->fmt_line(
"{} = 1;", thread_data_in_use);
1426 printer->pop_block();
1428 printer->pop_block();
1429 printer->add_newline(2);
1433 printer->add_line(
"/** thread memory cleanup callback */");
1434 printer->push_block(
"static void thread_mem_cleanup(ThreadDatum* thread) ");
1437 if (info.vectorize && info.derivimplicit_used()) {
1438 int n = info.derivimplicit_list_num;
1439 printer->fmt_line(
"free(thread[dith{}()].pval);", n);
1440 printer->fmt_line(
"nrn_destroy_newtonspace(static_cast<NewtonSpace*>(*newtonspace{}(thread)));", n);
1444 if (info.top_local_thread_size != 0) {
1445 auto line =
"free(thread[top_local_var_tid()].pval);";
1446 printer->add_line(line);
1448 if (info.thread_var_data_size != 0) {
1449 auto thread_data = get_variable_name(
"thread_data");
1450 auto thread_data_in_use = get_variable_name(
"thread_data_in_use");
1451 printer->fmt_push_block(
"if (thread[thread_var_tid()].pval == {})", thread_data);
1452 printer->fmt_line(
"{} = 0;", thread_data_in_use);
1453 printer->chain_block(
"else");
1454 printer->add_line(
"free(thread[thread_var_tid()].pval);");
1455 printer->pop_block();
1457 printer->pop_block();
1462 auto const value_initialize = print_initializers ?
"{}" :
"";
1463 auto int_type = default_int_data_type();
1464 printer->add_newline(2);
1465 printer->add_line(
"/** all mechanism instance variables and global variables */");
1466 printer->fmt_push_block(
"struct {} ", instance_struct());
1468 for (
auto const& [var, type]: info.neuron_global_variables) {
1469 auto const name = var->get_name();
1470 printer->fmt_line(
"{}* {}{};",
1473 print_initializers ? fmt::format(
"{{&coreneuron::{}}}", name)
1476 for (
auto& var: codegen_float_variables) {
1477 const auto& name = var->get_name();
1478 auto type = get_range_var_float_type(var);
1479 auto qualifier = is_constant_variable(name) ?
"const " :
"";
1480 printer->fmt_line(
"{}{}* {}{};", qualifier, type, name, value_initialize);
1482 for (
auto& var: codegen_int_variables) {
1483 const auto& name = var.symbol->get_name();
1484 if (var.is_index || var.is_integer) {
1485 auto qualifier = var.is_constant ?
"const " :
"";
1486 printer->fmt_line(
"{}{}* {}{};", qualifier, int_type, name, value_initialize);
1488 auto qualifier = var.is_constant ?
"const " :
"";
1489 auto type = var.is_vdata ?
"void*" : default_float_data_type();
1490 printer->fmt_line(
"{}{}* {}{};", qualifier, type, name, value_initialize);
1494 printer->fmt_line(
"{}* {}{};",
1497 print_initializers ? fmt::format(
"{{&{}}}", global_struct_instance())
1499 printer->pop_block(
";");
1504 if (!ion_variable_struct_required()) {
1507 printer->add_newline(2);
1508 printer->add_line(
"/** ion write variables */");
1509 printer->push_block(
"struct IonCurVar");
1511 std::string float_type = default_float_data_type();
1512 std::vector<std::string> members;
1514 for (
auto& ion: info.ions) {
1515 for (
auto& var: ion.writes) {
1516 printer->fmt_line(
"{} {};", float_type, var);
1517 members.push_back(var);
1520 for (
auto& var: info.currents) {
1521 if (!info.is_ion_variable(var)) {
1522 printer->fmt_line(
"{} {};", float_type, var);
1523 members.push_back(var);
1527 print_ion_var_constructor(members);
1529 printer->pop_block(
";");
1534 const std::vector<std::string>& members) {
1536 printer->add_newline();
1537 printer->add_indent();
1538 printer->add_text(
"IonCurVar() : ");
1539 for (
int i = 0; i < members.size(); i++) {
1540 printer->fmt_text(
"{}(0)", members[i]);
1541 if (i + 1 < members.size()) {
1542 printer->add_text(
", ");
1545 printer->add_text(
" {}");
1546 printer->add_newline();
1551 printer->add_line(
"IonCurVar ionvar;");
1561 auto type = float_data_type();
1562 printer->add_newline(2);
1563 printer->add_line(
"/** allocate and setup array for range variable */");
1564 printer->fmt_push_block(
"static inline {}* setup_range_variable(double* variable, int n)",
1566 printer->fmt_line(
"{0}* data = ({0}*) mem_alloc(n, sizeof({0}));", type);
1567 printer->push_block(
"for(size_t i = 0; i < n; i++)");
1568 printer->add_line(
"data[i] = variable[i];");
1569 printer->pop_block();
1570 printer->add_line(
"return data;");
1571 printer->pop_block();
1583 auto with = NmodlType::read_ion_var
1584 | NmodlType::write_ion_var
1585 | NmodlType::pointer_var
1586 | NmodlType::bbcore_pointer_var
1587 | NmodlType::extern_neuron_variable;
1589 bool need_default_type = symbol->has_any_property(with);
1590 if (need_default_type) {
1591 return default_float_data_type();
1593 return float_data_type();
1598 if (range_variable_setup_required()) {
1599 print_setup_range_variable();
1602 printer->add_newline();
1603 printer->add_line(
"// Allocate instance structure");
1604 printer->fmt_push_block(
"static void {}(NrnThread* nt, Memb_list* ml, int type)",
1606 printer->add_line(
"assert(!ml->instance);");
1607 printer->add_line(
"assert(!ml->global_variables);");
1608 printer->add_line(
"assert(ml->global_variables_size == 0);");
1609 printer->fmt_line(
"auto* const inst = new {}{{}};", instance_struct());
1610 printer->fmt_line(
"assert(inst->{} == &{});",
1612 global_struct_instance());
1613 printer->add_line(
"ml->instance = inst;");
1615 printer->fmt_line(
"ml->global_variables_size = sizeof({});", global_struct());
1616 printer->pop_block();
1617 printer->add_newline();
1619 auto const cast_inst_and_assert_validity = [&]() {
1620 printer->fmt_line(
"auto* const inst = static_cast<{}*>(ml->instance);", instance_struct());
1621 printer->add_line(
"assert(inst);");
1623 printer->fmt_line(
"assert(inst->{} == &{});",
1625 global_struct_instance());
1627 printer->fmt_line(
"assert(ml->global_variables_size == sizeof({}));", global_struct());
1632 print_instance_struct_transfer_routine_declarations();
1634 printer->add_line(
"// Deallocate the instance structure");
1635 printer->fmt_push_block(
"static void {}(NrnThread* nt, Memb_list* ml, int type)",
1637 cast_inst_and_assert_validity();
1640 if (info.random_variables.size()) {
1641 printer->add_line(
"int pnodecount = ml->_nodecount_padded;");
1642 printer->add_line(
"int nodecount = ml->nodecount;");
1643 printer->add_line(
"Datum* indexes = ml->pdata;");
1644 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
1645 for (
const auto& var: info.random_variables) {
1646 const auto& name = get_variable_name(var->get_name());
1647 printer->fmt_line(
"nrnran123_deletestream({});", name);
1649 printer->pop_block();
1651 print_instance_struct_delete_from_device();
1652 printer->add_multi_line(R
"CODE(
1654 ml->instance = nullptr;
1655 ml->global_variables = nullptr;
1656 ml->global_variables_size = 0;
1658 printer->pop_block();
1659 printer->add_newline();
1662 printer->add_line("/** initialize mechanism instance variables */");
1663 printer->push_block(
"static inline void setup_instance(NrnThread* nt, Memb_list* ml)");
1664 cast_inst_and_assert_validity();
1667 printer->add_line(
"int pnodecount = ml->_nodecount_padded;");
1668 stride =
"*pnodecount";
1670 printer->add_line(
"Datum* indexes = ml->pdata;");
1672 auto const float_type = default_float_data_type();
1676 for (
auto const& [var, type]: info.neuron_global_variables) {
1677 ptr_members.push_back(var->get_name());
1679 ptr_members.reserve(ptr_members.size() + codegen_float_variables.size() +
1680 codegen_int_variables.size());
1681 for (
auto& var: codegen_float_variables) {
1682 auto name = var->get_name();
1683 auto range_var_type = get_range_var_float_type(var);
1684 if (float_type == range_var_type) {
1685 auto const variable = fmt::format(
"ml->data+{}{}",
id, stride);
1686 printer->fmt_line(
"inst->{} = {};", name, variable);
1689 printer->fmt_line(
"inst->{} = setup_range_variable(ml->data+{}{}, pnodecount);",
1694 ptr_members.push_back(std::move(name));
1695 id += var->get_length();
1698 for (
auto& var: codegen_int_variables) {
1699 auto name = var.symbol->get_name();
1700 auto const variable = [&var]() {
1701 if (var.is_index || var.is_integer) {
1703 }
else if (var.is_vdata) {
1704 return "nt->_vdata";
1709 printer->fmt_line(
"inst->{} = {};", name, variable);
1710 ptr_members.push_back(std::move(name));
1712 print_instance_struct_copy_to_device();
1713 printer->pop_block();
1714 printer->add_newline();
1716 print_instance_struct_transfer_routines(ptr_members);
1721 if (info.artificial_cell) {
1722 printer->add_line(
"double v = 0.0;");
1724 printer->add_line(
"int node_id = node_index[id];");
1725 printer->add_line(
"double v = voltage[node_id];");
1729 if (ion_variable_struct_required()) {
1730 printer->add_line(
"IonCurVar ionvar;");
1735 for (
auto& statement: read_statements) {
1736 printer->add_line(statement);
1739 print_rename_state_vars();
1742 if (node !=
nullptr) {
1744 print_statement_block(*block,
false,
false);
1749 for (
auto& statement: write_statements) {
1751 printer->add_line(text);
1758 const std::string& function_name) {
1760 if (function_name.empty()) {
1761 method = compute_method_name(type);
1763 method = function_name;
1765 auto args =
"NrnThread* nt, Memb_list* ml, int type";
1769 args =
"NrnThread* nt, Memb_list* ml";
1772 print_global_method_annotation();
1773 printer->fmt_push_block(
"void {}({})", method, args);
1777 print_kernel_data_present_annotation_block_begin();
1781 printer->add_line(
"#ifndef CORENEURON_BUILD");
1783 printer->add_multi_line(R
"CODE(
1784 int nodecount = ml->nodecount;
1785 int pnodecount = ml->_nodecount_padded;
1786 const int* node_index = ml->nodeindices;
1787 double* data = ml->data;
1788 const double* voltage = nt->_actual_v;
1792 printer->add_line(
"double* vec_rhs = nt->_actual_rhs;");
1793 printer->add_line(
"double* vec_d = nt->_actual_d;");
1794 print_rhs_d_shadow_variables();
1796 printer->add_line(
"Datum* indexes = ml->pdata;");
1797 printer->add_line(
"ThreadDatum* thread = ml->_thread;");
1800 printer->add_newline();
1801 printer->add_line(
"setup_instance(nt, ml);");
1803 printer->fmt_line(
"auto* const inst = static_cast<{}*>(ml->instance);", instance_struct());
1804 printer->add_newline(1);
1808 printer->add_newline(2);
1809 printer->add_line(
"/** initialize channel */");
1812 if (info.derivimplicit_used()) {
1813 printer->add_newline();
1814 int nequation = info.num_equations;
1815 int list_num = info.derivimplicit_list_num;
1817 printer->fmt_line(
"int& deriv_advance_flag = *deriv{}_advance(thread);", list_num);
1818 printer->add_line(
"deriv_advance_flag = 0;");
1819 print_deriv_advance_flag_transfer_to_device();
1820 printer->fmt_line(
"auto ns = newtonspace{}(thread);", list_num);
1821 printer->fmt_line(
"auto& th = thread[dith{}()];", list_num);
1822 printer->push_block(
"if (*ns == nullptr)");
1823 printer->fmt_line(
"int vec_size = 2*{}*pnodecount*sizeof(double);", nequation);
1824 printer->fmt_line(
"double* vec = makevector(vec_size);", nequation);
1825 printer->fmt_line(
"th.pval = vec;", list_num);
1826 printer->fmt_line(
"*ns = nrn_cons_newtonspace({}, pnodecount);", nequation);
1827 print_newtonspace_transfer_to_device();
1828 printer->pop_block();
1835 print_global_variable_device_update_annotation();
1837 if (skip_init_check) {
1838 printer->push_block(
"if (_nrn_skip_initmodel == 0)");
1841 if (!info.changed_dt.empty()) {
1842 printer->fmt_line(
"double _save_prev_dt = {};",
1844 printer->fmt_line(
"{} = {};",
1847 print_dt_update_to_device();
1851 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
1853 if (info.net_receive_node !=
nullptr) {
1854 printer->fmt_line(
"{} = -1e20;", get_variable_name(
"tsave"));
1857 print_initial_block(info.initial_node);
1858 printer->pop_block();
1860 if (!info.changed_dt.empty()) {
1862 print_dt_update_to_device();
1865 printer->pop_block();
1867 if (info.derivimplicit_used()) {
1868 printer->add_line(
"deriv_advance_flag = 1;");
1869 print_deriv_advance_flag_transfer_to_device();
1872 if (info.net_send_used && !info.artificial_cell) {
1873 print_send_event_move();
1876 print_kernel_data_present_annotation_block_end();
1877 if (skip_init_check) {
1878 printer->pop_block();
1884 std::string ba_type;
1885 std::shared_ptr<ast::BABlock> ba_block;
1895 std::string ba_block_type = ba_block->get_type()->eval();
1898 std::string function_name = method_name(fmt::format(
"nrn_before_after_{}", block_id));
1901 printer->add_newline(2);
1902 printer->fmt_line(
"/** {} of block type {} # {} */", ba_type, ba_block_type, block_id);
1906 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
1908 printer->add_line(
"int node_id = node_index[id];");
1909 printer->add_line(
"double v = voltage[node_id];");
1914 for (
auto& statement: read_statements) {
1915 printer->add_line(statement);
1919 printer->add_indent();
1920 print_statement_block(*ba_block->get_statement_block());
1921 printer->add_newline();
1925 for (
auto& statement: write_statements) {
1927 printer->add_line(text);
1931 printer->pop_block();
1932 printer->pop_block();
1933 print_kernel_data_present_annotation_block_end();
1937 printer->add_newline(2);
1939 if (info.constructor_node !=
nullptr) {
1940 const auto& block = info.constructor_node->get_statement_block();
1941 print_statement_block(*block,
false,
false);
1943 printer->add_line(
"#endif");
1944 printer->pop_block();
1949 printer->add_newline(2);
1951 if (info.destructor_node !=
nullptr) {
1952 const auto& block = info.destructor_node->get_statement_block();
1953 print_statement_block(*block,
false,
false);
1955 printer->add_line(
"#endif");
1956 printer->pop_block();
1961 printer->add_newline(2);
1963 printer->fmt_push_block(
"static void {}(double* data, Datum* indexes, int type)", method);
1964 printer->add_line(
"// do nothing");
1965 printer->pop_block();
1974 if (info.watch_statements.empty()) {
1978 printer->add_newline(2);
1979 auto inst = fmt::format(
"{}* inst", instance_struct());
1981 printer->fmt_push_block(
1982 "static void nrn_watch_activate({}, int id, int pnodecount, int watch_id, "
1983 "double v, bool &watch_remove)",
1987 printer->push_block(
"if (watch_remove == false)");
1988 for (
int i = 0; i < info.watch_count; i++) {
1989 auto name = get_variable_name(fmt::format(
"watch{}", i + 1));
1990 printer->fmt_line(
"{} = 0;", name);
1992 printer->add_line(
"watch_remove = true;");
1993 printer->pop_block();
1999 for (
int i = 0; i < info.watch_statements.size(); i++) {
2000 auto statement = info.watch_statements[i];
2001 printer->fmt_push_block(
"if (watch_id == {})", i);
2003 auto varname = get_variable_name(fmt::format(
"watch{}", i + 1));
2004 printer->add_indent();
2005 printer->fmt_text(
"{} = 2 + (", varname);
2006 auto watch = statement->get_statements().front();
2007 watch->get_expression()->visit_children(*
this);
2008 printer->add_text(
");");
2009 printer->add_newline();
2011 printer->pop_block();
2013 printer->pop_block();
2022 if (info.watch_statements.empty()) {
2026 printer->add_newline(2);
2027 printer->add_line(
"/** routine to check watch activation */");
2036 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2038 if (info.is_voltage_used_by_watch_statements()) {
2039 printer->add_line(
"int node_id = node_index[id];");
2040 printer->add_line(
"double v = voltage[node_id];");
2045 printer->add_line(
"bool watch_untriggered = true;");
2047 for (
int i = 0; i < info.watch_statements.size(); i++) {
2048 auto statement = info.watch_statements[i];
2049 const auto& watch = statement->get_statements().front();
2050 const auto& varname = get_variable_name(fmt::format(
"watch{}", i + 1));
2053 printer->fmt_push_block(
"if ({}&2 && watch_untriggered)", varname);
2056 printer->add_indent();
2057 printer->add_text(
"if (");
2058 watch->get_expression()->accept(*
this);
2059 printer->add_text(
") {");
2060 printer->add_newline();
2061 printer->increase_indent();
2064 printer->fmt_push_block(
"if (({}&1) == 0)", varname);
2066 printer->add_line(
"watch_untriggered = false;");
2068 const auto& tqitem = get_variable_name(
"tqitem");
2069 const auto& point_process = get_variable_name(
"point_process");
2070 printer->add_indent();
2071 printer->add_text(
"net_send_buffering(");
2072 const auto& t = get_variable_name(
"t");
2073 printer->fmt_text(
"nt, ml->_net_send_buffer, 0, {}, -1, {}, {}+0.0, ",
2077 watch->get_value()->accept(*
this);
2078 printer->add_text(
");");
2079 printer->add_newline();
2080 printer->pop_block();
2082 printer->add_line(varname,
" = 3;");
2086 printer->decrease_indent();
2087 printer->push_block(
"} else");
2088 printer->add_line(varname,
" = 2;");
2089 printer->pop_block();
2092 printer->pop_block();
2096 printer->pop_block();
2097 print_send_event_move();
2098 print_kernel_data_present_annotation_block_end();
2099 printer->pop_block();
2104 bool need_mech_inst) {
2105 printer->add_multi_line(R
"CODE(
2106 int tid = pnt->_tid;
2107 int id = pnt->_i_instance;
2112 printer->add_line(
"NrnThread* nt = nrn_threads + tid;");
2113 printer->add_line(
"Memb_list* ml = nt->_ml_list[pnt->_type];");
2116 print_kernel_data_present_annotation_block_begin();
2119 printer->add_multi_line(R
"CODE(
2120 int nodecount = ml->nodecount;
2121 int pnodecount = ml->_nodecount_padded;
2122 double* data = ml->data;
2123 double* weights = nt->weights;
2124 Datum* indexes = ml->pdata;
2125 ThreadDatum* thread = ml->_thread;
2127 if (need_mech_inst) {
2128 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
2132 print_net_init_acc_serial_annotation_block_begin();
2136 auto parameters = info.net_receive_node->get_parameters();
2137 if (!parameters.empty()) {
2139 printer->add_newline();
2140 for (
auto& parameter: parameters) {
2141 auto name = parameter->get_node_name();
2144 printer->fmt_line(
"double* {} = weights + weight_index + {};", name, i);
2156 const auto& tqitem = get_variable_name(
"tqitem");
2157 std::string weight_index =
"weight_index";
2158 std::string pnt =
"pnt";
2162 if (!printing_net_receive && !printing_net_init) {
2164 auto var = get_variable_name(
"point_process");
2165 if (info.artificial_cell) {
2166 pnt =
"(Point_process*)" + var;
2172 if (info.artificial_cell) {
2173 printer->fmt_text(
"artcell_net_send(&{}, {}, {}, nt->_t+", tqitem, weight_index, pnt);
2175 const auto& point_process = get_variable_name(
"point_process");
2176 const auto& t = get_variable_name(
"t");
2177 printer->add_text(
"net_send_buffering(");
2178 printer->fmt_text(
"nt, ml->_net_send_buffer, 0, {}, {}, {}, {}+", tqitem, weight_index, point_process, t);
2181 print_vector_elements(arguments,
", ");
2182 printer->add_text(
')');
2187 if (!printing_net_receive && !printing_net_init) {
2188 throw std::runtime_error(
"Error : net_move only allowed in NET_RECEIVE block");
2192 const auto& tqitem = get_variable_name(
"tqitem");
2193 std::string weight_index =
"-1";
2194 std::string pnt =
"pnt";
2198 if (info.artificial_cell) {
2199 printer->fmt_text(
"artcell_net_move(&{}, {}, ", tqitem, pnt);
2200 print_vector_elements(arguments,
", ");
2201 printer->add_text(
")");
2203 const auto& point_process = get_variable_name(
"point_process");
2204 printer->add_text(
"net_send_buffering(");
2205 printer->fmt_text(
"nt, ml->_net_send_buffer, 2, {}, {}, {}, ", tqitem, weight_index, point_process);
2206 print_vector_elements(arguments,
", ");
2207 printer->add_text(
", 0.0");
2208 printer->add_text(
")");
2215 if (info.artificial_cell) {
2216 printer->add_text(
"net_event(pnt, ");
2217 print_vector_elements(arguments,
", ");
2219 const auto& point_process = get_variable_name(
"point_process");
2220 printer->add_text(
"net_send_buffering(");
2221 printer->fmt_text(
"nt, ml->_net_send_buffer, 1, -1, -1, {}, ", point_process);
2222 print_vector_elements(arguments,
", ");
2223 printer->add_text(
", 0.0");
2225 printer->add_text(
")");
2231 printer->add_text(method_name(name),
'(');
2233 printer->add_text(internal_method_arguments());
2234 if (!arguments.empty()) {
2235 printer->add_text(
", ");
2238 print_vector_elements(arguments,
", ");
2239 printer->add_text(
')');
2268 for (
auto& parameter: parameters) {
2269 const auto& name = parameter->get_node_name();
2280 const auto node = info.net_receive_initial_node;
2281 if (node ==
nullptr) {
2288 printing_net_init =
true;
2289 auto args =
"Point_process* pnt, int weight_index, double flag";
2290 printer->add_newline(2);
2291 printer->add_line(
"/** initialize block for net receive */");
2292 printer->fmt_push_block(
"static void net_init({})", args);
2293 auto block = node->get_statement_block().get();
2294 if (block->get_statements().empty()) {
2295 printer->add_line(
"// do nothing");
2297 print_net_receive_common_code(*node);
2298 print_statement_block(*block,
false,
false);
2299 if (node->is_initial_block()) {
2300 print_net_init_acc_serial_annotation_block_end();
2301 print_kernel_data_present_annotation_block_end();
2302 printer->add_line(
"auto& nsb = ml->_net_send_buffer;");
2303 print_net_send_buf_update_to_host();
2306 printer->pop_block();
2307 printing_net_init =
false;
2312 printer->add_newline();
2313 printer->add_line(
"NetSendBuffer_t* nsb = ml->_net_send_buffer;");
2314 print_net_send_buf_update_to_host();
2315 printer->push_block(
"for (int i=0; i < nsb->_cnt; i++)");
2316 printer->add_multi_line(R
"CODE(
2317 int type = nsb->_sendtype[i];
2319 double t = nsb->_nsb_t[i];
2320 double flag = nsb->_nsb_flag[i];
2321 int vdata_index = nsb->_vdata_index[i];
2322 int weight_index = nsb->_weight_index[i];
2323 int point_index = nsb->_pnt_index[i];
2324 net_sem_from_gpu(type, vdata_index, weight_index, tid, point_index, t, flag);
2326 printer->pop_block();
2327 printer->add_line("nsb->_cnt = 0;");
2328 print_net_send_buf_count_update_to_device();
2333 return fmt::format(
"void {}(NrnThread* nt)", method_name(
"net_buf_receive"));
2338 printer->add_line(
"Memb_list* ml = get_memb_list(nt);");
2339 printer->push_block(
"if (!ml)");
2340 printer->add_line(
"return;");
2341 printer->pop_block();
2342 printer->add_newline();
2347 printer->add_line(
"int count = nrb->_displ_cnt;");
2349 printer->push_block(
"for (int i = 0; i < count; i++)");
2354 printer->pop_block();
2359 if (!net_receive_required() || info.artificial_cell) {
2362 printer->add_newline(2);
2363 printer->push_block(net_receive_buffering_declaration());
2365 print_get_memb_list();
2367 const auto& net_receive = method_name(
"net_receive_kernel");
2369 print_kernel_data_present_annotation_block_begin();
2371 printer->add_line(
"NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;");
2372 if (need_mech_inst) {
2373 printer->fmt_line(
"auto* const inst = static_cast<{0}*>(ml->instance);", instance_struct());
2375 print_net_receive_loop_begin();
2376 printer->add_line(
"int start = nrb->_displ[i];");
2377 printer->add_line(
"int end = nrb->_displ[i+1];");
2378 printer->push_block(
"for (int j = start; j < end; j++)");
2379 printer->add_multi_line(R
"CODE(
2380 int index = nrb->_nrb_index[j];
2381 int offset = nrb->_pnt_index[index];
2382 double t = nrb->_nrb_t[index];
2383 int weight_index = nrb->_weight_index[index];
2384 double flag = nrb->_nrb_flag[index];
2385 Point_process* point_process = nt->pntprocs + offset;
2387 printer->add_line(net_receive, "(t, point_process, inst, nt, ml, weight_index, flag);");
2388 printer->pop_block();
2389 print_net_receive_loop_end();
2391 print_device_stream_wait();
2392 printer->add_line(
"nrb->_displ_cnt = 0;");
2393 printer->add_line(
"nrb->_cnt = 0;");
2395 if (info.net_send_used || info.net_event_used) {
2396 print_send_event_move();
2399 print_kernel_data_present_annotation_block_end();
2400 printer->pop_block();
2405 printer->add_line(
"i = nsb->_cnt++;");
2410 printer->push_block(
"if (i >= nsb->_size)");
2411 printer->add_line(
"nsb->grow();");
2412 printer->pop_block();
2417 if (!net_send_buffer_required()) {
2421 printer->add_newline(2);
2423 "const NrnThread* nt, NetSendBuffer_t* nsb, int type, int vdata_index, "
2424 "int weight_index, int point_index, double t, double flag";
2425 printer->fmt_push_block(
"static inline void net_send_buffering({})", args);
2426 printer->add_line(
"int i = 0;");
2427 print_net_send_buffering_cnt_update();
2428 print_net_send_buffering_grow();
2429 printer->push_block(
"if (i < nsb->_size)");
2430 printer->add_multi_line(R
"CODE(
2431 nsb->_sendtype[i] = type;
2432 nsb->_vdata_index[i] = vdata_index;
2433 nsb->_weight_index[i] = weight_index;
2434 nsb->_pnt_index[i] = point_index;
2436 nsb->_nsb_flag[i] = flag;
2438 printer->pop_block();
2439 printer->pop_block();
2444 if (!net_receive_required()) {
2448 printing_net_receive =
true;
2449 const auto node = info.net_receive_node;
2456 if (!info.artificial_cell) {
2457 name = method_name(
"net_receive_kernel");
2458 params.emplace_back(
"",
"double",
"",
"t");
2459 params.emplace_back(
"",
"Point_process*",
"",
"pnt");
2460 params.emplace_back(
"", fmt::format(
"{}*", instance_struct()),
2462 params.emplace_back(
"",
"NrnThread*",
"",
"nt");
2463 params.emplace_back(
"",
"Memb_list*",
"",
"ml");
2464 params.emplace_back(
"",
"int",
"",
"weight_index");
2465 params.emplace_back(
"",
"double",
"",
"flag");
2467 name = method_name(
"net_receive");
2468 params.emplace_back(
"",
"Point_process*",
"",
"pnt");
2469 params.emplace_back(
"",
"int",
"",
"weight_index");
2470 params.emplace_back(
"",
"double",
"",
"flag");
2473 printer->add_newline(2);
2474 printer->fmt_push_block(
"static inline void {}({})", name, get_parameter_str(params));
2475 print_net_receive_common_code(*node, info.artificial_cell);
2476 if (info.artificial_cell) {
2477 printer->add_line(
"double t = nt->_t;");
2483 printer->add_line(
"int node_id = ml->nodeindices[id];");
2484 printer->add_line(
"v = nt->_actual_v[node_id];");
2487 printer->fmt_line(
"{} = t;", get_variable_name(
"tsave"));
2489 if (info.is_watch_used()) {
2490 printer->add_line(
"bool watch_remove = false;");
2493 printer->add_indent();
2494 node->get_statement_block()->accept(*
this);
2495 printer->add_newline();
2496 printer->pop_block();
2498 printing_net_receive =
false;
2503 if (!net_receive_required()) {
2507 printing_net_receive =
true;
2508 if (!info.artificial_cell) {
2509 const auto& name = method_name(
"net_receive");
2511 {
"",
"Point_process*",
"",
"pnt"},
2512 {
"",
"int",
"",
"weight_index"},
2513 {
"",
"double",
"",
"flag"}};
2514 printer->add_newline(2);
2515 printer->fmt_push_block(
"static void {}({})", name, get_parameter_str(params));
2516 printer->add_line(
"NrnThread* nt = nrn_threads + pnt->_tid;");
2517 printer->add_line(
"Memb_list* ml = get_memb_list(nt);");
2518 printer->add_line(
"NetReceiveBuffer_t* nrb = ml->_net_receive_buffer;");
2519 printer->push_block(
"if (nrb->_cnt >= nrb->_size)");
2520 printer->add_line(
"realloc_net_receive_buffer(nt, ml);");
2521 printer->pop_block();
2522 printer->add_multi_line(R
"CODE(
2524 nrb->_pnt_index[id] = pnt-nt->pntprocs;
2525 nrb->_weight_index[id] = weight_index;
2526 nrb->_nrb_t[id] = nt->_t;
2527 nrb->_nrb_flag[id] = flag;
2530 printer->pop_block();
2532 printing_net_receive = false;
2544 auto ext_args = external_method_arguments();
2545 auto ext_params = get_parameter_str(external_method_parameters());
2546 auto suffix = info.mod_suffix;
2547 auto list_num = info.derivimplicit_list_num;
2549 auto primes_size = info.primes_size;
2550 auto stride =
"*pnodecount+id";
2552 printer->add_newline(2);
2554 printer->push_block(
"namespace");
2555 printer->fmt_push_block(
"struct _newton_{}_{}", block_name, info.mod_suffix);
2556 printer->fmt_push_block(
"int operator()({}) const", get_parameter_str(external_method_parameters()));
2557 auto const instance = fmt::format(
"auto* const inst = static_cast<{0}*>(ml->instance);",
2559 auto const slist1 = fmt::format(
"auto const& slist{} = {};",
2561 get_variable_name(fmt::format(
"slist{}", list_num)));
2562 auto const slist2 = fmt::format(
"auto& slist{} = {};",
2564 get_variable_name(fmt::format(
"slist{}", list_num + 1)));
2565 auto const dlist1 = fmt::format(
"auto const& dlist{} = {};",
2567 get_variable_name(fmt::format(
"dlist{}", list_num)));
2568 auto const dlist2 = fmt::format(
2569 "double* dlist{} = static_cast<double*>(thread[dith{}()].pval) + ({}*pnodecount);",
2573 printer->add_line(instance);
2574 if (ion_variable_struct_required()) {
2575 print_ion_variable();
2577 printer->fmt_line(
"double* savstate{} = static_cast<double*>(thread[dith{}()].pval);",
2580 printer->add_line(slist1);
2581 printer->add_line(dlist1);
2582 printer->add_line(dlist2);
2586 printer->add_line(
"int counter = -1;");
2587 printer->fmt_push_block(
"for (int i=0; i<{}; i++)", info.num_primes);
2588 printer->fmt_push_block(
"if (*deriv{}_advance(thread))", list_num);
2590 "dlist{0}[(++counter){1}] = "
2591 "data[dlist{2}[i]{1}]-(data[slist{2}[i]{1}]-savstate{2}[i{1}])/nt->_dt;",
2595 printer->chain_block(
"else");
2596 printer->fmt_line(
"dlist{0}[(++counter){1}] = data[slist{2}[i]{1}]-savstate{2}[i{1}];",
2600 printer->pop_block();
2601 printer->pop_block();
2602 printer->add_line(
"return 0;");
2603 printer->pop_block();
2604 printer->pop_block(
";");
2605 printer->pop_block();
2606 printer->add_newline();
2607 printer->fmt_push_block(
"int {}_{}({})", block_name, suffix, ext_params);
2608 printer->add_line(instance);
2609 printer->fmt_line(
"double* savstate{} = (double*) thread[dith{}()].pval;", list_num, list_num);
2610 printer->add_line(slist1);
2611 printer->add_line(slist2);
2612 printer->add_line(dlist2);
2613 printer->fmt_push_block(
"for (int i=0; i<{}; i++)", info.num_primes);
2614 printer->fmt_line(
"savstate{}[i{}] = data[slist{}[i]{}];", list_num, stride, list_num, stride);
2615 printer->pop_block();
2617 "int reset = nrn_newton_thread(static_cast<NewtonSpace*>(*newtonspace{}(thread)), {}, "
2618 "slist{}, _newton_{}_{}{{}}, dlist{}, {});",
2626 printer->add_line(
"return reset;");
2627 printer->pop_block();
2628 printer->add_newline(2);
2643 if (!nrn_state_required()) {
2647 printer->add_newline(2);
2648 printer->add_line(
"/** update state */");
2651 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2653 printer->add_line(
"int node_id = node_index[id];");
2654 printer->add_line(
"double v = voltage[node_id];");
2661 if (ion_variable_struct_required()) {
2662 print_ion_variable();
2666 for (
auto& statement: read_statements) {
2667 printer->add_line(statement);
2670 if (info.nrn_state_block) {
2671 info.nrn_state_block->visit_children(*
this);
2674 if (info.currents.empty() && info.breakpoint_node !=
nullptr) {
2675 auto block = info.breakpoint_node->get_statement_block();
2676 print_statement_block(*block,
false,
false);
2680 for (
auto& statement: write_statements) {
2681 const auto& text = process_shadow_update_statement(statement,
BlockType::State);
2682 printer->add_line(text);
2684 printer->pop_block();
2686 print_kernel_data_present_annotation_block_end();
2688 printer->pop_block();
2698 const auto& args = internal_method_parameters();
2700 printer->add_newline(2);
2701 printer->fmt_push_block(
"inline double nrn_current_{}({})",
2703 get_parameter_str(args));
2704 printer->add_line(
"double current = 0.0;");
2705 print_statement_block(*block,
false,
false);
2706 for (
auto& current: info.currents) {
2707 const auto& name = get_variable_name(current);
2708 printer->fmt_line(
"current += {};", name);
2710 printer->add_line(
"return current;");
2711 printer->pop_block();
2717 print_statement_block(*block,
false,
false);
2718 if (!info.currents.empty()) {
2720 for (
const auto& current: info.currents) {
2721 auto var = breakpoint_current(current);
2722 sum += get_variable_name(var);
2723 if (¤t != &info.currents.back()) {
2727 printer->fmt_line(
"double rhs = {};", sum);
2731 for (
const auto& conductance: info.conductances) {
2732 auto var = breakpoint_current(conductance.variable);
2733 sum += get_variable_name(var);
2734 if (&conductance != &info.conductances.back()) {
2738 printer->fmt_line(
"double g = {};", sum);
2740 for (
const auto& conductance: info.conductances) {
2741 if (!conductance.ion.empty()) {
2743 const auto& rhs = get_variable_name(conductance.variable);
2746 printer->add_line(text);
2753 printer->fmt_line(
"double g = nrn_current_{}({}+0.001);",
2755 internal_method_arguments());
2756 for (
auto& ion: info.ions) {
2757 for (
auto& var: ion.writes) {
2758 if (ion.is_ionic_current(var)) {
2759 const auto& name = get_variable_name(var);
2760 printer->fmt_line(
"double di{} = {};", ion.name, name);
2764 printer->fmt_line(
"double rhs = nrn_current_{}({});",
2766 internal_method_arguments());
2767 printer->add_line(
"g = (g-rhs)/0.001;");
2768 for (
auto& ion: info.ions) {
2769 for (
auto& var: ion.writes) {
2770 if (ion.is_ionic_current(var)) {
2772 auto rhs = fmt::format(
"(di{}-{})/0.001", ion.name, get_variable_name(var));
2773 if (info.point_process) {
2775 rhs += fmt::format(
"*1.e2/{}", area);
2779 printer->add_line(text);
2787 printer->add_line(
"int node_id = node_index[id];");
2788 printer->add_line(
"double v = voltage[node_id];");
2790 if (ion_variable_struct_required()) {
2791 print_ion_variable();
2795 for (
auto& statement: read_statements) {
2796 printer->add_line(statement);
2799 if (info.conductances.empty()) {
2800 print_nrn_cur_non_conductance_kernel();
2802 print_nrn_cur_conductance_kernel(node);
2806 for (
auto& statement: write_statements) {
2808 printer->add_line(text);
2811 if (info.point_process) {
2813 printer->fmt_line(
"double mfactor = 1.e2/{};", area);
2814 printer->add_line(
"g = g*mfactor;");
2815 printer->add_line(
"rhs = rhs*mfactor;");
2823 if (!info.electrode_current) {
2827 auto rhs_op = operator_for_rhs();
2828 auto d_op = operator_for_d();
2829 if (info.point_process) {
2830 rhs =
"shadow_rhs[id]";
2837 printer->push_block(
"if (nt->nrn_fast_imem)");
2838 if (nrn_cur_reduction_loop_required()) {
2839 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2840 printer->add_line(
"int node_id = node_index[id];");
2842 printer->fmt_line(
"nt->nrn_fast_imem->nrn_sav_rhs[node_id] {} {};", rhs_op, rhs);
2843 printer->fmt_line(
"nt->nrn_fast_imem->nrn_sav_d[node_id] {} {};", d_op, d);
2844 if (nrn_cur_reduction_loop_required()) {
2845 printer->pop_block();
2847 printer->pop_block();
2852 if (!nrn_cur_required()) {
2856 if (info.conductances.empty()) {
2857 print_nrn_current(*info.breakpoint_node);
2860 printer->add_newline(2);
2861 printer->add_line(
"/** update current */");
2864 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2865 print_nrn_cur_kernel(*info.breakpoint_node);
2866 print_nrn_cur_matrix_shadow_update();
2867 if (!nrn_cur_reduction_loop_required()) {
2868 print_fast_imem_calculation();
2870 printer->pop_block();
2872 if (nrn_cur_reduction_loop_required()) {
2873 printer->push_block(
"for (int id = 0; id < nodecount; id++)");
2874 print_nrn_cur_matrix_shadow_reduction();
2875 printer->pop_block();
2876 print_fast_imem_calculation();
2879 print_kernel_data_present_annotation_block_end();
2880 printer->pop_block();
2889 print_standard_includes();
2890 print_backend_includes();
2891 print_coreneuron_includes();
2896 print_first_pointer_var_index_getter();
2897 print_first_random_var_index_getter();
2898 print_net_receive_arg_size_getter();
2899 print_thread_getters();
2900 print_num_variable_getter();
2901 print_mech_type_getter();
2902 print_memb_list_getter();
2907 print_mechanism_global_var_structure(print_initializers);
2908 print_mechanism_range_var_structure(print_initializers);
2909 print_ion_var_structure();
2914 if (!info.vectorize) {
2917 printer->add_multi_line(R
"CODE(
2919 inst->v_unused[id] = v;
2926 printer->add_multi_line(R
"CODE(
2928 inst->g_unused[id] = g;
2935 print_top_verbatim_blocks();
2936 for (
const auto& procedure: info.procedures) {
2937 print_procedure(*procedure);
2939 for (
const auto&
function: info.functions) {
2940 print_function(*
function);
2942 for (
const auto&
function: info.function_tables) {
2943 print_function_tables(*
function);
2945 for (
size_t i = 0; i < info.before_after_blocks.size(); i++) {
2946 print_before_after_block(info.before_after_blocks[i], i);
2948 for (
const auto& callback: info.derivimplicit_callbacks) {
2949 const auto& block = *callback->get_node_to_solve();
2950 print_derivimplicit_kernel(block);
2952 print_net_send_buffering();
2954 print_watch_activate();
2955 print_watch_check();
2956 print_net_receive_kernel();
2957 print_net_receive();
2958 print_net_receive_buffering();
2966 print_backend_info();
2967 print_headers_include();
2968 print_namespace_start();
2969 print_nmodl_constants();
2970 print_prcellstate_macros();
2971 print_mechanism_info();
2972 print_data_structures(
true);
2973 print_global_variables_for_hoc();
2974 print_common_getters();
2975 print_memory_allocation_routine();
2976 print_abort_routine();
2977 print_thread_memory_callbacks();
2978 print_instance_variable_setup();
2980 print_nrn_constructor();
2981 print_nrn_destructor();
2982 print_function_prototypes();
2983 print_functors_definitions();
2984 print_compute_functions();
2985 print_check_table_thread_function();
2986 print_mechanism_register();
2987 print_namespace_stop();
2997 printer->fmt_line(
"{}_{}({});",
3000 external_method_arguments());
3012 for (
size_t i_arg = 0; i_arg < args.size(); ++i_arg) {
3016 const auto& new_name = fmt::format(
"weights[{} + nt->_fornetcon_weight_perm[i]]", i_arg);
3017 v.
set(old_name, new_name);
3018 statement_block->accept(v);
3023 printer->fmt_text(
"const size_t offset = {}*pnodecount + id;",
index);
3024 printer->add_newline();
3026 "const size_t for_netcon_start = nt->_fornetcon_perm_indices[indexes[offset]];");
3028 "const size_t for_netcon_end = nt->_fornetcon_perm_indices[indexes[offset] + 1];");
3030 printer->push_block(
"for (auto i = for_netcon_start; i < for_netcon_end; ++i)");
3031 print_statement_block(*statement_block,
false,
false);
3032 printer->pop_block();
3037 printer->add_text(fmt::format(
"nrn_watch_activate(inst, id, pnodecount, {}, v, watch_remove)",
3038 current_watch_statement++));
3043 print_atomic_reduction_pragma();
3044 printer->add_indent();
3046 printer->add_text(
";");