/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include namespace thrust { namespace system { namespace tbb { namespace detail { namespace sort_detail { // TODO tune this based on data type and comp const static int threshold = 128 * 1024; template void merge_sort(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace); template struct merge_sort_closure { execution_policy &exec; Iterator1 first1, last1; Iterator2 first2; StrictWeakOrdering comp; bool inplace; merge_sort_closure(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace) : exec(exec), first1(first1), last1(last1), first2(first2), comp(comp), inplace(inplace) {} void operator()(void) const { merge_sort(exec, first1, last1, first2, comp, inplace); } }; template void merge_sort(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace) { typedef typename thrust::iterator_difference::type difference_type; difference_type n = thrust::distance(first1, last1); if (n < threshold) { thrust::system::detail::internal::scalar::stable_sort(first1, last1, comp); if (!inplace) thrust::system::detail::internal::scalar::copy(first1, last1, first2); return; } Iterator1 mid1 = first1 + (n / 2); Iterator2 mid2 = first2 + (n / 2); Iterator2 last2 = first2 + n; typedef merge_sort_closure Closure; Closure left (exec, first1, mid1, first2, comp, !inplace); Closure right(exec, mid1, last1, mid2, comp, !inplace); ::tbb::parallel_invoke(left, right); if (inplace) thrust::merge(exec, first2, mid2, mid2, last2, first1, comp); else thrust::merge(exec, first1, mid1, mid1, last1, first2, comp); } } // end namespace sort_detail namespace sort_by_key_detail { // TODO tune this based on data type and comp const static int threshold = 128 * 1024; template void merge_sort_by_key(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator3 first3, Iterator4 first4, StrictWeakOrdering comp, bool inplace); template struct merge_sort_by_key_closure { execution_policy &exec; Iterator1 first1, last1; Iterator2 first2; Iterator3 first3; Iterator4 first4; StrictWeakOrdering comp; bool inplace; merge_sort_by_key_closure(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator3 first3, Iterator4 first4, StrictWeakOrdering comp, bool inplace) : exec(exec), first1(first1), last1(last1), first2(first2), first3(first3), first4(first4), comp(comp), inplace(inplace) {} void operator()(void) const { merge_sort_by_key(exec, first1, last1, first2, first3, first4, comp, inplace); } }; template void merge_sort_by_key(execution_policy &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterator3 first3, Iterator4 first4, StrictWeakOrdering comp, bool inplace) { typedef typename thrust::iterator_difference::type difference_type; difference_type n = thrust::distance(first1, last1); Iterator1 mid1 = first1 + (n / 2); Iterator2 mid2 = first2 + (n / 2); Iterator3 mid3 = first3 + (n / 2); Iterator4 mid4 = first4 + (n / 2); Iterator2 last2 = first2 + n; Iterator3 last3 = first3 + n; if (n < threshold) { thrust::system::detail::internal::scalar::stable_sort_by_key(first1, last1, first2, comp); if (!inplace) { thrust::system::detail::internal::scalar::copy(first1, last1, first3); thrust::system::detail::internal::scalar::copy(first2, last2, first4); } return; } typedef merge_sort_by_key_closure Closure; Closure left (exec, first1, mid1, first2, first3, first4, comp, !inplace); Closure right(exec, mid1, last1, mid2, mid3, mid4, comp, !inplace); ::tbb::parallel_invoke(left, right); if(inplace) { thrust::merge_by_key(exec, first3, mid3, mid3, last3, first4, mid4, first1, first2, comp); } else { thrust::merge_by_key(exec, first1, mid1, mid1, last1, first2, mid2, first3, first4, comp); } } } // end namespace sort_detail template void stable_sort(execution_policy &exec, RandomAccessIterator first, RandomAccessIterator last, StrictWeakOrdering comp) { typedef typename thrust::iterator_value::type key_type; thrust::detail::temporary_array temp(exec, first, last); sort_detail::merge_sort(exec, first, last, temp.begin(), comp, true); } template void stable_sort_by_key(execution_policy &exec, RandomAccessIterator1 first1, RandomAccessIterator1 last1, RandomAccessIterator2 first2, StrictWeakOrdering comp) { typedef typename thrust::iterator_value::type key_type; typedef typename thrust::iterator_value::type val_type; RandomAccessIterator2 last2 = first2 + thrust::distance(first1, last1); thrust::detail::temporary_array temp1(exec, first1, last1); thrust::detail::temporary_array temp2(exec, first2, last2); sort_by_key_detail::merge_sort_by_key(exec, first1, last1, first2, temp1.begin(), temp2.begin(), comp, true); } } // end namespace detail } // end namespace tbb } // end namespace system } // end namespace thrust