142 lines
4.6 KiB
C++
142 lines
4.6 KiB
C++
/*
|
|
* Copyright 2008-2012 NVIDIA Corporation
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <thrust/detail/config.h>
|
|
#include <thrust/find.h>
|
|
#include <thrust/reduce.h>
|
|
|
|
#include <thrust/tuple.h>
|
|
#include <thrust/extrema.h>
|
|
#include <thrust/iterator/counting_iterator.h>
|
|
#include <thrust/iterator/transform_iterator.h>
|
|
#include <thrust/iterator/zip_iterator.h>
|
|
#include <thrust/detail/internal_functional.h>
|
|
|
|
|
|
// Contributed by Erich Elsen
|
|
|
|
namespace thrust
|
|
{
|
|
namespace system
|
|
{
|
|
namespace detail
|
|
{
|
|
namespace generic
|
|
{
|
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename T>
|
|
InputIterator find(thrust::execution_policy<DerivedPolicy> &exec,
|
|
InputIterator first,
|
|
InputIterator last,
|
|
const T& value)
|
|
{
|
|
// XXX consider a placeholder expression here
|
|
return thrust::find_if(exec, first, last, thrust::detail::equal_to_value<T>(value));
|
|
} // end find()
|
|
|
|
|
|
template<typename TupleType>
|
|
struct find_if_functor
|
|
{
|
|
__host__ __device__
|
|
TupleType operator()(const TupleType& lhs, const TupleType& rhs) const
|
|
{
|
|
// select the smallest index among true results
|
|
if (thrust::get<0>(lhs) && thrust::get<0>(rhs))
|
|
return TupleType(true, (thrust::min)(thrust::get<1>(lhs), thrust::get<1>(rhs)));
|
|
else if (thrust::get<0>(lhs))
|
|
return lhs;
|
|
else
|
|
return rhs;
|
|
}
|
|
};
|
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename Predicate>
|
|
InputIterator find_if(thrust::execution_policy<DerivedPolicy> &exec,
|
|
InputIterator first,
|
|
InputIterator last,
|
|
Predicate pred)
|
|
{
|
|
typedef typename thrust::iterator_traits<InputIterator>::difference_type difference_type;
|
|
typedef typename thrust::tuple<bool,difference_type> result_type;
|
|
|
|
// empty sequence
|
|
if (first == last)
|
|
return last;
|
|
|
|
const difference_type n = thrust::distance(first, last);
|
|
|
|
// this implementation breaks up the sequence into separate intervals
|
|
// in an attempt to early-out as soon as a value is found
|
|
|
|
// TODO incorporate sizeof(InputType) into interval_threshold and round to multiple of 32
|
|
const difference_type interval_threshold = 1 << 20;
|
|
const difference_type interval_size = (std::min)(interval_threshold, n);
|
|
|
|
// force transform_iterator output to bool
|
|
typedef thrust::transform_iterator<Predicate, InputIterator, bool> XfrmIterator;
|
|
typedef thrust::tuple<XfrmIterator, thrust::counting_iterator<difference_type> > IteratorTuple;
|
|
typedef thrust::zip_iterator<IteratorTuple> ZipIterator;
|
|
|
|
IteratorTuple iter_tuple = thrust::make_tuple(XfrmIterator(first, pred),
|
|
thrust::counting_iterator<difference_type>(0));
|
|
|
|
ZipIterator begin = thrust::make_zip_iterator(iter_tuple);
|
|
ZipIterator end = begin + n;
|
|
|
|
for(ZipIterator interval_begin = begin; interval_begin < end; interval_begin += interval_size)
|
|
{
|
|
ZipIterator interval_end = interval_begin + interval_size;
|
|
if(end < interval_end)
|
|
{
|
|
interval_end = end;
|
|
} // end if
|
|
|
|
result_type result = thrust::reduce(exec,
|
|
interval_begin, interval_end,
|
|
result_type(false,interval_end - begin),
|
|
find_if_functor<result_type>());
|
|
|
|
// see if we found something
|
|
if (thrust::get<0>(result))
|
|
{
|
|
return first + thrust::get<1>(result);
|
|
}
|
|
}
|
|
|
|
//nothing was found if we reach here...
|
|
return first + n;
|
|
}
|
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename Predicate>
|
|
InputIterator find_if_not(thrust::execution_policy<DerivedPolicy> &exec,
|
|
InputIterator first,
|
|
InputIterator last,
|
|
Predicate pred)
|
|
{
|
|
return thrust::find_if(exec, first, last, thrust::detail::not1(pred));
|
|
} // end find()
|
|
|
|
|
|
} // end namespace generic
|
|
} // end namespace detail
|
|
} // end namespace system
|
|
} // end namespace thrust
|
|
|