You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
4.6 KiB
141 lines
4.6 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
#include <thrust/detail/config.h> |
|
#include <thrust/find.h> |
|
#include <thrust/reduce.h> |
|
|
|
#include <thrust/tuple.h> |
|
#include <thrust/extrema.h> |
|
#include <thrust/iterator/counting_iterator.h> |
|
#include <thrust/iterator/transform_iterator.h> |
|
#include <thrust/iterator/zip_iterator.h> |
|
#include <thrust/detail/internal_functional.h> |
|
|
|
|
|
// Contributed by Erich Elsen |
|
|
|
namespace thrust |
|
{ |
|
namespace system |
|
{ |
|
namespace detail |
|
{ |
|
namespace generic |
|
{ |
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename T> |
|
InputIterator find(thrust::execution_policy<DerivedPolicy> &exec, |
|
InputIterator first, |
|
InputIterator last, |
|
const T& value) |
|
{ |
|
// XXX consider a placeholder expression here |
|
return thrust::find_if(exec, first, last, thrust::detail::equal_to_value<T>(value)); |
|
} // end find() |
|
|
|
|
|
template<typename TupleType> |
|
struct find_if_functor |
|
{ |
|
__host__ __device__ |
|
TupleType operator()(const TupleType& lhs, const TupleType& rhs) const |
|
{ |
|
// select the smallest index among true results |
|
if (thrust::get<0>(lhs) && thrust::get<0>(rhs)) |
|
return TupleType(true, (thrust::min)(thrust::get<1>(lhs), thrust::get<1>(rhs))); |
|
else if (thrust::get<0>(lhs)) |
|
return lhs; |
|
else |
|
return rhs; |
|
} |
|
}; |
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename Predicate> |
|
InputIterator find_if(thrust::execution_policy<DerivedPolicy> &exec, |
|
InputIterator first, |
|
InputIterator last, |
|
Predicate pred) |
|
{ |
|
typedef typename thrust::iterator_traits<InputIterator>::difference_type difference_type; |
|
typedef typename thrust::tuple<bool,difference_type> result_type; |
|
|
|
// empty sequence |
|
if (first == last) |
|
return last; |
|
|
|
const difference_type n = thrust::distance(first, last); |
|
|
|
// this implementation breaks up the sequence into separate intervals |
|
// in an attempt to early-out as soon as a value is found |
|
|
|
// TODO incorporate sizeof(InputType) into interval_threshold and round to multiple of 32 |
|
const difference_type interval_threshold = 1 << 20; |
|
const difference_type interval_size = (std::min)(interval_threshold, n); |
|
|
|
// force transform_iterator output to bool |
|
typedef thrust::transform_iterator<Predicate, InputIterator, bool> XfrmIterator; |
|
typedef thrust::tuple<XfrmIterator, thrust::counting_iterator<difference_type> > IteratorTuple; |
|
typedef thrust::zip_iterator<IteratorTuple> ZipIterator; |
|
|
|
IteratorTuple iter_tuple = thrust::make_tuple(XfrmIterator(first, pred), |
|
thrust::counting_iterator<difference_type>(0)); |
|
|
|
ZipIterator begin = thrust::make_zip_iterator(iter_tuple); |
|
ZipIterator end = begin + n; |
|
|
|
for(ZipIterator interval_begin = begin; interval_begin < end; interval_begin += interval_size) |
|
{ |
|
ZipIterator interval_end = interval_begin + interval_size; |
|
if(end < interval_end) |
|
{ |
|
interval_end = end; |
|
} // end if |
|
|
|
result_type result = thrust::reduce(exec, |
|
interval_begin, interval_end, |
|
result_type(false,interval_end - begin), |
|
find_if_functor<result_type>()); |
|
|
|
// see if we found something |
|
if (thrust::get<0>(result)) |
|
{ |
|
return first + thrust::get<1>(result); |
|
} |
|
} |
|
|
|
//nothing was found if we reach here... |
|
return first + n; |
|
} |
|
|
|
|
|
template<typename DerivedPolicy, typename InputIterator, typename Predicate> |
|
InputIterator find_if_not(thrust::execution_policy<DerivedPolicy> &exec, |
|
InputIterator first, |
|
InputIterator last, |
|
Predicate pred) |
|
{ |
|
return thrust::find_if(exec, first, last, thrust::detail::not1(pred)); |
|
} // end find() |
|
|
|
|
|
} // end namespace generic |
|
} // end namespace detail |
|
} // end namespace system |
|
} // end namespace thrust |
|
|
|
|