286 lines
9.1 KiB
Plaintext
Raw Normal View History

2014-03-18 22:17:40 +01:00
/*
* Copyright 2008-2012 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/iterator/iterator_traits.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/system/tbb/detail/execution_policy.h>
#include <thrust/system/detail/internal/scalar/merge.h>
#include <thrust/system/detail/internal/scalar/binary_search.h>
#include <tbb/parallel_for.h>
namespace thrust
{
namespace system
{
namespace tbb
{
namespace detail
{
namespace merge_detail
{
template<typename InputIterator1,
typename InputIterator2,
typename OutputIterator,
typename StrictWeakOrdering>
struct range
{
InputIterator1 first1, last1;
InputIterator2 first2, last2;
OutputIterator result;
StrictWeakOrdering comp;
size_t grain_size;
range(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result,
StrictWeakOrdering comp,
size_t grain_size = 1024)
: first1(first1), last1(last1),
first2(first2), last2(last2),
result(result), comp(comp), grain_size(grain_size)
{}
range(range& r, ::tbb::split)
: first1(r.first1), last1(r.last1),
first2(r.first2), last2(r.last2),
result(r.result), comp(r.comp), grain_size(r.grain_size)
{
// we can assume n1 and n2 are not both 0
size_t n1 = thrust::distance(first1, last1);
size_t n2 = thrust::distance(first2, last2);
InputIterator1 mid1 = first1;
InputIterator2 mid2 = first2;
if (n1 > n2)
{
mid1 += n1 / 2;
mid2 = thrust::system::detail::internal::scalar::lower_bound(first2, last2, raw_reference_cast(*mid1), comp);
}
else
{
mid2 += n2 / 2;
mid1 = thrust::system::detail::internal::scalar::upper_bound(first1, last1, raw_reference_cast(*mid2), comp);
}
// set first range to [first1, mid1), [first2, mid2), result
r.last1 = mid1;
r.last2 = mid2;
// set second range to [mid1, last1), [mid2, last2), result + (mid1 - first1) + (mid2 - first2)
first1 = mid1;
first2 = mid2;
result += thrust::distance(r.first1, mid1) + thrust::distance(r.first2, mid2);
}
bool empty(void) const
{
return (first1 == last1) && (first2 == last2);
}
bool is_divisible(void) const
{
return static_cast<size_t>(thrust::distance(first1, last1) + thrust::distance(first2, last2)) > grain_size;
}
};
struct body
{
template <typename Range>
void operator()(Range& r) const
{
thrust::system::detail::internal::scalar::merge
(r.first1, r.last1,
r.first2, r.last2,
r.result,
r.comp);
}
};
} // end namespace merge_detail
namespace merge_by_key_detail
{
template<typename InputIterator1,
typename InputIterator2,
typename InputIterator3,
typename InputIterator4,
typename OutputIterator1,
typename OutputIterator2,
typename StrictWeakOrdering>
struct range
{
InputIterator1 keys_first1, keys_last1;
InputIterator2 keys_first2, keys_last2;
InputIterator3 values_first1;
InputIterator4 values_first2;
OutputIterator1 keys_result;
OutputIterator2 values_result;
StrictWeakOrdering comp;
size_t grain_size;
range(InputIterator1 keys_first1, InputIterator1 keys_last1,
InputIterator2 keys_first2, InputIterator2 keys_last2,
InputIterator3 values_first1,
InputIterator4 values_first2,
OutputIterator1 keys_result,
OutputIterator2 values_result,
StrictWeakOrdering comp,
size_t grain_size = 1024)
: keys_first1(keys_first1), keys_last1(keys_last1),
keys_first2(keys_first2), keys_last2(keys_last2),
values_first1(values_first1),
values_first2(values_first2),
keys_result(keys_result), values_result(values_result),
comp(comp), grain_size(grain_size)
{}
range(range& r, ::tbb::split)
: keys_first1(r.keys_first1), keys_last1(r.keys_last1),
keys_first2(r.keys_first2), keys_last2(r.keys_last2),
values_first1(r.values_first1),
values_first2(r.values_first2),
keys_result(r.keys_result), values_result(r.values_result),
comp(r.comp), grain_size(r.grain_size)
{
// we can assume n1 and n2 are not both 0
size_t n1 = thrust::distance(keys_first1, keys_last1);
size_t n2 = thrust::distance(keys_first2, keys_last2);
InputIterator1 mid1 = keys_first1;
InputIterator2 mid2 = keys_first2;
if (n1 > n2)
{
mid1 += n1 / 2;
mid2 = thrust::system::detail::internal::scalar::lower_bound(keys_first2, keys_last2, raw_reference_cast(*mid1), comp);
}
else
{
mid2 += n2 / 2;
mid1 = thrust::system::detail::internal::scalar::upper_bound(keys_first1, keys_last1, raw_reference_cast(*mid2), comp);
}
// set first range to [keys_first1, mid1), [keys_first2, mid2), keys_result, values_result
r.keys_last1 = mid1;
r.keys_last2 = mid2;
// set second range to [mid1, keys_last1), [mid2, keys_last2), keys_result + (mid1 - keys_first1) + (mid2 - keys_first2), values_result + (mid1 - keys_first1) + (mid2 - keys_first2)
keys_first1 = mid1;
keys_first2 = mid2;
values_first1 += thrust::distance(r.keys_first1, mid1);
values_first2 += thrust::distance(r.keys_first2, mid2);
keys_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
values_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2);
}
bool empty(void) const
{
return (keys_first1 == keys_last1) && (keys_first2 == keys_last2);
}
bool is_divisible(void) const
{
return static_cast<size_t>(thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)) > grain_size;
}
};
struct body
{
template <typename Range>
void operator()(Range& r) const
{
thrust::system::detail::internal::scalar::merge_by_key
(r.keys_first1, r.keys_last1,
r.keys_first2, r.keys_last2,
r.values_first1,
r.values_first2,
r.keys_result,
r.values_result,
r.comp);
}
};
} // end namespace merge_by_key_detail
template<typename DerivedPolicy,
typename InputIterator1,
typename InputIterator2,
typename OutputIterator,
typename StrictWeakOrdering>
OutputIterator merge(execution_policy<DerivedPolicy> &exec,
InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
InputIterator2 last2,
OutputIterator result,
StrictWeakOrdering comp)
{
typedef typename merge_detail::range<InputIterator1,InputIterator2,OutputIterator,StrictWeakOrdering> Range;
typedef merge_detail::body Body;
Range range(first1, last1, first2, last2, result, comp);
Body body;
::tbb::parallel_for(range, body);
thrust::advance(result, thrust::distance(first1, last1) + thrust::distance(first2, last2));
return result;
} // end merge()
template <typename DerivedPolicy,
typename InputIterator1,
typename InputIterator2,
typename InputIterator3,
typename InputIterator4,
typename OutputIterator1,
typename OutputIterator2,
typename StrictWeakOrdering>
thrust::pair<OutputIterator1,OutputIterator2>
merge_by_key(execution_policy<DerivedPolicy> &exec,
InputIterator1 keys_first1,
InputIterator1 keys_last1,
InputIterator2 keys_first2,
InputIterator2 keys_last2,
InputIterator3 values_first3,
InputIterator4 values_first4,
OutputIterator1 keys_result,
OutputIterator2 values_result,
StrictWeakOrdering comp)
{
typedef typename merge_by_key_detail::range<InputIterator1,InputIterator2,InputIterator3,InputIterator4,OutputIterator1,OutputIterator2,StrictWeakOrdering> Range;
typedef merge_by_key_detail::body Body;
Range range(keys_first1, keys_last1, keys_first2, keys_last2, values_first3, values_first4, keys_result, values_result, comp);
Body body;
::tbb::parallel_for(range, body);
thrust::advance(keys_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
thrust::advance(values_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2));
return thrust::make_pair(keys_result,values_result);
}
} // end namespace detail
} // end namespace tbb
} // end namespace system
} // end namespace thrust