/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! \file binary_search.inl * \brief Inline file for binary_search.h */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace thrust { namespace detail { // XXX WAR circular #inclusion with this forward declaration template class temporary_array; } // end detail namespace system { namespace detail { namespace generic { namespace detail { // short names to avoid nvcc bug struct lbf { template __host__ __device__ typename thrust::iterator_traits::difference_type operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp) { return thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp) - begin; } }; struct ubf { template __host__ __device__ typename thrust::iterator_traits::difference_type operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){ return thrust::system::detail::generic::scalar::upper_bound(begin, end, value, comp) - begin; } }; struct bsf { template __host__ __device__ bool operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){ RandomAccessIterator iter = thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp); thrust::detail::host_device_function wrapped_comp(comp); return iter != end && !wrapped_comp(value, *iter); } }; template struct binary_search_functor { ForwardIterator begin; ForwardIterator end; StrictWeakOrdering comp; BinarySearchFunction func; binary_search_functor(ForwardIterator begin, ForwardIterator end, StrictWeakOrdering comp, BinarySearchFunction func) : begin(begin), end(end), comp(comp), func(func) {} template __host__ __device__ void operator()(Tuple t) { thrust::get<1>(t) = func(begin, end, thrust::get<0>(t), comp); } }; // binary_search_functor // Vector Implementation template OutputIterator binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output, StrictWeakOrdering comp, BinarySearchFunction func) { thrust::for_each(exec, thrust::make_zip_iterator(thrust::make_tuple(values_begin, output)), thrust::make_zip_iterator(thrust::make_tuple(values_end, output + thrust::distance(values_begin, values_end))), detail::binary_search_functor(begin, end, comp, func)); return output + thrust::distance(values_begin, values_end); } // Scalar Implementation template OutputType binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value, StrictWeakOrdering comp, BinarySearchFunction func) { // use the vectorized path to implement the scalar version // allocate device buffers for value and output thrust::detail::temporary_array d_value(exec,1); thrust::detail::temporary_array d_output(exec,1); // copy value to device d_value[0] = value; // perform the query thrust::system::detail::generic::detail::binary_search(exec, begin, end, d_value.begin(), d_value.end(), d_output.begin(), comp, func); // copy result to host and return return d_output[0]; } } // end namespace detail ////////////////////// // Scalar Functions // ////////////////////// template ForwardIterator lower_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value) { return thrust::lower_bound(exec, begin, end, value, thrust::less()); } template ForwardIterator lower_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value, StrictWeakOrdering comp) { typedef typename thrust::iterator_traits::difference_type difference_type; return begin + detail::binary_search(exec, begin, end, value, comp, detail::lbf()); } template ForwardIterator upper_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value) { return thrust::upper_bound(exec, begin, end, value, thrust::less()); } template ForwardIterator upper_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value, StrictWeakOrdering comp) { typedef typename thrust::iterator_traits::difference_type difference_type; return begin + detail::binary_search(exec, begin, end, value, comp, detail::ubf()); } template bool binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value) { return thrust::binary_search(exec, begin, end, value, thrust::less()); } template bool binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, const T& value, StrictWeakOrdering comp) { return detail::binary_search(exec, begin, end, value, comp, detail::bsf()); } ////////////////////// // Vector Functions // ////////////////////// template OutputIterator lower_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output) { typedef typename thrust::iterator_value::type ValueType; return thrust::lower_bound(exec, begin, end, values_begin, values_end, output, thrust::less()); } template OutputIterator lower_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output, StrictWeakOrdering comp) { return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::lbf()); } template OutputIterator upper_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output) { typedef typename thrust::iterator_value::type ValueType; return thrust::upper_bound(exec, begin, end, values_begin, values_end, output, thrust::less()); } template OutputIterator upper_bound(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output, StrictWeakOrdering comp) { return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::ubf()); } template OutputIterator binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output) { typedef typename thrust::iterator_value::type ValueType; return thrust::binary_search(exec, begin, end, values_begin, values_end, output, thrust::less()); } template OutputIterator binary_search(thrust::execution_policy &exec, ForwardIterator begin, ForwardIterator end, InputIterator values_begin, InputIterator values_end, OutputIterator output, StrictWeakOrdering comp) { return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::bsf()); } template thrust::pair equal_range(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last, const LessThanComparable &value) { return thrust::equal_range(exec, first, last, value, thrust::less()); } template thrust::pair equal_range(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last, const T &value, StrictWeakOrdering comp) { ForwardIterator lb = thrust::lower_bound(exec, first, last, value, comp); ForwardIterator ub = thrust::upper_bound(exec, first, last, value, comp); return thrust::make_pair(lb, ub); } } // end namespace generic } // end namespace detail } // end namespace system } // end namespace thrust