numpy.h 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693
  1. /*
  2. pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
  3. Copyright (c) 2016 Wenzel Jakob <[email protected]>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "pybind11.h"
  9. #include "complex.h"
  10. #include <numeric>
  11. #include <algorithm>
  12. #include <array>
  13. #include <cstdint>
  14. #include <cstdlib>
  15. #include <cstring>
  16. #include <sstream>
  17. #include <string>
  18. #include <functional>
  19. #include <type_traits>
  20. #include <utility>
  21. #include <vector>
  22. #include <typeindex>
  23. #if defined(_MSC_VER)
  24. # pragma warning(push)
  25. # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
  26. #endif
  27. /* This will be true on all flat address space platforms and allows us to reduce the
  28. whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
  29. and dimension types (e.g. shape, strides, indexing), instead of inflicting this
  30. upon the library user. */
  31. static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
  32. static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
  33. // We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
  34. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  35. class array; // Forward declaration
  36. PYBIND11_NAMESPACE_BEGIN(detail)
  37. template <> struct handle_type_name<array> { static constexpr auto name = _("numpy.ndarray"); };
  38. template <typename type, typename SFINAE = void> struct npy_format_descriptor;
  39. struct PyArrayDescr_Proxy {
  40. PyObject_HEAD
  41. PyObject *typeobj;
  42. char kind;
  43. char type;
  44. char byteorder;
  45. char flags;
  46. int type_num;
  47. int elsize;
  48. int alignment;
  49. char *subarray;
  50. PyObject *fields;
  51. PyObject *names;
  52. };
  53. struct PyArray_Proxy {
  54. PyObject_HEAD
  55. char *data;
  56. int nd;
  57. ssize_t *dimensions;
  58. ssize_t *strides;
  59. PyObject *base;
  60. PyObject *descr;
  61. int flags;
  62. };
  63. struct PyVoidScalarObject_Proxy {
  64. PyObject_VAR_HEAD
  65. char *obval;
  66. PyArrayDescr_Proxy *descr;
  67. int flags;
  68. PyObject *base;
  69. };
  70. struct numpy_type_info {
  71. PyObject* dtype_ptr;
  72. std::string format_str;
  73. };
  74. struct numpy_internals {
  75. std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
  76. numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
  77. auto it = registered_dtypes.find(std::type_index(tinfo));
  78. if (it != registered_dtypes.end())
  79. return &(it->second);
  80. if (throw_if_missing)
  81. pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
  82. return nullptr;
  83. }
  84. template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
  85. return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
  86. }
  87. };
  88. inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
  89. ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
  90. }
  91. inline numpy_internals& get_numpy_internals() {
  92. static numpy_internals* ptr = nullptr;
  93. if (!ptr)
  94. load_numpy_internals(ptr);
  95. return *ptr;
  96. }
  97. template <typename T> struct same_size {
  98. template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
  99. };
  100. template <typename Concrete> constexpr int platform_lookup() { return -1; }
  101. // Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
  102. template <typename Concrete, typename T, typename... Ts, typename... Ints>
  103. constexpr int platform_lookup(int I, Ints... Is) {
  104. return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
  105. }
  106. struct npy_api {
  107. enum constants {
  108. NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
  109. NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
  110. NPY_ARRAY_OWNDATA_ = 0x0004,
  111. NPY_ARRAY_FORCECAST_ = 0x0010,
  112. NPY_ARRAY_ENSUREARRAY_ = 0x0040,
  113. NPY_ARRAY_ALIGNED_ = 0x0100,
  114. NPY_ARRAY_WRITEABLE_ = 0x0400,
  115. NPY_BOOL_ = 0,
  116. NPY_BYTE_, NPY_UBYTE_,
  117. NPY_SHORT_, NPY_USHORT_,
  118. NPY_INT_, NPY_UINT_,
  119. NPY_LONG_, NPY_ULONG_,
  120. NPY_LONGLONG_, NPY_ULONGLONG_,
  121. NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
  122. NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
  123. NPY_OBJECT_ = 17,
  124. NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
  125. // Platform-dependent normalization
  126. NPY_INT8_ = NPY_BYTE_,
  127. NPY_UINT8_ = NPY_UBYTE_,
  128. NPY_INT16_ = NPY_SHORT_,
  129. NPY_UINT16_ = NPY_USHORT_,
  130. // `npy_common.h` defines the integer aliases. In order, it checks:
  131. // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
  132. // and assigns the alias to the first matching size, so we should check in this order.
  133. NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
  134. NPY_LONG_, NPY_INT_, NPY_SHORT_),
  135. NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
  136. NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
  137. NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
  138. NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
  139. NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
  140. NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
  141. };
  142. typedef struct {
  143. Py_intptr_t *ptr;
  144. int len;
  145. } PyArray_Dims;
  146. static npy_api& get() {
  147. static npy_api api = lookup();
  148. return api;
  149. }
  150. bool PyArray_Check_(PyObject *obj) const {
  151. return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
  152. }
  153. bool PyArrayDescr_Check_(PyObject *obj) const {
  154. return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
  155. }
  156. unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
  157. PyObject *(*PyArray_DescrFromType_)(int);
  158. PyObject *(*PyArray_NewFromDescr_)
  159. (PyTypeObject *, PyObject *, int, Py_intptr_t const *,
  160. Py_intptr_t const *, void *, int, PyObject *);
  161. // Unused. Not removed because that affects ABI of the class.
  162. PyObject *(*PyArray_DescrNewFromType_)(int);
  163. int (*PyArray_CopyInto_)(PyObject *, PyObject *);
  164. PyObject *(*PyArray_NewCopy_)(PyObject *, int);
  165. PyTypeObject *PyArray_Type_;
  166. PyTypeObject *PyVoidArrType_Type_;
  167. PyTypeObject *PyArrayDescr_Type_;
  168. PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
  169. PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
  170. int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
  171. bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
  172. int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, unsigned char, PyObject **, int *,
  173. Py_intptr_t *, PyObject **, PyObject *);
  174. PyObject *(*PyArray_Squeeze_)(PyObject *);
  175. // Unused. Not removed because that affects ABI of the class.
  176. int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
  177. PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
  178. private:
  179. enum functions {
  180. API_PyArray_GetNDArrayCFeatureVersion = 211,
  181. API_PyArray_Type = 2,
  182. API_PyArrayDescr_Type = 3,
  183. API_PyVoidArrType_Type = 39,
  184. API_PyArray_DescrFromType = 45,
  185. API_PyArray_DescrFromScalar = 57,
  186. API_PyArray_FromAny = 69,
  187. API_PyArray_Resize = 80,
  188. API_PyArray_CopyInto = 82,
  189. API_PyArray_NewCopy = 85,
  190. API_PyArray_NewFromDescr = 94,
  191. API_PyArray_DescrNewFromType = 96,
  192. API_PyArray_DescrConverter = 174,
  193. API_PyArray_EquivTypes = 182,
  194. API_PyArray_GetArrayParamsFromObject = 278,
  195. API_PyArray_Squeeze = 136,
  196. API_PyArray_SetBaseObject = 282
  197. };
  198. static npy_api lookup() {
  199. module_ m = module_::import("numpy.core.multiarray");
  200. auto c = m.attr("_ARRAY_API");
  201. #if PY_MAJOR_VERSION >= 3
  202. void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
  203. #else
  204. void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
  205. #endif
  206. npy_api api;
  207. #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
  208. DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
  209. if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
  210. pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
  211. DECL_NPY_API(PyArray_Type);
  212. DECL_NPY_API(PyVoidArrType_Type);
  213. DECL_NPY_API(PyArrayDescr_Type);
  214. DECL_NPY_API(PyArray_DescrFromType);
  215. DECL_NPY_API(PyArray_DescrFromScalar);
  216. DECL_NPY_API(PyArray_FromAny);
  217. DECL_NPY_API(PyArray_Resize);
  218. DECL_NPY_API(PyArray_CopyInto);
  219. DECL_NPY_API(PyArray_NewCopy);
  220. DECL_NPY_API(PyArray_NewFromDescr);
  221. DECL_NPY_API(PyArray_DescrNewFromType);
  222. DECL_NPY_API(PyArray_DescrConverter);
  223. DECL_NPY_API(PyArray_EquivTypes);
  224. DECL_NPY_API(PyArray_GetArrayParamsFromObject);
  225. DECL_NPY_API(PyArray_Squeeze);
  226. DECL_NPY_API(PyArray_SetBaseObject);
  227. #undef DECL_NPY_API
  228. return api;
  229. }
  230. };
  231. inline PyArray_Proxy* array_proxy(void* ptr) {
  232. return reinterpret_cast<PyArray_Proxy*>(ptr);
  233. }
  234. inline const PyArray_Proxy* array_proxy(const void* ptr) {
  235. return reinterpret_cast<const PyArray_Proxy*>(ptr);
  236. }
  237. inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
  238. return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
  239. }
  240. inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
  241. return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
  242. }
  243. inline bool check_flags(const void* ptr, int flag) {
  244. return (flag == (array_proxy(ptr)->flags & flag));
  245. }
  246. template <typename T> struct is_std_array : std::false_type { };
  247. template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
  248. template <typename T> struct is_complex : std::false_type { };
  249. template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
  250. template <typename T> struct array_info_scalar {
  251. using type = T;
  252. static constexpr bool is_array = false;
  253. static constexpr bool is_empty = false;
  254. static constexpr auto extents = _("");
  255. static void append_extents(list& /* shape */) { }
  256. };
  257. // Computes underlying type and a comma-separated list of extents for array
  258. // types (any mix of std::array and built-in arrays). An array of char is
  259. // treated as scalar because it gets special handling.
  260. template <typename T> struct array_info : array_info_scalar<T> { };
  261. template <typename T, size_t N> struct array_info<std::array<T, N>> {
  262. using type = typename array_info<T>::type;
  263. static constexpr bool is_array = true;
  264. static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
  265. static constexpr size_t extent = N;
  266. // appends the extents to shape
  267. static void append_extents(list& shape) {
  268. shape.append(N);
  269. array_info<T>::append_extents(shape);
  270. }
  271. static constexpr auto extents = _<array_info<T>::is_array>(
  272. concat(_<N>(), array_info<T>::extents), _<N>()
  273. );
  274. };
  275. // For numpy we have special handling for arrays of characters, so we don't include
  276. // the size in the array extents.
  277. template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
  278. template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
  279. template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
  280. template <typename T> using remove_all_extents_t = typename array_info<T>::type;
  281. template <typename T> using is_pod_struct = all_of<
  282. std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
  283. #if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI)
  284. // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent
  285. // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4).
  286. std::is_trivially_copyable<T>,
  287. #else
  288. // GCC 4 doesn't implement is_trivially_copyable, so approximate it
  289. std::is_trivially_destructible<T>,
  290. satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
  291. #endif
  292. satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
  293. >;
  294. // Replacement for std::is_pod (deprecated in C++20)
  295. template <typename T> using is_pod = all_of<
  296. std::is_standard_layout<T>,
  297. std::is_trivial<T>
  298. >;
  299. template <ssize_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
  300. template <ssize_t Dim = 0, typename Strides, typename... Ix>
  301. ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
  302. return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
  303. }
  304. /**
  305. * Proxy class providing unsafe, unchecked const access to array data. This is constructed through
  306. * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
  307. * will be -1 for dimensions determined at runtime.
  308. */
  309. template <typename T, ssize_t Dims>
  310. class unchecked_reference {
  311. protected:
  312. static constexpr bool Dynamic = Dims < 0;
  313. const unsigned char *data_;
  314. // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
  315. // make large performance gains on big, nested loops, but requires compile-time dimensions
  316. conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>>
  317. shape_, strides_;
  318. const ssize_t dims_;
  319. friend class pybind11::array;
  320. // Constructor for compile-time dimensions:
  321. template <bool Dyn = Dynamic>
  322. unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<!Dyn, ssize_t>)
  323. : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
  324. for (size_t i = 0; i < (size_t) dims_; i++) {
  325. shape_[i] = shape[i];
  326. strides_[i] = strides[i];
  327. }
  328. }
  329. // Constructor for runtime dimensions:
  330. template <bool Dyn = Dynamic>
  331. unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t<Dyn, ssize_t> dims)
  332. : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
  333. public:
  334. /**
  335. * Unchecked const reference access to data at the given indices. For a compile-time known
  336. * number of dimensions, this requires the correct number of arguments; for run-time
  337. * dimensionality, this is not checked (and so is up to the caller to use safely).
  338. */
  339. template <typename... Ix> const T &operator()(Ix... index) const {
  340. static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
  341. "Invalid number of indices for unchecked array reference");
  342. return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, ssize_t(index)...));
  343. }
  344. /**
  345. * Unchecked const reference access to data; this operator only participates if the reference
  346. * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
  347. */
  348. template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  349. const T &operator[](ssize_t index) const { return operator()(index); }
  350. /// Pointer access to the data at the given indices.
  351. template <typename... Ix> const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); }
  352. /// Returns the item size, i.e. sizeof(T)
  353. constexpr static ssize_t itemsize() { return sizeof(T); }
  354. /// Returns the shape (i.e. size) of dimension `dim`
  355. ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
  356. /// Returns the number of dimensions of the array
  357. ssize_t ndim() const { return dims_; }
  358. /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
  359. template <bool Dyn = Dynamic>
  360. enable_if_t<!Dyn, ssize_t> size() const {
  361. return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
  362. }
  363. template <bool Dyn = Dynamic>
  364. enable_if_t<Dyn, ssize_t> size() const {
  365. return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
  366. }
  367. /// Returns the total number of bytes used by the referenced data. Note that the actual span in
  368. /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
  369. ssize_t nbytes() const {
  370. return size() * itemsize();
  371. }
  372. };
  373. template <typename T, ssize_t Dims>
  374. class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
  375. friend class pybind11::array;
  376. using ConstBase = unchecked_reference<T, Dims>;
  377. using ConstBase::ConstBase;
  378. using ConstBase::Dynamic;
  379. public:
  380. // Bring in const-qualified versions from base class
  381. using ConstBase::operator();
  382. using ConstBase::operator[];
  383. /// Mutable, unchecked access to data at the given indices.
  384. template <typename... Ix> T& operator()(Ix... index) {
  385. static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
  386. "Invalid number of indices for unchecked array reference");
  387. return const_cast<T &>(ConstBase::operator()(index...));
  388. }
  389. /**
  390. * Mutable, unchecked access data at the given index; this operator only participates if the
  391. * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
  392. * exactly equivalent to `obj(index)`.
  393. */
  394. template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  395. T &operator[](ssize_t index) { return operator()(index); }
  396. /// Mutable pointer access to the data at the given indices.
  397. template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); }
  398. };
  399. template <typename T, ssize_t Dim>
  400. struct type_caster<unchecked_reference<T, Dim>> {
  401. static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
  402. };
  403. template <typename T, ssize_t Dim>
  404. struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
  405. PYBIND11_NAMESPACE_END(detail)
  406. class dtype : public object {
  407. public:
  408. PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
  409. explicit dtype(const buffer_info &info) {
  410. dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
  411. // If info.itemsize == 0, use the value calculated from the format string
  412. m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
  413. }
  414. explicit dtype(const std::string &format) {
  415. m_ptr = from_args(pybind11::str(format)).release().ptr();
  416. }
  417. dtype(const char *format) : dtype(std::string(format)) { }
  418. dtype(list names, list formats, list offsets, ssize_t itemsize) {
  419. dict args;
  420. args["names"] = names;
  421. args["formats"] = formats;
  422. args["offsets"] = offsets;
  423. args["itemsize"] = pybind11::int_(itemsize);
  424. m_ptr = from_args(args).release().ptr();
  425. }
  426. /// This is essentially the same as calling numpy.dtype(args) in Python.
  427. static dtype from_args(object args) {
  428. PyObject *ptr = nullptr;
  429. if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr)
  430. throw error_already_set();
  431. return reinterpret_steal<dtype>(ptr);
  432. }
  433. /// Return dtype associated with a C++ type.
  434. template <typename T> static dtype of() {
  435. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
  436. }
  437. /// Size of the data type in bytes.
  438. ssize_t itemsize() const {
  439. return detail::array_descriptor_proxy(m_ptr)->elsize;
  440. }
  441. /// Returns true for structured data types.
  442. bool has_fields() const {
  443. return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
  444. }
  445. /// Single-character type code.
  446. char kind() const {
  447. return detail::array_descriptor_proxy(m_ptr)->kind;
  448. }
  449. private:
  450. static object _dtype_from_pep3118() {
  451. static PyObject *obj = module_::import("numpy.core._internal")
  452. .attr("_dtype_from_pep3118").cast<object>().release().ptr();
  453. return reinterpret_borrow<object>(obj);
  454. }
  455. dtype strip_padding(ssize_t itemsize) {
  456. // Recursively strip all void fields with empty names that are generated for
  457. // padding fields (as of NumPy v1.11).
  458. if (!has_fields())
  459. return *this;
  460. struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
  461. std::vector<field_descr> field_descriptors;
  462. for (auto field : attr("fields").attr("items")()) {
  463. auto spec = field.cast<tuple>();
  464. auto name = spec[0].cast<pybind11::str>();
  465. auto format = spec[1].cast<tuple>()[0].cast<dtype>();
  466. auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
  467. if (!len(name) && format.kind() == 'V')
  468. continue;
  469. field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
  470. }
  471. std::sort(field_descriptors.begin(), field_descriptors.end(),
  472. [](const field_descr& a, const field_descr& b) {
  473. return a.offset.cast<int>() < b.offset.cast<int>();
  474. });
  475. list names, formats, offsets;
  476. for (auto& descr : field_descriptors) {
  477. names.append(descr.name);
  478. formats.append(descr.format);
  479. offsets.append(descr.offset);
  480. }
  481. return dtype(names, formats, offsets, itemsize);
  482. }
  483. };
  484. class array : public buffer {
  485. public:
  486. PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
  487. enum {
  488. c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
  489. f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
  490. forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
  491. };
  492. array() : array(0, static_cast<const double *>(nullptr)) {}
  493. using ShapeContainer = detail::any_container<ssize_t>;
  494. using StridesContainer = detail::any_container<ssize_t>;
  495. // Constructs an array taking shape/strides from arbitrary container types
  496. array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
  497. const void *ptr = nullptr, handle base = handle()) {
  498. if (strides->empty())
  499. *strides = detail::c_strides(*shape, dt.itemsize());
  500. auto ndim = shape->size();
  501. if (ndim != strides->size())
  502. pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
  503. auto descr = dt;
  504. int flags = 0;
  505. if (base && ptr) {
  506. if (isinstance<array>(base))
  507. /* Copy flags from base (except ownership bit) */
  508. flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
  509. else
  510. /* Writable by default, easy to downgrade later on if needed */
  511. flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
  512. }
  513. auto &api = detail::npy_api::get();
  514. auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
  515. api.PyArray_Type_, descr.release().ptr(), (int) ndim,
  516. // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
  517. reinterpret_cast<Py_intptr_t*>(shape->data()),
  518. reinterpret_cast<Py_intptr_t*>(strides->data()),
  519. const_cast<void *>(ptr), flags, nullptr));
  520. if (!tmp)
  521. throw error_already_set();
  522. if (ptr) {
  523. if (base) {
  524. api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
  525. } else {
  526. tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
  527. }
  528. }
  529. m_ptr = tmp.release().ptr();
  530. }
  531. array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
  532. : array(dt, std::move(shape), {}, ptr, base) { }
  533. template <typename T, typename = detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
  534. array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
  535. : array(dt, {{count}}, ptr, base) { }
  536. template <typename T>
  537. array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
  538. : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
  539. template <typename T>
  540. array(ShapeContainer shape, const T *ptr, handle base = handle())
  541. : array(std::move(shape), {}, ptr, base) { }
  542. template <typename T>
  543. explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { }
  544. explicit array(const buffer_info &info, handle base = handle())
  545. : array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) { }
  546. /// Array descriptor (dtype)
  547. pybind11::dtype dtype() const {
  548. return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
  549. }
  550. /// Total number of elements
  551. ssize_t size() const {
  552. return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
  553. }
  554. /// Byte size of a single element
  555. ssize_t itemsize() const {
  556. return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
  557. }
  558. /// Total number of bytes
  559. ssize_t nbytes() const {
  560. return size() * itemsize();
  561. }
  562. /// Number of dimensions
  563. ssize_t ndim() const {
  564. return detail::array_proxy(m_ptr)->nd;
  565. }
  566. /// Base object
  567. object base() const {
  568. return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
  569. }
  570. /// Dimensions of the array
  571. const ssize_t* shape() const {
  572. return detail::array_proxy(m_ptr)->dimensions;
  573. }
  574. /// Dimension along a given axis
  575. ssize_t shape(ssize_t dim) const {
  576. if (dim >= ndim())
  577. fail_dim_check(dim, "invalid axis");
  578. return shape()[dim];
  579. }
  580. /// Strides of the array
  581. const ssize_t* strides() const {
  582. return detail::array_proxy(m_ptr)->strides;
  583. }
  584. /// Stride along a given axis
  585. ssize_t strides(ssize_t dim) const {
  586. if (dim >= ndim())
  587. fail_dim_check(dim, "invalid axis");
  588. return strides()[dim];
  589. }
  590. /// Return the NumPy array flags
  591. int flags() const {
  592. return detail::array_proxy(m_ptr)->flags;
  593. }
  594. /// If set, the array is writeable (otherwise the buffer is read-only)
  595. bool writeable() const {
  596. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
  597. }
  598. /// If set, the array owns the data (will be freed when the array is deleted)
  599. bool owndata() const {
  600. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
  601. }
  602. /// Pointer to the contained data. If index is not provided, points to the
  603. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  604. template<typename... Ix> const void* data(Ix... index) const {
  605. return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  606. }
  607. /// Mutable pointer to the contained data. If index is not provided, points to the
  608. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  609. /// May throw if the array is not writeable.
  610. template<typename... Ix> void* mutable_data(Ix... index) {
  611. check_writeable();
  612. return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  613. }
  614. /// Byte offset from beginning of the array to a given index (full or partial).
  615. /// May throw if the index would lead to out of bounds access.
  616. template<typename... Ix> ssize_t offset_at(Ix... index) const {
  617. if ((ssize_t) sizeof...(index) > ndim())
  618. fail_dim_check(sizeof...(index), "too many indices for an array");
  619. return byte_offset(ssize_t(index)...);
  620. }
  621. ssize_t offset_at() const { return 0; }
  622. /// Item count from beginning of the array to a given index (full or partial).
  623. /// May throw if the index would lead to out of bounds access.
  624. template<typename... Ix> ssize_t index_at(Ix... index) const {
  625. return offset_at(index...) / itemsize();
  626. }
  627. /**
  628. * Returns a proxy object that provides access to the array's data without bounds or
  629. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  630. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  631. * and the caller must take care not to access invalid dimensions or dimension indices.
  632. */
  633. template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
  634. if (Dims >= 0 && ndim() != Dims)
  635. throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
  636. "; expected " + std::to_string(Dims));
  637. return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
  638. }
  639. /**
  640. * Returns a proxy object that provides const access to the array's data without bounds or
  641. * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
  642. * underlying array have the `writable` flag. Use with care: the array must not be destroyed or
  643. * reshaped for the duration of the returned object, and the caller must take care not to access
  644. * invalid dimensions or dimension indices.
  645. */
  646. template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
  647. if (Dims >= 0 && ndim() != Dims)
  648. throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
  649. "; expected " + std::to_string(Dims));
  650. return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
  651. }
  652. /// Return a new view with all of the dimensions of length 1 removed
  653. array squeeze() {
  654. auto& api = detail::npy_api::get();
  655. return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
  656. }
  657. /// Resize array to given shape
  658. /// If refcheck is true and more that one reference exist to this array
  659. /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
  660. void resize(ShapeContainer new_shape, bool refcheck = true) {
  661. detail::npy_api::PyArray_Dims d = {
  662. // Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
  663. reinterpret_cast<Py_intptr_t*>(new_shape->data()),
  664. int(new_shape->size())
  665. };
  666. // try to resize, set ordering param to -1 cause it's not used anyway
  667. auto new_array = reinterpret_steal<object>(
  668. detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1)
  669. );
  670. if (!new_array) throw error_already_set();
  671. if (isinstance<array>(new_array)) { *this = std::move(new_array); }
  672. }
  673. /// Ensure that the argument is a NumPy array
  674. /// In case of an error, nullptr is returned and the Python error is cleared.
  675. static array ensure(handle h, int ExtraFlags = 0) {
  676. auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
  677. if (!result)
  678. PyErr_Clear();
  679. return result;
  680. }
  681. protected:
  682. template<typename, typename> friend struct detail::npy_format_descriptor;
  683. void fail_dim_check(ssize_t dim, const std::string& msg) const {
  684. throw index_error(msg + ": " + std::to_string(dim) +
  685. " (ndim = " + std::to_string(ndim()) + ")");
  686. }
  687. template<typename... Ix> ssize_t byte_offset(Ix... index) const {
  688. check_dimensions(index...);
  689. return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
  690. }
  691. void check_writeable() const {
  692. if (!writeable())
  693. throw std::domain_error("array is not writeable");
  694. }
  695. template<typename... Ix> void check_dimensions(Ix... index) const {
  696. check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
  697. }
  698. void check_dimensions_impl(ssize_t, const ssize_t*) const { }
  699. template<typename... Ix> void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const {
  700. if (i >= *shape) {
  701. throw index_error(std::string("index ") + std::to_string(i) +
  702. " is out of bounds for axis " + std::to_string(axis) +
  703. " with size " + std::to_string(*shape));
  704. }
  705. check_dimensions_impl(axis + 1, shape + 1, index...);
  706. }
  707. /// Create array from any object -- always returns a new reference
  708. static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
  709. if (ptr == nullptr) {
  710. PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
  711. return nullptr;
  712. }
  713. return detail::npy_api::get().PyArray_FromAny_(
  714. ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
  715. }
  716. };
  717. template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
  718. private:
  719. struct private_ctor {};
  720. // Delegating constructor needed when both moving and accessing in the same constructor
  721. array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
  722. : array(std::move(shape), std::move(strides), ptr, base) {}
  723. public:
  724. static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
  725. using value_type = T;
  726. array_t() : array(0, static_cast<const T *>(nullptr)) {}
  727. array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { }
  728. array_t(handle h, stolen_t) : array(h, stolen_t{}) { }
  729. PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
  730. array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
  731. if (!m_ptr) PyErr_Clear();
  732. if (!is_borrowed) Py_XDECREF(h.ptr());
  733. }
  734. array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
  735. if (!m_ptr) throw error_already_set();
  736. }
  737. explicit array_t(const buffer_info& info, handle base = handle()) : array(info, base) { }
  738. array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
  739. : array(std::move(shape), std::move(strides), ptr, base) { }
  740. explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
  741. : array_t(private_ctor{}, std::move(shape),
  742. ExtraFlags & f_style
  743. ? detail::f_strides(*shape, itemsize())
  744. : detail::c_strides(*shape, itemsize()),
  745. ptr, base) { }
  746. explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
  747. : array({count}, {}, ptr, base) { }
  748. constexpr ssize_t itemsize() const {
  749. return sizeof(T);
  750. }
  751. template<typename... Ix> ssize_t index_at(Ix... index) const {
  752. return offset_at(index...) / itemsize();
  753. }
  754. template<typename... Ix> const T* data(Ix... index) const {
  755. return static_cast<const T*>(array::data(index...));
  756. }
  757. template<typename... Ix> T* mutable_data(Ix... index) {
  758. return static_cast<T*>(array::mutable_data(index...));
  759. }
  760. // Reference to element at a given index
  761. template<typename... Ix> const T& at(Ix... index) const {
  762. if ((ssize_t) sizeof...(index) != ndim())
  763. fail_dim_check(sizeof...(index), "index dimension mismatch");
  764. return *(static_cast<const T*>(array::data()) + byte_offset(ssize_t(index)...) / itemsize());
  765. }
  766. // Mutable reference to element at a given index
  767. template<typename... Ix> T& mutable_at(Ix... index) {
  768. if ((ssize_t) sizeof...(index) != ndim())
  769. fail_dim_check(sizeof...(index), "index dimension mismatch");
  770. return *(static_cast<T*>(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize());
  771. }
  772. /**
  773. * Returns a proxy object that provides access to the array's data without bounds or
  774. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  775. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  776. * and the caller must take care not to access invalid dimensions or dimension indices.
  777. */
  778. template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
  779. return array::mutable_unchecked<T, Dims>();
  780. }
  781. /**
  782. * Returns a proxy object that provides const access to the array's data without bounds or
  783. * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
  784. * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
  785. * for the duration of the returned object, and the caller must take care not to access invalid
  786. * dimensions or dimension indices.
  787. */
  788. template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const & {
  789. return array::unchecked<T, Dims>();
  790. }
  791. /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
  792. /// it). In case of an error, nullptr is returned and the Python error is cleared.
  793. static array_t ensure(handle h) {
  794. auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
  795. if (!result)
  796. PyErr_Clear();
  797. return result;
  798. }
  799. static bool check_(handle h) {
  800. const auto &api = detail::npy_api::get();
  801. return api.PyArray_Check_(h.ptr())
  802. && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr())
  803. && detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
  804. }
  805. protected:
  806. /// Create array from any object -- always returns a new reference
  807. static PyObject *raw_array_t(PyObject *ptr) {
  808. if (ptr == nullptr) {
  809. PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
  810. return nullptr;
  811. }
  812. return detail::npy_api::get().PyArray_FromAny_(
  813. ptr, dtype::of<T>().release().ptr(), 0, 0,
  814. detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
  815. }
  816. };
  817. template <typename T>
  818. struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  819. static std::string format() {
  820. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
  821. }
  822. };
  823. template <size_t N> struct format_descriptor<char[N]> {
  824. static std::string format() { return std::to_string(N) + "s"; }
  825. };
  826. template <size_t N> struct format_descriptor<std::array<char, N>> {
  827. static std::string format() { return std::to_string(N) + "s"; }
  828. };
  829. template <typename T>
  830. struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
  831. static std::string format() {
  832. return format_descriptor<
  833. typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
  834. }
  835. };
  836. template <typename T>
  837. struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
  838. static std::string format() {
  839. using namespace detail;
  840. static constexpr auto extents = _("(") + array_info<T>::extents + _(")");
  841. return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
  842. }
  843. };
  844. PYBIND11_NAMESPACE_BEGIN(detail)
  845. template <typename T, int ExtraFlags>
  846. struct pyobject_caster<array_t<T, ExtraFlags>> {
  847. using type = array_t<T, ExtraFlags>;
  848. bool load(handle src, bool convert) {
  849. if (!convert && !type::check_(src))
  850. return false;
  851. value = type::ensure(src);
  852. return static_cast<bool>(value);
  853. }
  854. static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
  855. return src.inc_ref();
  856. }
  857. PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
  858. };
  859. template <typename T>
  860. struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  861. static bool compare(const buffer_info& b) {
  862. return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
  863. }
  864. };
  865. template <typename T, typename = void>
  866. struct npy_format_descriptor_name;
  867. template <typename T>
  868. struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
  869. static constexpr auto name = _<std::is_same<T, bool>::value>(
  870. _("bool"), _<std::is_signed<T>::value>("numpy.int", "numpy.uint") + _<sizeof(T)*8>()
  871. );
  872. };
  873. template <typename T>
  874. struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
  875. static constexpr auto name = _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
  876. _("numpy.float") + _<sizeof(T)*8>(), _("numpy.longdouble")
  877. );
  878. };
  879. template <typename T>
  880. struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
  881. static constexpr auto name = _<std::is_same<typename T::value_type, float>::value
  882. || std::is_same<typename T::value_type, double>::value>(
  883. _("numpy.complex") + _<sizeof(typename T::value_type)*16>(), _("numpy.longcomplex")
  884. );
  885. };
  886. template <typename T>
  887. struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
  888. : npy_format_descriptor_name<T> {
  889. private:
  890. // NB: the order here must match the one in common.h
  891. constexpr static const int values[15] = {
  892. npy_api::NPY_BOOL_,
  893. npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_INT16_, npy_api::NPY_UINT16_,
  894. npy_api::NPY_INT32_, npy_api::NPY_UINT32_, npy_api::NPY_INT64_, npy_api::NPY_UINT64_,
  895. npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
  896. npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
  897. };
  898. public:
  899. static constexpr int value = values[detail::is_fmt_numeric<T>::index];
  900. static pybind11::dtype dtype() {
  901. if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
  902. return reinterpret_steal<pybind11::dtype>(ptr);
  903. pybind11_fail("Unsupported buffer format!");
  904. }
  905. };
  906. #define PYBIND11_DECL_CHAR_FMT \
  907. static constexpr auto name = _("S") + _<N>(); \
  908. static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
  909. template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
  910. template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
  911. #undef PYBIND11_DECL_CHAR_FMT
  912. template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
  913. private:
  914. using base_descr = npy_format_descriptor<typename array_info<T>::type>;
  915. public:
  916. static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
  917. static constexpr auto name = _("(") + array_info<T>::extents + _(")") + base_descr::name;
  918. static pybind11::dtype dtype() {
  919. list shape;
  920. array_info<T>::append_extents(shape);
  921. return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
  922. }
  923. };
  924. template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
  925. private:
  926. using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
  927. public:
  928. static constexpr auto name = base_descr::name;
  929. static pybind11::dtype dtype() { return base_descr::dtype(); }
  930. };
  931. struct field_descriptor {
  932. const char *name;
  933. ssize_t offset;
  934. ssize_t size;
  935. std::string format;
  936. dtype descr;
  937. };
  938. inline PYBIND11_NOINLINE void register_structured_dtype(
  939. any_container<field_descriptor> fields,
  940. const std::type_info& tinfo, ssize_t itemsize,
  941. bool (*direct_converter)(PyObject *, void *&)) {
  942. auto& numpy_internals = get_numpy_internals();
  943. if (numpy_internals.get_type_info(tinfo, false))
  944. pybind11_fail("NumPy: dtype is already registered");
  945. // Use ordered fields because order matters as of NumPy 1.14:
  946. // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
  947. std::vector<field_descriptor> ordered_fields(std::move(fields));
  948. std::sort(ordered_fields.begin(), ordered_fields.end(),
  949. [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
  950. list names, formats, offsets;
  951. for (auto& field : ordered_fields) {
  952. if (!field.descr)
  953. pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
  954. field.name + "` @ " + tinfo.name());
  955. names.append(PYBIND11_STR_TYPE(field.name));
  956. formats.append(field.descr);
  957. offsets.append(pybind11::int_(field.offset));
  958. }
  959. auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
  960. // There is an existing bug in NumPy (as of v1.11): trailing bytes are
  961. // not encoded explicitly into the format string. This will supposedly
  962. // get fixed in v1.12; for further details, see these:
  963. // - https://github.com/numpy/numpy/issues/7797
  964. // - https://github.com/numpy/numpy/pull/7798
  965. // Because of this, we won't use numpy's logic to generate buffer format
  966. // strings and will just do it ourselves.
  967. ssize_t offset = 0;
  968. std::ostringstream oss;
  969. // mark the structure as unaligned with '^', because numpy and C++ don't
  970. // always agree about alignment (particularly for complex), and we're
  971. // explicitly listing all our padding. This depends on none of the fields
  972. // overriding the endianness. Putting the ^ in front of individual fields
  973. // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
  974. oss << "^T{";
  975. for (auto& field : ordered_fields) {
  976. if (field.offset > offset)
  977. oss << (field.offset - offset) << 'x';
  978. oss << field.format << ':' << field.name << ':';
  979. offset = field.offset + field.size;
  980. }
  981. if (itemsize > offset)
  982. oss << (itemsize - offset) << 'x';
  983. oss << '}';
  984. auto format_str = oss.str();
  985. // Sanity check: verify that NumPy properly parses our buffer format string
  986. auto& api = npy_api::get();
  987. auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
  988. if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
  989. pybind11_fail("NumPy: invalid buffer descriptor!");
  990. auto tindex = std::type_index(tinfo);
  991. numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
  992. get_internals().direct_conversions[tindex].push_back(direct_converter);
  993. }
  994. template <typename T, typename SFINAE> struct npy_format_descriptor {
  995. static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
  996. static constexpr auto name = make_caster<T>::name;
  997. static pybind11::dtype dtype() {
  998. return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
  999. }
  1000. static std::string format() {
  1001. static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
  1002. return format_str;
  1003. }
  1004. static void register_dtype(any_container<field_descriptor> fields) {
  1005. register_structured_dtype(std::move(fields), typeid(typename std::remove_cv<T>::type),
  1006. sizeof(T), &direct_converter);
  1007. }
  1008. private:
  1009. static PyObject* dtype_ptr() {
  1010. static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
  1011. return ptr;
  1012. }
  1013. static bool direct_converter(PyObject *obj, void*& value) {
  1014. auto& api = npy_api::get();
  1015. if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
  1016. return false;
  1017. if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
  1018. if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
  1019. value = ((PyVoidScalarObject_Proxy *) obj)->obval;
  1020. return true;
  1021. }
  1022. }
  1023. return false;
  1024. }
  1025. };
  1026. #ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
  1027. # define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0)
  1028. # define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0)
  1029. #else
  1030. #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
  1031. ::pybind11::detail::field_descriptor { \
  1032. Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
  1033. ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
  1034. ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
  1035. }
  1036. // Extract name, offset and format descriptor for a struct field
  1037. #define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
  1038. // The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
  1039. // (C) William Swanson, Paul Fultz
  1040. #define PYBIND11_EVAL0(...) __VA_ARGS__
  1041. #define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
  1042. #define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
  1043. #define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
  1044. #define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
  1045. #define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
  1046. #define PYBIND11_MAP_END(...)
  1047. #define PYBIND11_MAP_OUT
  1048. #define PYBIND11_MAP_COMMA ,
  1049. #define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
  1050. #define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
  1051. #define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
  1052. #define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
  1053. #if defined(_MSC_VER) && !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
  1054. #define PYBIND11_MAP_LIST_NEXT1(test, next) \
  1055. PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
  1056. #else
  1057. #define PYBIND11_MAP_LIST_NEXT1(test, next) \
  1058. PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
  1059. #endif
  1060. #define PYBIND11_MAP_LIST_NEXT(test, next) \
  1061. PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
  1062. #define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
  1063. f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
  1064. #define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
  1065. f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
  1066. // PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
  1067. #define PYBIND11_MAP_LIST(f, t, ...) \
  1068. PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
  1069. #define PYBIND11_NUMPY_DTYPE(Type, ...) \
  1070. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
  1071. (::std::vector<::pybind11::detail::field_descriptor> \
  1072. {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
  1073. #if defined(_MSC_VER) && !defined(__clang__)
  1074. #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  1075. PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
  1076. #else
  1077. #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  1078. PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
  1079. #endif
  1080. #define PYBIND11_MAP2_LIST_NEXT(test, next) \
  1081. PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
  1082. #define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
  1083. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
  1084. #define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
  1085. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
  1086. // PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
  1087. #define PYBIND11_MAP2_LIST(f, t, ...) \
  1088. PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
  1089. #define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
  1090. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
  1091. (::std::vector<::pybind11::detail::field_descriptor> \
  1092. {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
  1093. #endif // __CLION_IDE__
  1094. class common_iterator {
  1095. public:
  1096. using container_type = std::vector<ssize_t>;
  1097. using value_type = container_type::value_type;
  1098. using size_type = container_type::size_type;
  1099. common_iterator() : p_ptr(0), m_strides() {}
  1100. common_iterator(void* ptr, const container_type& strides, const container_type& shape)
  1101. : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
  1102. m_strides.back() = static_cast<value_type>(strides.back());
  1103. for (size_type i = m_strides.size() - 1; i != 0; --i) {
  1104. size_type j = i - 1;
  1105. auto s = static_cast<value_type>(shape[i]);
  1106. m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
  1107. }
  1108. }
  1109. void increment(size_type dim) {
  1110. p_ptr += m_strides[dim];
  1111. }
  1112. void* data() const {
  1113. return p_ptr;
  1114. }
  1115. private:
  1116. char* p_ptr;
  1117. container_type m_strides;
  1118. };
  1119. template <size_t N> class multi_array_iterator {
  1120. public:
  1121. using container_type = std::vector<ssize_t>;
  1122. multi_array_iterator(const std::array<buffer_info, N> &buffers,
  1123. const container_type &shape)
  1124. : m_shape(shape.size()), m_index(shape.size(), 0),
  1125. m_common_iterator() {
  1126. // Manual copy to avoid conversion warning if using std::copy
  1127. for (size_t i = 0; i < shape.size(); ++i)
  1128. m_shape[i] = shape[i];
  1129. container_type strides(shape.size());
  1130. for (size_t i = 0; i < N; ++i)
  1131. init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
  1132. }
  1133. multi_array_iterator& operator++() {
  1134. for (size_t j = m_index.size(); j != 0; --j) {
  1135. size_t i = j - 1;
  1136. if (++m_index[i] != m_shape[i]) {
  1137. increment_common_iterator(i);
  1138. break;
  1139. } else {
  1140. m_index[i] = 0;
  1141. }
  1142. }
  1143. return *this;
  1144. }
  1145. template <size_t K, class T = void> T* data() const {
  1146. return reinterpret_cast<T*>(m_common_iterator[K].data());
  1147. }
  1148. private:
  1149. using common_iter = common_iterator;
  1150. void init_common_iterator(const buffer_info &buffer,
  1151. const container_type &shape,
  1152. common_iter &iterator,
  1153. container_type &strides) {
  1154. auto buffer_shape_iter = buffer.shape.rbegin();
  1155. auto buffer_strides_iter = buffer.strides.rbegin();
  1156. auto shape_iter = shape.rbegin();
  1157. auto strides_iter = strides.rbegin();
  1158. while (buffer_shape_iter != buffer.shape.rend()) {
  1159. if (*shape_iter == *buffer_shape_iter)
  1160. *strides_iter = *buffer_strides_iter;
  1161. else
  1162. *strides_iter = 0;
  1163. ++buffer_shape_iter;
  1164. ++buffer_strides_iter;
  1165. ++shape_iter;
  1166. ++strides_iter;
  1167. }
  1168. std::fill(strides_iter, strides.rend(), 0);
  1169. iterator = common_iter(buffer.ptr, strides, shape);
  1170. }
  1171. void increment_common_iterator(size_t dim) {
  1172. for (auto &iter : m_common_iterator)
  1173. iter.increment(dim);
  1174. }
  1175. container_type m_shape;
  1176. container_type m_index;
  1177. std::array<common_iter, N> m_common_iterator;
  1178. };
  1179. enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
  1180. // Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial
  1181. // enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
  1182. // singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
  1183. // buffer; returns `non_trivial` otherwise.
  1184. template <size_t N>
  1185. broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
  1186. ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
  1187. return std::max(res, buf.ndim);
  1188. });
  1189. shape.clear();
  1190. shape.resize((size_t) ndim, 1);
  1191. // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
  1192. // the full size).
  1193. for (size_t i = 0; i < N; ++i) {
  1194. auto res_iter = shape.rbegin();
  1195. auto end = buffers[i].shape.rend();
  1196. for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
  1197. const auto &dim_size_in = *shape_iter;
  1198. auto &dim_size_out = *res_iter;
  1199. // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
  1200. if (dim_size_out == 1)
  1201. dim_size_out = dim_size_in;
  1202. else if (dim_size_in != 1 && dim_size_in != dim_size_out)
  1203. pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
  1204. }
  1205. }
  1206. bool trivial_broadcast_c = true;
  1207. bool trivial_broadcast_f = true;
  1208. for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
  1209. if (buffers[i].size == 1)
  1210. continue;
  1211. // Require the same number of dimensions:
  1212. if (buffers[i].ndim != ndim)
  1213. return broadcast_trivial::non_trivial;
  1214. // Require all dimensions be full-size:
  1215. if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
  1216. return broadcast_trivial::non_trivial;
  1217. // Check for C contiguity (but only if previous inputs were also C contiguous)
  1218. if (trivial_broadcast_c) {
  1219. ssize_t expect_stride = buffers[i].itemsize;
  1220. auto end = buffers[i].shape.crend();
  1221. for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
  1222. trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
  1223. if (expect_stride == *stride_iter)
  1224. expect_stride *= *shape_iter;
  1225. else
  1226. trivial_broadcast_c = false;
  1227. }
  1228. }
  1229. // Check for Fortran contiguity (if previous inputs were also F contiguous)
  1230. if (trivial_broadcast_f) {
  1231. ssize_t expect_stride = buffers[i].itemsize;
  1232. auto end = buffers[i].shape.cend();
  1233. for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
  1234. trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
  1235. if (expect_stride == *stride_iter)
  1236. expect_stride *= *shape_iter;
  1237. else
  1238. trivial_broadcast_f = false;
  1239. }
  1240. }
  1241. }
  1242. return
  1243. trivial_broadcast_c ? broadcast_trivial::c_trivial :
  1244. trivial_broadcast_f ? broadcast_trivial::f_trivial :
  1245. broadcast_trivial::non_trivial;
  1246. }
  1247. template <typename T>
  1248. struct vectorize_arg {
  1249. static_assert(!std::is_rvalue_reference<T>::value, "Functions with rvalue reference arguments cannot be vectorized");
  1250. // The wrapped function gets called with this type:
  1251. using call_type = remove_reference_t<T>;
  1252. // Is this a vectorized argument?
  1253. static constexpr bool vectorize =
  1254. satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value &&
  1255. satisfies_none_of<call_type, std::is_pointer, std::is_array, is_std_array, std::is_enum>::value &&
  1256. (!std::is_reference<T>::value ||
  1257. (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
  1258. // Accept this type: an array for vectorized types, otherwise the type as-is:
  1259. using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
  1260. };
  1261. // py::vectorize when a return type is present
  1262. template <typename Func, typename Return, typename... Args>
  1263. struct vectorize_returned_array {
  1264. using Type = array_t<Return>;
  1265. static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
  1266. if (trivial == broadcast_trivial::f_trivial)
  1267. return array_t<Return, array::f_style>(shape);
  1268. else
  1269. return array_t<Return>(shape);
  1270. }
  1271. static Return *mutable_data(Type &array) {
  1272. return array.mutable_data();
  1273. }
  1274. static Return call(Func &f, Args &... args) {
  1275. return f(args...);
  1276. }
  1277. static void call(Return *out, size_t i, Func &f, Args &... args) {
  1278. out[i] = f(args...);
  1279. }
  1280. };
  1281. // py::vectorize when a return type is not present
  1282. template <typename Func, typename... Args>
  1283. struct vectorize_returned_array<Func, void, Args...> {
  1284. using Type = none;
  1285. static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
  1286. return none();
  1287. }
  1288. static void *mutable_data(Type &) {
  1289. return nullptr;
  1290. }
  1291. static detail::void_type call(Func &f, Args &... args) {
  1292. f(args...);
  1293. return {};
  1294. }
  1295. static void call(void *, size_t, Func &f, Args &... args) {
  1296. f(args...);
  1297. }
  1298. };
  1299. template <typename Func, typename Return, typename... Args>
  1300. struct vectorize_helper {
  1301. // NVCC for some reason breaks if NVectorized is private
  1302. #ifdef __CUDACC__
  1303. public:
  1304. #else
  1305. private:
  1306. #endif
  1307. static constexpr size_t N = sizeof...(Args);
  1308. static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
  1309. static_assert(NVectorized >= 1,
  1310. "pybind11::vectorize(...) requires a function with at least one vectorizable argument");
  1311. public:
  1312. template <typename T>
  1313. explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) { }
  1314. object operator()(typename vectorize_arg<Args>::type... args) {
  1315. return run(args...,
  1316. make_index_sequence<N>(),
  1317. select_indices<vectorize_arg<Args>::vectorize...>(),
  1318. make_index_sequence<NVectorized>());
  1319. }
  1320. private:
  1321. remove_reference_t<Func> f;
  1322. // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag
  1323. // when arg_call_types is manually inlined.
  1324. using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
  1325. template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
  1326. using returned_array = vectorize_returned_array<Func, Return, Args...>;
  1327. // Runs a vectorized function given arguments tuple and three index sequences:
  1328. // - Index is the full set of 0 ... (N-1) argument indices;
  1329. // - VIndex is the subset of argument indices with vectorized parameters, letting us access
  1330. // vectorized arguments (anything not in this sequence is passed through)
  1331. // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
  1332. // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
  1333. // index BIndex in the array).
  1334. template <size_t... Index, size_t... VIndex, size_t... BIndex> object run(
  1335. typename vectorize_arg<Args>::type &...args,
  1336. index_sequence<Index...> i_seq, index_sequence<VIndex...> vi_seq, index_sequence<BIndex...> bi_seq) {
  1337. // Pointers to values the function was called with; the vectorized ones set here will start
  1338. // out as array_t<T> pointers, but they will be changed them to T pointers before we make
  1339. // call the wrapped function. Non-vectorized pointers are left as-is.
  1340. std::array<void *, N> params{{ &args... }};
  1341. // The array of `buffer_info`s of vectorized arguments:
  1342. std::array<buffer_info, NVectorized> buffers{{ reinterpret_cast<array *>(params[VIndex])->request()... }};
  1343. /* Determine dimensions parameters of output array */
  1344. ssize_t nd = 0;
  1345. std::vector<ssize_t> shape(0);
  1346. auto trivial = broadcast(buffers, nd, shape);
  1347. auto ndim = (size_t) nd;
  1348. size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
  1349. // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
  1350. // not wrapped in an array).
  1351. if (size == 1 && ndim == 0) {
  1352. PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
  1353. return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
  1354. }
  1355. auto result = returned_array::create(trivial, shape);
  1356. if (size == 0) return std::move(result);
  1357. /* Call the function */
  1358. auto mutable_data = returned_array::mutable_data(result);
  1359. if (trivial == broadcast_trivial::non_trivial)
  1360. apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
  1361. else
  1362. apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
  1363. return std::move(result);
  1364. }
  1365. template <size_t... Index, size_t... VIndex, size_t... BIndex>
  1366. void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
  1367. std::array<void *, N> &params,
  1368. Return *out,
  1369. size_t size,
  1370. index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
  1371. // Initialize an array of mutable byte references and sizes with references set to the
  1372. // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
  1373. // (except for singletons, which get an increment of 0).
  1374. std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{{
  1375. std::pair<unsigned char *&, const size_t>(
  1376. reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
  1377. buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>)
  1378. )...
  1379. }};
  1380. for (size_t i = 0; i < size; ++i) {
  1381. returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
  1382. for (auto &x : vecparams) x.first += x.second;
  1383. }
  1384. }
  1385. template <size_t... Index, size_t... VIndex, size_t... BIndex>
  1386. void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
  1387. std::array<void *, N> &params,
  1388. Return *out,
  1389. size_t size,
  1390. const std::vector<ssize_t> &output_shape,
  1391. index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
  1392. multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
  1393. for (size_t i = 0; i < size; ++i, ++input_iter) {
  1394. PYBIND11_EXPAND_SIDE_EFFECTS((
  1395. params[VIndex] = input_iter.template data<BIndex>()
  1396. ));
  1397. returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
  1398. }
  1399. }
  1400. };
  1401. template <typename Func, typename Return, typename... Args>
  1402. vectorize_helper<Func, Return, Args...>
  1403. vectorize_extractor(const Func &f, Return (*) (Args ...)) {
  1404. return detail::vectorize_helper<Func, Return, Args...>(f);
  1405. }
  1406. template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
  1407. static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor<T>::name + _("]");
  1408. };
  1409. PYBIND11_NAMESPACE_END(detail)
  1410. // Vanilla pointer vectorizer:
  1411. template <typename Return, typename... Args>
  1412. detail::vectorize_helper<Return (*)(Args...), Return, Args...>
  1413. vectorize(Return (*f) (Args ...)) {
  1414. return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
  1415. }
  1416. // lambda vectorizer:
  1417. template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
  1418. auto vectorize(Func &&f) -> decltype(
  1419. detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr)) {
  1420. return detail::vectorize_extractor(std::forward<Func>(f), (detail::function_signature_t<Func> *) nullptr);
  1421. }
  1422. // Vectorize a class method (non-const):
  1423. template <typename Return, typename Class, typename... Args,
  1424. typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())), Return, Class *, Args...>>
  1425. Helper vectorize(Return (Class::*f)(Args...)) {
  1426. return Helper(std::mem_fn(f));
  1427. }
  1428. // Vectorize a class method (const):
  1429. template <typename Return, typename Class, typename... Args,
  1430. typename Helper = detail::vectorize_helper<decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())), Return, const Class *, Args...>>
  1431. Helper vectorize(Return (Class::*f)(Args...) const) {
  1432. return Helper(std::mem_fn(f));
  1433. }
  1434. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
  1435. #if defined(_MSC_VER)
  1436. #pragma warning(pop)
  1437. #endif