/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include namespace thrust { namespace system { namespace tbb { namespace detail { namespace merge_detail { template struct range { InputIterator1 first1, last1; InputIterator2 first2, last2; OutputIterator result; StrictWeakOrdering comp; size_t grain_size; range(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, StrictWeakOrdering comp, size_t grain_size = 1024) : first1(first1), last1(last1), first2(first2), last2(last2), result(result), comp(comp), grain_size(grain_size) {} range(range& r, ::tbb::split) : first1(r.first1), last1(r.last1), first2(r.first2), last2(r.last2), result(r.result), comp(r.comp), grain_size(r.grain_size) { // we can assume n1 and n2 are not both 0 size_t n1 = thrust::distance(first1, last1); size_t n2 = thrust::distance(first2, last2); InputIterator1 mid1 = first1; InputIterator2 mid2 = first2; if (n1 > n2) { mid1 += n1 / 2; mid2 = thrust::system::detail::internal::scalar::lower_bound(first2, last2, raw_reference_cast(*mid1), comp); } else { mid2 += n2 / 2; mid1 = thrust::system::detail::internal::scalar::upper_bound(first1, last1, raw_reference_cast(*mid2), comp); } // set first range to [first1, mid1), [first2, mid2), result r.last1 = mid1; r.last2 = mid2; // set second range to [mid1, last1), [mid2, last2), result + (mid1 - first1) + (mid2 - first2) first1 = mid1; first2 = mid2; result += thrust::distance(r.first1, mid1) + thrust::distance(r.first2, mid2); } bool empty(void) const { return (first1 == last1) && (first2 == last2); } bool is_divisible(void) const { return static_cast(thrust::distance(first1, last1) + thrust::distance(first2, last2)) > grain_size; } }; struct body { template void operator()(Range& r) const { thrust::system::detail::internal::scalar::merge (r.first1, r.last1, r.first2, r.last2, r.result, r.comp); } }; } // end namespace merge_detail namespace merge_by_key_detail { template struct range { InputIterator1 keys_first1, keys_last1; InputIterator2 keys_first2, keys_last2; InputIterator3 values_first1; InputIterator4 values_first2; OutputIterator1 keys_result; OutputIterator2 values_result; StrictWeakOrdering comp; size_t grain_size; range(InputIterator1 keys_first1, InputIterator1 keys_last1, InputIterator2 keys_first2, InputIterator2 keys_last2, InputIterator3 values_first1, InputIterator4 values_first2, OutputIterator1 keys_result, OutputIterator2 values_result, StrictWeakOrdering comp, size_t grain_size = 1024) : keys_first1(keys_first1), keys_last1(keys_last1), keys_first2(keys_first2), keys_last2(keys_last2), values_first1(values_first1), values_first2(values_first2), keys_result(keys_result), values_result(values_result), comp(comp), grain_size(grain_size) {} range(range& r, ::tbb::split) : keys_first1(r.keys_first1), keys_last1(r.keys_last1), keys_first2(r.keys_first2), keys_last2(r.keys_last2), values_first1(r.values_first1), values_first2(r.values_first2), keys_result(r.keys_result), values_result(r.values_result), comp(r.comp), grain_size(r.grain_size) { // we can assume n1 and n2 are not both 0 size_t n1 = thrust::distance(keys_first1, keys_last1); size_t n2 = thrust::distance(keys_first2, keys_last2); InputIterator1 mid1 = keys_first1; InputIterator2 mid2 = keys_first2; if (n1 > n2) { mid1 += n1 / 2; mid2 = thrust::system::detail::internal::scalar::lower_bound(keys_first2, keys_last2, raw_reference_cast(*mid1), comp); } else { mid2 += n2 / 2; mid1 = thrust::system::detail::internal::scalar::upper_bound(keys_first1, keys_last1, raw_reference_cast(*mid2), comp); } // set first range to [keys_first1, mid1), [keys_first2, mid2), keys_result, values_result r.keys_last1 = mid1; r.keys_last2 = mid2; // set second range to [mid1, keys_last1), [mid2, keys_last2), keys_result + (mid1 - keys_first1) + (mid2 - keys_first2), values_result + (mid1 - keys_first1) + (mid2 - keys_first2) keys_first1 = mid1; keys_first2 = mid2; values_first1 += thrust::distance(r.keys_first1, mid1); values_first2 += thrust::distance(r.keys_first2, mid2); keys_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2); values_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2); } bool empty(void) const { return (keys_first1 == keys_last1) && (keys_first2 == keys_last2); } bool is_divisible(void) const { return static_cast(thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)) > grain_size; } }; struct body { template void operator()(Range& r) const { thrust::system::detail::internal::scalar::merge_by_key (r.keys_first1, r.keys_last1, r.keys_first2, r.keys_last2, r.values_first1, r.values_first2, r.keys_result, r.values_result, r.comp); } }; } // end namespace merge_by_key_detail template OutputIterator merge(execution_policy &exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, StrictWeakOrdering comp) { typedef typename merge_detail::range Range; typedef merge_detail::body Body; Range range(first1, last1, first2, last2, result, comp); Body body; ::tbb::parallel_for(range, body); thrust::advance(result, thrust::distance(first1, last1) + thrust::distance(first2, last2)); return result; } // end merge() template thrust::pair merge_by_key(execution_policy &exec, InputIterator1 keys_first1, InputIterator1 keys_last1, InputIterator2 keys_first2, InputIterator2 keys_last2, InputIterator3 values_first3, InputIterator4 values_first4, OutputIterator1 keys_result, OutputIterator2 values_result, StrictWeakOrdering comp) { typedef typename merge_by_key_detail::range Range; typedef merge_by_key_detail::body Body; Range range(keys_first1, keys_last1, keys_first2, keys_last2, values_first3, values_first4, keys_result, values_result, comp); Body body; ::tbb::parallel_for(range, body); thrust::advance(keys_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)); thrust::advance(values_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)); return thrust::make_pair(keys_result,values_result); } } // end namespace detail } // end namespace tbb } // end namespace system } // end namespace thrust