You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
342 lines
13 KiB
342 lines
13 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
|
|
/*! \file binary_search.inl |
|
* \brief Inline file for binary_search.h |
|
*/ |
|
|
|
#pragma once |
|
|
|
#include <thrust/detail/config.h> |
|
#include <thrust/distance.h> |
|
#include <thrust/functional.h> |
|
#include <thrust/binary_search.h> |
|
#include <thrust/iterator/zip_iterator.h> |
|
#include <thrust/iterator/iterator_traits.h> |
|
#include <thrust/binary_search.h> |
|
|
|
#include <thrust/for_each.h> |
|
#include <thrust/detail/function.h> |
|
#include <thrust/system/detail/generic/scalar/binary_search.h> |
|
|
|
#include <thrust/detail/temporary_array.h> |
|
#include <thrust/detail/type_traits.h> |
|
|
|
namespace thrust |
|
{ |
|
namespace detail |
|
{ |
|
|
|
// XXX WAR circular #inclusion with this forward declaration |
|
template<typename,typename> class temporary_array; |
|
|
|
} // end detail |
|
namespace system |
|
{ |
|
namespace detail |
|
{ |
|
namespace generic |
|
{ |
|
namespace detail |
|
{ |
|
|
|
|
|
// short names to avoid nvcc bug |
|
struct lbf |
|
{ |
|
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering> |
|
__host__ __device__ |
|
typename thrust::iterator_traits<RandomAccessIterator>::difference_type |
|
operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp) |
|
{ |
|
return thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp) - begin; |
|
} |
|
}; |
|
|
|
struct ubf |
|
{ |
|
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering> |
|
__host__ __device__ |
|
typename thrust::iterator_traits<RandomAccessIterator>::difference_type |
|
operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){ |
|
return thrust::system::detail::generic::scalar::upper_bound(begin, end, value, comp) - begin; |
|
} |
|
}; |
|
|
|
struct bsf |
|
{ |
|
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering> |
|
__host__ __device__ |
|
bool operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){ |
|
RandomAccessIterator iter = thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp); |
|
|
|
thrust::detail::host_device_function<StrictWeakOrdering,bool> wrapped_comp(comp); |
|
|
|
return iter != end && !wrapped_comp(value, *iter); |
|
} |
|
}; |
|
|
|
|
|
template <typename ForwardIterator, typename StrictWeakOrdering, typename BinarySearchFunction> |
|
struct binary_search_functor |
|
{ |
|
ForwardIterator begin; |
|
ForwardIterator end; |
|
StrictWeakOrdering comp; |
|
BinarySearchFunction func; |
|
|
|
binary_search_functor(ForwardIterator begin, ForwardIterator end, StrictWeakOrdering comp, BinarySearchFunction func) |
|
: begin(begin), end(end), comp(comp), func(func) {} |
|
|
|
template <typename Tuple> |
|
__host__ __device__ |
|
void operator()(Tuple t) |
|
{ |
|
thrust::get<1>(t) = func(begin, end, thrust::get<0>(t), comp); |
|
} |
|
}; // binary_search_functor |
|
|
|
|
|
// Vector Implementation |
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering, typename BinarySearchFunction> |
|
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output, |
|
StrictWeakOrdering comp, |
|
BinarySearchFunction func) |
|
{ |
|
thrust::for_each(exec, |
|
thrust::make_zip_iterator(thrust::make_tuple(values_begin, output)), |
|
thrust::make_zip_iterator(thrust::make_tuple(values_end, output + thrust::distance(values_begin, values_end))), |
|
detail::binary_search_functor<ForwardIterator, StrictWeakOrdering, BinarySearchFunction>(begin, end, comp, func)); |
|
|
|
return output + thrust::distance(values_begin, values_end); |
|
} |
|
|
|
|
|
|
|
// Scalar Implementation |
|
template <typename OutputType, typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering, typename BinarySearchFunction> |
|
OutputType binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value, |
|
StrictWeakOrdering comp, |
|
BinarySearchFunction func) |
|
{ |
|
// use the vectorized path to implement the scalar version |
|
|
|
// allocate device buffers for value and output |
|
thrust::detail::temporary_array<T,DerivedPolicy> d_value(exec,1); |
|
thrust::detail::temporary_array<OutputType,DerivedPolicy> d_output(exec,1); |
|
|
|
// copy value to device |
|
d_value[0] = value; |
|
|
|
// perform the query |
|
thrust::system::detail::generic::detail::binary_search(exec, begin, end, d_value.begin(), d_value.end(), d_output.begin(), comp, func); |
|
|
|
// copy result to host and return |
|
return d_output[0]; |
|
} |
|
|
|
} // end namespace detail |
|
|
|
|
|
////////////////////// |
|
// Scalar Functions // |
|
////////////////////// |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T> |
|
ForwardIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value) |
|
{ |
|
return thrust::lower_bound(exec, begin, end, value, thrust::less<T>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering> |
|
ForwardIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value, |
|
StrictWeakOrdering comp) |
|
{ |
|
typedef typename thrust::iterator_traits<ForwardIterator>::difference_type difference_type; |
|
|
|
return begin + detail::binary_search<difference_type>(exec, begin, end, value, comp, detail::lbf()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T> |
|
ForwardIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value) |
|
{ |
|
return thrust::upper_bound(exec, begin, end, value, thrust::less<T>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering> |
|
ForwardIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value, |
|
StrictWeakOrdering comp) |
|
{ |
|
typedef typename thrust::iterator_traits<ForwardIterator>::difference_type difference_type; |
|
|
|
return begin + detail::binary_search<difference_type>(exec, begin, end, value, comp, detail::ubf()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T> |
|
bool binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value) |
|
{ |
|
return thrust::binary_search(exec, begin, end, value, thrust::less<T>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering> |
|
bool binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
const T& value, |
|
StrictWeakOrdering comp) |
|
{ |
|
return detail::binary_search<bool>(exec, begin, end, value, comp, detail::bsf()); |
|
} |
|
|
|
|
|
////////////////////// |
|
// Vector Functions // |
|
////////////////////// |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator> |
|
OutputIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output) |
|
{ |
|
typedef typename thrust::iterator_value<InputIterator>::type ValueType; |
|
|
|
return thrust::lower_bound(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering> |
|
OutputIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output, |
|
StrictWeakOrdering comp) |
|
{ |
|
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::lbf()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator> |
|
OutputIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output) |
|
{ |
|
typedef typename thrust::iterator_value<InputIterator>::type ValueType; |
|
|
|
return thrust::upper_bound(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering> |
|
OutputIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output, |
|
StrictWeakOrdering comp) |
|
{ |
|
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::ubf()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator> |
|
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output) |
|
{ |
|
typedef typename thrust::iterator_value<InputIterator>::type ValueType; |
|
|
|
return thrust::binary_search(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>()); |
|
} |
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering> |
|
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator begin, |
|
ForwardIterator end, |
|
InputIterator values_begin, |
|
InputIterator values_end, |
|
OutputIterator output, |
|
StrictWeakOrdering comp) |
|
{ |
|
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::bsf()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename LessThanComparable> |
|
thrust::pair<ForwardIterator,ForwardIterator> |
|
equal_range(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator first, |
|
ForwardIterator last, |
|
const LessThanComparable &value) |
|
{ |
|
return thrust::equal_range(exec, first, last, value, thrust::less<LessThanComparable>()); |
|
} |
|
|
|
|
|
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering> |
|
thrust::pair<ForwardIterator,ForwardIterator> |
|
equal_range(thrust::execution_policy<DerivedPolicy> &exec, |
|
ForwardIterator first, |
|
ForwardIterator last, |
|
const T &value, |
|
StrictWeakOrdering comp) |
|
{ |
|
ForwardIterator lb = thrust::lower_bound(exec, first, last, value, comp); |
|
ForwardIterator ub = thrust::upper_bound(exec, first, last, value, comp); |
|
return thrust::make_pair(lb, ub); |
|
} |
|
|
|
|
|
} // end namespace generic |
|
} // end namespace detail |
|
} // end namespace system |
|
} // end namespace thrust |
|
|
|
|