343 lines
13 KiB
Plaintext
Raw Normal View History

2014-03-18 22:17:40 +01:00
/*
* Copyright 2008-2012 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file binary_search.inl
* \brief Inline file for binary_search.h
*/
#pragma once
#include <thrust/detail/config.h>
#include <thrust/distance.h>
#include <thrust/functional.h>
#include <thrust/binary_search.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/binary_search.h>
#include <thrust/for_each.h>
#include <thrust/detail/function.h>
#include <thrust/system/detail/generic/scalar/binary_search.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/detail/type_traits.h>
namespace thrust
{
namespace detail
{
// XXX WAR circular #inclusion with this forward declaration
template<typename,typename> class temporary_array;
} // end detail
namespace system
{
namespace detail
{
namespace generic
{
namespace detail
{
// short names to avoid nvcc bug
struct lbf
{
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering>
__host__ __device__
typename thrust::iterator_traits<RandomAccessIterator>::difference_type
operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp)
{
return thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp) - begin;
}
};
struct ubf
{
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering>
__host__ __device__
typename thrust::iterator_traits<RandomAccessIterator>::difference_type
operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){
return thrust::system::detail::generic::scalar::upper_bound(begin, end, value, comp) - begin;
}
};
struct bsf
{
template <typename RandomAccessIterator, typename T, typename StrictWeakOrdering>
__host__ __device__
bool operator()(RandomAccessIterator begin, RandomAccessIterator end, const T& value, StrictWeakOrdering comp){
RandomAccessIterator iter = thrust::system::detail::generic::scalar::lower_bound(begin, end, value, comp);
thrust::detail::host_device_function<StrictWeakOrdering,bool> wrapped_comp(comp);
return iter != end && !wrapped_comp(value, *iter);
}
};
template <typename ForwardIterator, typename StrictWeakOrdering, typename BinarySearchFunction>
struct binary_search_functor
{
ForwardIterator begin;
ForwardIterator end;
StrictWeakOrdering comp;
BinarySearchFunction func;
binary_search_functor(ForwardIterator begin, ForwardIterator end, StrictWeakOrdering comp, BinarySearchFunction func)
: begin(begin), end(end), comp(comp), func(func) {}
template <typename Tuple>
__host__ __device__
void operator()(Tuple t)
{
thrust::get<1>(t) = func(begin, end, thrust::get<0>(t), comp);
}
}; // binary_search_functor
// Vector Implementation
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering, typename BinarySearchFunction>
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output,
StrictWeakOrdering comp,
BinarySearchFunction func)
{
thrust::for_each(exec,
thrust::make_zip_iterator(thrust::make_tuple(values_begin, output)),
thrust::make_zip_iterator(thrust::make_tuple(values_end, output + thrust::distance(values_begin, values_end))),
detail::binary_search_functor<ForwardIterator, StrictWeakOrdering, BinarySearchFunction>(begin, end, comp, func));
return output + thrust::distance(values_begin, values_end);
}
// Scalar Implementation
template <typename OutputType, typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering, typename BinarySearchFunction>
OutputType binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value,
StrictWeakOrdering comp,
BinarySearchFunction func)
{
// use the vectorized path to implement the scalar version
// allocate device buffers for value and output
thrust::detail::temporary_array<T,DerivedPolicy> d_value(exec,1);
thrust::detail::temporary_array<OutputType,DerivedPolicy> d_output(exec,1);
// copy value to device
d_value[0] = value;
// perform the query
thrust::system::detail::generic::detail::binary_search(exec, begin, end, d_value.begin(), d_value.end(), d_output.begin(), comp, func);
// copy result to host and return
return d_output[0];
}
} // end namespace detail
//////////////////////
// Scalar Functions //
//////////////////////
template <typename DerivedPolicy, typename ForwardIterator, typename T>
ForwardIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value)
{
return thrust::lower_bound(exec, begin, end, value, thrust::less<T>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering>
ForwardIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value,
StrictWeakOrdering comp)
{
typedef typename thrust::iterator_traits<ForwardIterator>::difference_type difference_type;
return begin + detail::binary_search<difference_type>(exec, begin, end, value, comp, detail::lbf());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T>
ForwardIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value)
{
return thrust::upper_bound(exec, begin, end, value, thrust::less<T>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering>
ForwardIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value,
StrictWeakOrdering comp)
{
typedef typename thrust::iterator_traits<ForwardIterator>::difference_type difference_type;
return begin + detail::binary_search<difference_type>(exec, begin, end, value, comp, detail::ubf());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T>
bool binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value)
{
return thrust::binary_search(exec, begin, end, value, thrust::less<T>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering>
bool binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
const T& value,
StrictWeakOrdering comp)
{
return detail::binary_search<bool>(exec, begin, end, value, comp, detail::bsf());
}
//////////////////////
// Vector Functions //
//////////////////////
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator>
OutputIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output)
{
typedef typename thrust::iterator_value<InputIterator>::type ValueType;
return thrust::lower_bound(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering>
OutputIterator lower_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output,
StrictWeakOrdering comp)
{
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::lbf());
}
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator>
OutputIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output)
{
typedef typename thrust::iterator_value<InputIterator>::type ValueType;
return thrust::upper_bound(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering>
OutputIterator upper_bound(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output,
StrictWeakOrdering comp)
{
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::ubf());
}
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator>
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output)
{
typedef typename thrust::iterator_value<InputIterator>::type ValueType;
return thrust::binary_search(exec, begin, end, values_begin, values_end, output, thrust::less<ValueType>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename InputIterator, typename OutputIterator, typename StrictWeakOrdering>
OutputIterator binary_search(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator begin,
ForwardIterator end,
InputIterator values_begin,
InputIterator values_end,
OutputIterator output,
StrictWeakOrdering comp)
{
return detail::binary_search(exec, begin, end, values_begin, values_end, output, comp, detail::bsf());
}
template <typename DerivedPolicy, typename ForwardIterator, typename LessThanComparable>
thrust::pair<ForwardIterator,ForwardIterator>
equal_range(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator first,
ForwardIterator last,
const LessThanComparable &value)
{
return thrust::equal_range(exec, first, last, value, thrust::less<LessThanComparable>());
}
template <typename DerivedPolicy, typename ForwardIterator, typename T, typename StrictWeakOrdering>
thrust::pair<ForwardIterator,ForwardIterator>
equal_range(thrust::execution_policy<DerivedPolicy> &exec,
ForwardIterator first,
ForwardIterator last,
const T &value,
StrictWeakOrdering comp)
{
ForwardIterator lb = thrust::lower_bound(exec, first, last, value, comp);
ForwardIterator ub = thrust::upper_bound(exec, first, last, value, comp);
return thrust::make_pair(lb, ub);
}
} // end namespace generic
} // end namespace detail
} // end namespace system
} // end namespace thrust