User Guide
wrapper.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include "wrapper.hpp"
9 
11 #include "pybind/pyembed.hpp"
12 #include <fmt/format.h>
13 #include <optional>
14 #include <pybind11/embed.h>
15 #include <pybind11/stl.h>
16 
17 #include <set>
18 #include <vector>
19 
20 #include "ode_py.hpp"
21 
22 namespace py = pybind11;
23 using namespace py::literals;
24 
25 namespace nmodl {
26 namespace pybind_wrappers {
27 
28 std::tuple<std::vector<std::string>, std::vector<std::string>, std::string>
29 call_solve_linear_system(const std::vector<std::string>& eq_system,
30  const std::vector<std::string>& state_vars,
31  const std::set<std::string>& vars,
32  bool small_system,
33  bool elimination,
34  const std::string& tmp_unique_prefix,
35  const std::set<std::string>& function_calls) {
36  const auto locals = py::dict("eq_strings"_a = eq_system,
37  "state_vars"_a = state_vars,
38  "vars"_a = vars,
39  "small_system"_a = small_system,
40  "do_cse"_a = elimination,
41  "function_calls"_a = function_calls,
42  "tmp_unique_prefix"_a = tmp_unique_prefix);
43  std::string script = R"(
44 exception_message = ""
45 try:
46  solutions, new_local_vars = solve_lin_system(eq_strings,
47  state_vars,
48  vars,
49  function_calls,
50  tmp_unique_prefix,
51  small_system,
52  do_cse)
53 except Exception as e:
54  # if we fail, fail silently and return empty string
55  import traceback
56  solutions = [""]
57  new_local_vars = [""]
58  exception_message = traceback.format_exc()
59 )";
60 
61  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
62  // returns a vector of solutions, i.e. new statements to add to block:
63  auto solutions = locals["solutions"].cast<std::vector<std::string>>();
64  // and a vector of new local variables that need to be declared in the block:
65  auto new_local_vars = locals["new_local_vars"].cast<std::vector<std::string>>();
66  // may also return a python exception message:
67  auto exception_message = locals["exception_message"].cast<std::string>();
68 
69  return {std::move(solutions), std::move(new_local_vars), std::move(exception_message)};
70 }
71 
72 
73 std::tuple<std::vector<std::string>, std::string> call_solve_nonlinear_system(
74  const std::vector<std::string>& eq_system,
75  const std::vector<std::string>& state_vars,
76  const std::set<std::string>& vars,
77  const std::set<std::string>& function_calls) {
78  const auto locals = py::dict("equation_strings"_a = eq_system,
79  "state_vars"_a = state_vars,
80  "vars"_a = vars,
81  "function_calls"_a = function_calls);
82  std::string script = R"(
83 exception_message = ""
84 try:
85  solutions = solve_non_lin_system(equation_strings,
86  state_vars,
87  vars,
88  function_calls)
89 except Exception as e:
90  # if we fail, fail silently and return empty string
91  import traceback
92  solutions = [""]
93  exception_message = traceback.format_exc()
94 )";
95 
96  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
97  // returns a vector of solutions, i.e. new statements to add to block:
98  auto solutions = locals["solutions"].cast<std::vector<std::string>>();
99  // may also return a python exception message:
100  auto exception_message = locals["exception_message"].cast<std::string>();
101 
102  return {std::move(solutions), std::move(exception_message)};
103 }
104 
105 
106 std::tuple<std::string, std::string> call_diffeq_solver(const std::string& node_as_nmodl,
107  const std::string& dt_var,
108  const std::set<std::string>& vars,
109  bool use_pade_approx,
110  const std::set<std::string>& function_calls,
111  const std::string& method) {
112  const auto locals = py::dict("equation_string"_a = node_as_nmodl,
113  "dt_var"_a = dt_var,
114  "vars"_a = vars,
115  "use_pade_approx"_a = use_pade_approx,
116  "function_calls"_a = function_calls);
117 
118  if (method == codegen::naming::EULER_METHOD) {
119  // replace x' = f(x) differential equation
120  // with forwards Euler timestep:
121  // x = x + f(x) * dt
122  std::string script = R"(
123 exception_message = ""
124 try:
125  solution = forwards_euler2c(equation_string, dt_var, vars, function_calls)
126 except Exception as e:
127  # if we fail, fail silently and return empty string
128  import traceback
129  solution = ""
130  exception_message = traceback.format_exc()
131 )";
132 
133  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
134  } else if (method == codegen::naming::CNEXP_METHOD) {
135  // replace x' = f(x) differential equation
136  // with analytic solution for x(t+dt) in terms of x(t)
137  // x = ...
138  std::string script = R"(
139 exception_message = ""
140 try:
141  solution = integrate2c(equation_string, dt_var, vars,
142  use_pade_approx)
143 except Exception as e:
144  # if we fail, fail silently and return empty string
145  import traceback
146  solution = ""
147  exception_message = traceback.format_exc()
148 )";
149 
150  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
151  } else {
152  // nothing to do, but the caller should know.
153  return {};
154  }
155  auto solution = locals["solution"].cast<std::string>();
156  auto exception_message = locals["exception_message"].cast<std::string>();
157 
158  return {std::move(solution), std::move(exception_message)};
159 }
160 
161 
162 std::tuple<std::string, std::string> call_analytic_diff(
163  const std::vector<std::string>& expressions,
164  const std::set<std::string>& used_names_in_block) {
165  auto locals = py::dict("expressions"_a = expressions, "vars"_a = used_names_in_block);
166  std::string script = R"(
167 exception_message = ""
168 try:
169  rhs = expressions[-1].split("=", 1)[1]
170  solution = differentiate2c(rhs,
171  "v",
172  vars,
173  expressions[:-1]
174  )
175 except Exception as e:
176  # if we fail, fail silently and return empty string
177  import traceback
178  solution = ""
179  exception_message = traceback.format_exc()
180 )";
181 
182  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
183 
184  auto solution = locals["solution"].cast<std::string>();
185  auto exception_message = locals["exception_message"].cast<std::string>();
186 
187  return {std::move(solution), std::move(exception_message)};
188 }
189 
190 std::tuple<std::string, std::string> call_diff2c(
191  const std::string& expression,
192  const std::pair<std::string, std::optional<int>>& variable,
193  const std::unordered_set<std::string>& indexed_vars) {
194  std::string statements;
195  // only indexed variables require special treatment
196  for (const auto& var: indexed_vars) {
197  statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var);
198  }
199  auto [name, property] = variable;
200  if (property.has_value()) {
201  name = fmt::format("sp.IndexedBase('{}', shape=[1])", name);
202  statements += fmt::format("_allvars.append({})", name);
203  } else {
204  name = fmt::format("'{}'", name);
205  }
206  auto locals = py::dict("expression"_a = expression);
207  std::string script =
208  fmt::format(R"(
209 _allvars = []
210 {}
211 variable = {}
212 exception_message = ""
213 try:
214  solution = differentiate2c(expression,
215  variable,
216  _allvars,
217  )
218 except Exception as e:
219  # if we fail, fail silently and return empty string
220  solution = ""
221  exception_message = str(e)
222 )",
223  statements,
224  property.has_value() ? fmt::format("{}[{}]", name, property.value()) : name);
225 
226  py::exec(nmodl::pybind_wrappers::ode_py + script, locals);
227 
228  auto solution = locals["solution"].cast<std::string>();
229  auto exception_message = locals["exception_message"].cast<std::string>();
230 
231  return {std::move(solution), std::move(exception_message)};
232 }
233 
235  pybind11::initialize_interpreter(true);
236 }
237 
239  pybind11::finalize_interpreter();
240 }
241 
242 // Prevent mangling for easier `dlsym`.
243 extern "C" {
251  &call_diff2c};
252 }
253 }
254 
255 } // namespace pybind_wrappers
256 } // namespace nmodl
nmodl::pybind_wrappers::call_diffeq_solver
std::tuple< std::string, std::string > call_diffeq_solver(const std::string &node_as_nmodl, const std::string &dt_var, const std::set< std::string > &vars, bool use_pade_approx, const std::set< std::string > &function_calls, const std::string &method)
Definition: wrapper.cpp:106
wrapper.hpp
nmodl::codegen::naming::CNEXP_METHOD
static constexpr char CNEXP_METHOD[]
cnexp method in nmodl
Definition: codegen_naming.hpp:30
NMODL_EXPORT
#define NMODL_EXPORT
Definition: wrapper.hpp:73
nmodl
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
nmodl::pybind_wrappers::initialize_interpreter_func
void initialize_interpreter_func()
Definition: wrapper.cpp:234
nmodl::pybind_wrappers::call_solve_nonlinear_system
std::tuple< std::vector< std::string >, std::string > call_solve_nonlinear_system(const std::vector< std::string > &eq_system, const std::vector< std::string > &state_vars, const std::set< std::string > &vars, const std::set< std::string > &function_calls)
Definition: wrapper.cpp:73
codegen_naming.hpp
nmodl::pybind_wrappers::call_analytic_diff
std::tuple< std::string, std::string > call_analytic_diff(const std::vector< std::string > &expressions, const std::set< std::string > &used_names_in_block)
Definition: wrapper.cpp:162
nmodl::pybind_wrappers::call_diff2c
std::tuple< std::string, std::string > call_diff2c(const std::string &expression, const std::pair< std::string, std::optional< int >> &variable, const std::unordered_set< std::string > &indexed_vars)
Differentiates an expression with respect to a variable.
Definition: wrapper.cpp:190
nmodl::pybind_wrappers::pybind_wrap_api
Definition: wrapper.hpp:60
nmodl::codegen::naming::EULER_METHOD
static constexpr char EULER_METHOD[]
euler method in nmodl
Definition: codegen_naming.hpp:27
nmodl::pybind_wrappers::call_solve_linear_system
std::tuple< std::vector< std::string >, std::vector< std::string >, std::string > call_solve_linear_system(const std::vector< std::string > &eq_system, const std::vector< std::string > &state_vars, const std::set< std::string > &vars, bool small_system, bool elimination, const std::string &tmp_unique_prefix, const std::set< std::string > &function_calls)
Definition: wrapper.cpp:29
nmodl::pybind_wrappers::finalize_interpreter_func
void finalize_interpreter_func()
Definition: wrapper.cpp:238
nmodl::pybind_wrappers::nmodl_init_pybind_wrapper_api
NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept
Definition: wrapper.cpp:244
pyembed.hpp