You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
8.1 KiB
251 lines
8.1 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
#include <thrust/detail/config.h> |
|
#include <thrust/detail/temporary_array.h> |
|
#include <thrust/detail/copy.h> |
|
#include <thrust/system/detail/internal/scalar/sort.h> |
|
#include <thrust/iterator/iterator_traits.h> |
|
#include <thrust/distance.h> |
|
#include <thrust/merge.h> |
|
#include <tbb/parallel_invoke.h> |
|
|
|
namespace thrust |
|
{ |
|
namespace system |
|
{ |
|
namespace tbb |
|
{ |
|
namespace detail |
|
{ |
|
namespace sort_detail |
|
{ |
|
|
|
// TODO tune this based on data type and comp |
|
const static int threshold = 128 * 1024; |
|
|
|
template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering> |
|
void merge_sort(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace); |
|
|
|
template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering> |
|
struct merge_sort_closure |
|
{ |
|
execution_policy<DerivedPolicy> &exec; |
|
Iterator1 first1, last1; |
|
Iterator2 first2; |
|
StrictWeakOrdering comp; |
|
bool inplace; |
|
|
|
merge_sort_closure(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace) |
|
: exec(exec), first1(first1), last1(last1), first2(first2), comp(comp), inplace(inplace) |
|
{} |
|
|
|
void operator()(void) const |
|
{ |
|
merge_sort(exec, first1, last1, first2, comp, inplace); |
|
} |
|
}; |
|
|
|
|
|
template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering> |
|
void merge_sort(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace) |
|
{ |
|
typedef typename thrust::iterator_difference<Iterator1>::type difference_type; |
|
|
|
difference_type n = thrust::distance(first1, last1); |
|
|
|
if (n < threshold) |
|
{ |
|
thrust::system::detail::internal::scalar::stable_sort(first1, last1, comp); |
|
|
|
if (!inplace) |
|
thrust::system::detail::internal::scalar::copy(first1, last1, first2); |
|
|
|
return; |
|
} |
|
|
|
Iterator1 mid1 = first1 + (n / 2); |
|
Iterator2 mid2 = first2 + (n / 2); |
|
Iterator2 last2 = first2 + n; |
|
|
|
typedef merge_sort_closure<DerivedPolicy,Iterator1,Iterator2,StrictWeakOrdering> Closure; |
|
|
|
Closure left (exec, first1, mid1, first2, comp, !inplace); |
|
Closure right(exec, mid1, last1, mid2, comp, !inplace); |
|
|
|
::tbb::parallel_invoke(left, right); |
|
|
|
if (inplace) thrust::merge(exec, first2, mid2, mid2, last2, first1, comp); |
|
else thrust::merge(exec, first1, mid1, mid1, last1, first2, comp); |
|
} |
|
|
|
} // end namespace sort_detail |
|
|
|
|
|
namespace sort_by_key_detail |
|
{ |
|
|
|
// TODO tune this based on data type and comp |
|
const static int threshold = 128 * 1024; |
|
|
|
template <typename DerivedPolicy, |
|
typename Iterator1, |
|
typename Iterator2, |
|
typename Iterator3, |
|
typename Iterator4, |
|
typename StrictWeakOrdering> |
|
void merge_sort_by_key(execution_policy<DerivedPolicy> &exec, |
|
Iterator1 first1, |
|
Iterator1 last1, |
|
Iterator2 first2, |
|
Iterator3 first3, |
|
Iterator4 first4, |
|
StrictWeakOrdering comp, |
|
bool inplace); |
|
|
|
template <typename DerivedPolicy, |
|
typename Iterator1, |
|
typename Iterator2, |
|
typename Iterator3, |
|
typename Iterator4, |
|
typename StrictWeakOrdering> |
|
struct merge_sort_by_key_closure |
|
{ |
|
execution_policy<DerivedPolicy> &exec; |
|
Iterator1 first1, last1; |
|
Iterator2 first2; |
|
Iterator3 first3; |
|
Iterator4 first4; |
|
StrictWeakOrdering comp; |
|
bool inplace; |
|
|
|
merge_sort_by_key_closure(execution_policy<DerivedPolicy> &exec, |
|
Iterator1 first1, |
|
Iterator1 last1, |
|
Iterator2 first2, |
|
Iterator3 first3, |
|
Iterator4 first4, |
|
StrictWeakOrdering comp, |
|
bool inplace) |
|
: exec(exec), first1(first1), last1(last1), first2(first2), first3(first3), first4(first4), comp(comp), inplace(inplace) |
|
{} |
|
|
|
void operator()(void) const |
|
{ |
|
merge_sort_by_key(exec, first1, last1, first2, first3, first4, comp, inplace); |
|
} |
|
}; |
|
|
|
|
|
template <typename DerivedPolicy, |
|
typename Iterator1, |
|
typename Iterator2, |
|
typename Iterator3, |
|
typename Iterator4, |
|
typename StrictWeakOrdering> |
|
void merge_sort_by_key(execution_policy<DerivedPolicy> &exec, |
|
Iterator1 first1, |
|
Iterator1 last1, |
|
Iterator2 first2, |
|
Iterator3 first3, |
|
Iterator4 first4, |
|
StrictWeakOrdering comp, |
|
bool inplace) |
|
{ |
|
typedef typename thrust::iterator_difference<Iterator1>::type difference_type; |
|
|
|
difference_type n = thrust::distance(first1, last1); |
|
|
|
Iterator1 mid1 = first1 + (n / 2); |
|
Iterator2 mid2 = first2 + (n / 2); |
|
Iterator3 mid3 = first3 + (n / 2); |
|
Iterator4 mid4 = first4 + (n / 2); |
|
Iterator2 last2 = first2 + n; |
|
Iterator3 last3 = first3 + n; |
|
|
|
if (n < threshold) |
|
{ |
|
thrust::system::detail::internal::scalar::stable_sort_by_key(first1, last1, first2, comp); |
|
|
|
if (!inplace) |
|
{ |
|
thrust::system::detail::internal::scalar::copy(first1, last1, first3); |
|
thrust::system::detail::internal::scalar::copy(first2, last2, first4); |
|
} |
|
|
|
return; |
|
} |
|
|
|
typedef merge_sort_by_key_closure<DerivedPolicy,Iterator1,Iterator2,Iterator3,Iterator4,StrictWeakOrdering> Closure; |
|
|
|
Closure left (exec, first1, mid1, first2, first3, first4, comp, !inplace); |
|
Closure right(exec, mid1, last1, mid2, mid3, mid4, comp, !inplace); |
|
|
|
::tbb::parallel_invoke(left, right); |
|
|
|
if(inplace) |
|
{ |
|
thrust::merge_by_key(exec, first3, mid3, mid3, last3, first4, mid4, first1, first2, comp); |
|
} |
|
else |
|
{ |
|
thrust::merge_by_key(exec, first1, mid1, mid1, last1, first2, mid2, first3, first4, comp); |
|
} |
|
} |
|
|
|
} // end namespace sort_detail |
|
|
|
template<typename DerivedPolicy, |
|
typename RandomAccessIterator, |
|
typename StrictWeakOrdering> |
|
void stable_sort(execution_policy<DerivedPolicy> &exec, |
|
RandomAccessIterator first, |
|
RandomAccessIterator last, |
|
StrictWeakOrdering comp) |
|
{ |
|
typedef typename thrust::iterator_value<RandomAccessIterator>::type key_type; |
|
|
|
thrust::detail::temporary_array<key_type, DerivedPolicy> temp(exec, first, last); |
|
|
|
sort_detail::merge_sort(exec, first, last, temp.begin(), comp, true); |
|
} |
|
|
|
template<typename DerivedPolicy, |
|
typename RandomAccessIterator1, |
|
typename RandomAccessIterator2, |
|
typename StrictWeakOrdering> |
|
void stable_sort_by_key(execution_policy<DerivedPolicy> &exec, |
|
RandomAccessIterator1 first1, |
|
RandomAccessIterator1 last1, |
|
RandomAccessIterator2 first2, |
|
StrictWeakOrdering comp) |
|
{ |
|
typedef typename thrust::iterator_value<RandomAccessIterator1>::type key_type; |
|
typedef typename thrust::iterator_value<RandomAccessIterator2>::type val_type; |
|
|
|
RandomAccessIterator2 last2 = first2 + thrust::distance(first1, last1); |
|
|
|
thrust::detail::temporary_array<key_type, DerivedPolicy> temp1(exec, first1, last1); |
|
thrust::detail::temporary_array<val_type, DerivedPolicy> temp2(exec, first2, last2); |
|
|
|
sort_by_key_detail::merge_sort_by_key(exec, first1, last1, first2, temp1.begin(), temp2.begin(), comp, true); |
|
} |
|
|
|
} // end namespace detail |
|
} // end namespace tbb |
|
} // end namespace system |
|
} // end namespace thrust |
|
|
|
|