Skip to content

Commit

Permalink
Merge pull request numpy#21042 from charris/style-fixups-npysort
Browse files Browse the repository at this point in the history
MAINT, STY: Style fixups.
  • Loading branch information
mattip authored Feb 13, 2022
2 parents 1e26776 + 4e0e767 commit d14d94c
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 114 deletions.
218 changes: 117 additions & 101 deletions numpy/core/src/npysort/binsearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,65 @@

#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include "npy_sort.h"
#include "numpy_tag.h"
#include <numpy/npy_common.h>
#include <numpy/ndarraytypes.h>
#include "numpy/ndarraytypes.h"
#include "numpy/npy_common.h"

#include "npy_binsearch.h"
#include "npy_sort.h"
#include "numpy_tag.h"

#include <array>
#include <functional> // for std::less and std::less_equal

// Enumerators for the variant of binsearch
enum arg_t { noarg, arg};
enum side_t { left, right};
enum arg_t
{
noarg,
arg
};
enum side_t
{
left,
right
};

// Mapping from enumerators to comparators
template<class Tag, side_t side>
template <class Tag, side_t side>
struct side_to_cmp;
template<class Tag>
struct side_to_cmp<Tag, left> { static constexpr auto value = Tag::less; };
template<class Tag>
struct side_to_cmp<Tag, right> { static constexpr auto value = Tag::less_equal; };

template<side_t side>
template <class Tag>
struct side_to_cmp<Tag, left> {
static constexpr auto value = Tag::less;
};

template <class Tag>
struct side_to_cmp<Tag, right> {
static constexpr auto value = Tag::less_equal;
};

template <side_t side>
struct side_to_generic_cmp;
template<>
struct side_to_generic_cmp<left> { using type = std::less<int>; };
template<>
struct side_to_generic_cmp<right> { using type = std::less_equal<int>; };

template <>
struct side_to_generic_cmp<left> {
using type = std::less<int>;
};

template <>
struct side_to_generic_cmp<right> {
using type = std::less_equal<int>;
};

/*
*****************************************************************************
** NUMERIC SEARCHES **
*****************************************************************************
*/
template<class Tag, side_t side>
template <class Tag, side_t side>
static void
binsearch(const char *arr, const char *key, char *ret,
npy_intp arr_len, npy_intp key_len,
npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
PyArrayObject*)
binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
npy_intp key_len, npy_intp arr_str, npy_intp key_str,
npy_intp ret_str, PyArrayObject *)
{
using T = typename Tag::type;
auto cmp = side_to_cmp<Tag, side>::value;
Expand Down Expand Up @@ -73,7 +92,7 @@ binsearch(const char *arr, const char *key, char *ret,

while (min_idx < max_idx) {
const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
const T mid_val = *(const T *)(arr + mid_idx*arr_str);
const T mid_val = *(const T *)(arr + mid_idx * arr_str);
if (cmp(mid_val, key_val)) {
min_idx = mid_idx + 1;
}
Expand All @@ -85,13 +104,12 @@ binsearch(const char *arr, const char *key, char *ret,
}
}

template<class Tag, side_t side>
template <class Tag, side_t side>
static int
argbinsearch(const char *arr, const char *key,
const char *sort, char *ret,
npy_intp arr_len, npy_intp key_len,
npy_intp arr_str, npy_intp key_str,
npy_intp sort_str, npy_intp ret_str, PyArrayObject*)
argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
npy_intp arr_len, npy_intp key_len, npy_intp arr_str,
npy_intp key_str, npy_intp sort_str, npy_intp ret_str,
PyArrayObject *)
{
using T = typename Tag::type;
auto cmp = side_to_cmp<Tag, side>::value;
Expand Down Expand Up @@ -123,14 +141,14 @@ argbinsearch(const char *arr, const char *key,

while (min_idx < max_idx) {
const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx * sort_str);
T mid_val;

if (sort_idx < 0 || sort_idx >= arr_len) {
return -1;
}

mid_val = *(const T *)(arr + sort_idx*arr_str);
mid_val = *(const T *)(arr + sort_idx * arr_str);

if (cmp(mid_val, key_val)) {
min_idx = mid_idx + 1;
Expand All @@ -150,12 +168,11 @@ argbinsearch(const char *arr, const char *key,
*****************************************************************************
*/

template<side_t side>
template <side_t side>
static void
npy_binsearch(const char *arr, const char *key, char *ret,
npy_intp arr_len, npy_intp key_len,
npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
PyArrayObject *cmp)
npy_binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
npy_intp key_len, npy_intp arr_str, npy_intp key_str,
npy_intp ret_str, PyArrayObject *cmp)
{
using Cmp = typename side_to_generic_cmp<side>::type;
PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
Expand All @@ -181,7 +198,7 @@ npy_binsearch(const char *arr, const char *key, char *ret,

while (min_idx < max_idx) {
const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
const char *arr_ptr = arr + mid_idx*arr_str;
const char *arr_ptr = arr + mid_idx * arr_str;

if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
min_idx = mid_idx + 1;
Expand All @@ -194,14 +211,12 @@ npy_binsearch(const char *arr, const char *key, char *ret,
}
}

template<side_t side>
template <side_t side>
static int
npy_argbinsearch(const char *arr, const char *key,
const char *sort, char *ret,
npy_intp arr_len, npy_intp key_len,
npy_intp arr_str, npy_intp key_str,
npy_intp sort_str, npy_intp ret_str,
PyArrayObject *cmp)
npy_argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
npy_intp arr_len, npy_intp key_len, npy_intp arr_str,
npy_intp key_str, npy_intp sort_str, npy_intp ret_str,
PyArrayObject *cmp)
{
using Cmp = typename side_to_generic_cmp<side>::type;
PyArray_CompareFunc *compare = PyArray_DESCR(cmp)->f->compare;
Expand All @@ -227,14 +242,14 @@ npy_argbinsearch(const char *arr, const char *key,

while (min_idx < max_idx) {
const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx*sort_str);
const npy_intp sort_idx = *(npy_intp *)(sort + mid_idx * sort_str);
const char *arr_ptr;

if (sort_idx < 0 || sort_idx >= arr_len) {
return -1;
}

arr_ptr = arr + sort_idx*arr_str;
arr_ptr = arr + sort_idx * arr_str;

if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
min_idx = mid_idx + 1;
Expand All @@ -254,88 +269,86 @@ npy_argbinsearch(const char *arr, const char *key,
*****************************************************************************
*/

template<arg_t arg>
template <arg_t arg>
struct binsearch_base;

template<>
template <>
struct binsearch_base<arg> {
using function_type = PyArray_ArgBinSearchFunc*;
using function_type = PyArray_ArgBinSearchFunc *;
struct value_type {
int typenum;
function_type binsearch[NPY_NSEARCHSIDES];
};
template<class... Tags>
static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) {
template <class... Tags>
static constexpr std::array<value_type, sizeof...(Tags)>
make_binsearch_map(npy::taglist<Tags...>)
{
return std::array<value_type, sizeof...(Tags)>{
value_type{
Tags::type_value,
{
(function_type)&argbinsearch<Tags, left>,
(function_type)argbinsearch<Tags, right>
}
}...
};
value_type{Tags::type_value,
{(function_type)&argbinsearch<Tags, left>,
(function_type)argbinsearch<Tags, right>}}...};
}
static constexpr std::array<function_type, 2> npy_map = {
(function_type)&npy_argbinsearch<left>,
(function_type)&npy_argbinsearch<right>
};
(function_type)&npy_argbinsearch<left>,
(function_type)&npy_argbinsearch<right>};
};
constexpr std::array<binsearch_base<arg>::function_type, 2> binsearch_base<arg>::npy_map;
constexpr std::array<binsearch_base<arg>::function_type, 2>
binsearch_base<arg>::npy_map;

template<>
template <>
struct binsearch_base<noarg> {
using function_type = PyArray_BinSearchFunc*;
using function_type = PyArray_BinSearchFunc *;
struct value_type {
int typenum;
function_type binsearch[NPY_NSEARCHSIDES];
};
template<class... Tags>
static constexpr std::array<value_type, sizeof...(Tags)> make_binsearch_map(npy::taglist<Tags...>) {
template <class... Tags>
static constexpr std::array<value_type, sizeof...(Tags)>
make_binsearch_map(npy::taglist<Tags...>)
{
return std::array<value_type, sizeof...(Tags)>{
value_type{
Tags::type_value,
{
(function_type)&binsearch<Tags, left>,
(function_type)binsearch<Tags, right>
}
}...
};
value_type{Tags::type_value,
{(function_type)&binsearch<Tags, left>,
(function_type)binsearch<Tags, right>}}...};
}
static constexpr std::array<function_type, 2> npy_map = {
(function_type)&npy_binsearch<left>,
(function_type)&npy_binsearch<right>
};
(function_type)&npy_binsearch<left>,
(function_type)&npy_binsearch<right>};
};
constexpr std::array<binsearch_base<noarg>::function_type, 2> binsearch_base<noarg>::npy_map;
constexpr std::array<binsearch_base<noarg>::function_type, 2>
binsearch_base<noarg>::npy_map;

// Handle generation of all binsearch variants
template<arg_t arg>
template <arg_t arg>
struct binsearch_t : binsearch_base<arg> {
using binsearch_base<arg>::make_binsearch_map;
using value_type = typename binsearch_base<arg>::value_type;

using taglist = npy::taglist<
/* If adding new types, make sure to keep them ordered by type num */
npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag,
npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag,
npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag, npy::half_tag,
npy::float_tag, npy::double_tag, npy::longdouble_tag, npy::cfloat_tag,
npy::cdouble_tag, npy::clongdouble_tag, npy::datetime_tag,
npy::timedelta_tag>;

static constexpr std::array<value_type, taglist::size> map = make_binsearch_map(taglist());
/* If adding new types, make sure to keep them ordered by type num
*/
npy::bool_tag, npy::byte_tag, npy::ubyte_tag, npy::short_tag,
npy::ushort_tag, npy::int_tag, npy::uint_tag, npy::long_tag,
npy::ulong_tag, npy::longlong_tag, npy::ulonglong_tag,
npy::half_tag, npy::float_tag, npy::double_tag,
npy::longdouble_tag, npy::cfloat_tag, npy::cdouble_tag,
npy::clongdouble_tag, npy::datetime_tag, npy::timedelta_tag>;

static constexpr std::array<value_type, taglist::size> map =
make_binsearch_map(taglist());
};
template<arg_t arg>
constexpr std::array<typename binsearch_t<arg>::value_type, binsearch_t<arg>::taglist::size> binsearch_t<arg>::map;

template <arg_t arg>
constexpr std::array<typename binsearch_t<arg>::value_type,
binsearch_t<arg>::taglist::size>
binsearch_t<arg>::map;

template<arg_t arg>
template <arg_t arg>
static NPY_INLINE typename binsearch_t<arg>::function_type
_get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
{
using binsearch = binsearch_t<arg>;
npy_intp nfuncs = binsearch::map.size();;
npy_intp nfuncs = binsearch::map.size();
npy_intp min_idx = 0;
npy_intp max_idx = nfuncs;
int type = dtype->type_num;
Expand All @@ -359,8 +372,7 @@ _get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
}
}

if (min_idx < nfuncs &&
binsearch::map[min_idx].typenum == type) {
if (min_idx < nfuncs && binsearch::map[min_idx].typenum == type) {
return binsearch::map[min_idx].binsearch[side];
}

Expand All @@ -371,17 +383,21 @@ _get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
return NULL;
}


/*
*****************************************************************************
** C INTERFACE **
*****************************************************************************
*/
extern "C" {
NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) {
return _get_binsearch_func<noarg>(dtype, side);
}
NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side) {
return _get_binsearch_func<arg>(dtype, side);
}
NPY_NO_EXPORT PyArray_BinSearchFunc *
get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
{
return _get_binsearch_func<noarg>(dtype, side);
}

NPY_NO_EXPORT PyArray_ArgBinSearchFunc *
get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
{
return _get_binsearch_func<arg>(dtype, side);
}
}
Loading

0 comments on commit d14d94c

Please sign in to comment.