2014-03-18 22:17:40 +01:00

345 lines
14 KiB
C++

/*
* Copyright 2008-2012 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <thrust/detail/config.h>
#include <thrust/system/tbb/detail/reduce_by_key.h>
#include <thrust/iterator/reverse_iterator.h>
#include <thrust/system/cpp/execution_policy.h>
#include <thrust/system/tbb/detail/execution_policy.h>
#include <thrust/system/tbb/detail/reduce_intervals.h>
#include <thrust/detail/minmax.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/detail/range/tail_flags.h>
#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>
#include <tbb/tbb_thread.h>
#include <cassert>
namespace thrust
{
namespace system
{
namespace tbb
{
namespace detail
{
namespace reduce_by_key_detail
{
template<typename L, typename R>
inline L divide_ri(const L x, const R y)
{
return (x + (y - 1)) / y;
}
template<typename InputIterator, typename BinaryFunction, typename OutputIterator = void>
struct partial_sum_type
: thrust::detail::eval_if<
thrust::detail::has_result_type<BinaryFunction>::value,
thrust::detail::result_type<BinaryFunction>,
thrust::detail::eval_if<
thrust::detail::is_output_iterator<OutputIterator>::value,
thrust::iterator_value<InputIterator>,
thrust::iterator_value<OutputIterator>
>
>
{};
template<typename InputIterator, typename BinaryFunction>
struct partial_sum_type<InputIterator,BinaryFunction,void>
: thrust::detail::eval_if<
thrust::detail::has_result_type<BinaryFunction>::value,
thrust::detail::result_type<BinaryFunction>,
thrust::iterator_value<InputIterator>
>
{};
template<typename InputIterator1,
typename InputIterator2,
typename BinaryPredicate,
typename BinaryFunction>
thrust::pair<
InputIterator1,
thrust::pair<
typename InputIterator1::value_type,
typename partial_sum_type<InputIterator2,BinaryFunction>::type
>
>
reduce_last_segment_backward(InputIterator1 keys_first,
InputIterator1 keys_last,
InputIterator2 values_first,
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typename thrust::iterator_difference<InputIterator1>::type n = keys_last - keys_first;
// reverse the ranges and consume from the end
thrust::reverse_iterator<InputIterator1> keys_first_r(keys_last);
thrust::reverse_iterator<InputIterator1> keys_last_r(keys_first);
thrust::reverse_iterator<InputIterator2> values_first_r(values_first + n);
typename InputIterator1::value_type result_key = *keys_first_r;
typename partial_sum_type<InputIterator2,BinaryFunction>::type result_value = *values_first_r;
// consume the entirety of the first key's sequence
for(++keys_first_r, ++values_first_r;
(keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key);
++keys_first_r, ++values_first_r)
{
result_value = binary_op(result_value, *values_first_r);
}
return thrust::make_pair(keys_first_r.base(), thrust::make_pair(result_key, result_value));
}
template<typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction>
thrust::tuple<
OutputIterator1,
OutputIterator2,
typename InputIterator1::value_type,
typename partial_sum_type<InputIterator2,BinaryFunction>::type
>
reduce_by_key_with_carry(InputIterator1 keys_first,
InputIterator1 keys_last,
InputIterator2 values_first,
OutputIterator1 keys_output,
OutputIterator2 values_output,
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
// first, consume the last sequence to produce the carry
// XXX is there an elegant way to pose this such that we don't need to default construct carry?
thrust::pair<
typename InputIterator1::value_type,
typename partial_sum_type<InputIterator2,BinaryFunction>::type
> carry;
thrust::tie(keys_last, carry) = reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op);
// finish with sequential reduce_by_key
thrust::cpp::tag seq;
thrust::tie(keys_output, values_output) =
thrust::reduce_by_key(seq, keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op);
return thrust::make_tuple(keys_output, values_output, carry.first, carry.second);
}
template<typename Iterator>
bool interval_has_carry(size_t interval_idx, size_t interval_size, size_t num_intervals, Iterator tail_flags)
{
// to discover whether the interval has a carry, look at the tail_flag corresponding to its last element
// the final interval never has a carry by definition
return (interval_idx + 1 < num_intervals) ? !tail_flags[(interval_idx + 1) * interval_size - 1] : false;
}
template<typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6, typename BinaryPredicate, typename BinaryFunction>
struct serial_reduce_by_key_body
{
typedef typename thrust::iterator_difference<Iterator1>::type size_type;
Iterator1 keys_first;
Iterator2 values_first;
Iterator3 result_offset;
Iterator4 keys_result;
Iterator5 values_result;
Iterator6 carry_result;
size_type n;
size_type interval_size;
size_type num_intervals;
BinaryPredicate binary_pred;
BinaryFunction binary_op;
serial_reduce_by_key_body(Iterator1 keys_first, Iterator2 values_first, Iterator3 result_offset, Iterator4 keys_result, Iterator5 values_result, Iterator6 carry_result, size_type n, size_type interval_size, size_type num_intervals, BinaryPredicate binary_pred, BinaryFunction binary_op)
: keys_first(keys_first), values_first(values_first),
result_offset(result_offset),
keys_result(keys_result),
values_result(values_result),
carry_result(carry_result),
n(n),
interval_size(interval_size),
num_intervals(num_intervals),
binary_pred(binary_pred),
binary_op(binary_op)
{}
void operator()(const ::tbb::blocked_range<size_type> &r) const
{
assert(r.size() == 1);
const size_type interval_idx = r.begin();
const size_type offset_to_first = interval_size * interval_idx;
const size_type offset_to_last = thrust::min(n, offset_to_first + interval_size);
Iterator1 my_keys_first = keys_first + offset_to_first;
Iterator1 my_keys_last = keys_first + offset_to_last;
Iterator2 my_values_first = values_first + offset_to_first;
Iterator3 my_result_offset = result_offset + interval_idx;
Iterator4 my_keys_result = keys_result + *my_result_offset;
Iterator5 my_values_result = values_result + *my_result_offset;
Iterator6 my_carry_result = carry_result + interval_idx;
// consume the rest of the interval with reduce_by_key
typedef typename thrust::iterator_value<Iterator1>::type key_type;
typedef typename partial_sum_type<Iterator2,BinaryFunction>::type value_type;
// XXX is there a way to pose this so that we don't require default construction of carry?
thrust::pair<key_type, value_type> carry;
thrust::tie(my_keys_result, my_values_result, carry.first, carry.second) =
reduce_by_key_with_carry(my_keys_first,
my_keys_last,
my_values_first,
my_keys_result,
my_values_result,
binary_pred,
binary_op);
// store to carry only when we actually have a carry
// store to my_keys_result & my_values_result otherwise
// create tail_flags so we can check for a carry
thrust::detail::tail_flags<Iterator1,BinaryPredicate> flags = thrust::detail::make_tail_flags(keys_first, keys_first + n, binary_pred);
if(interval_has_carry(interval_idx, interval_size, num_intervals, flags.begin()))
{
// we can ignore the carry's key
// XXX because the carry result is uninitialized, we should copy construct
*my_carry_result = carry.second;
}
else
{
*my_keys_result = carry.first;
*my_values_result = carry.second;
}
}
};
template<typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6, typename BinaryPredicate, typename BinaryFunction>
serial_reduce_by_key_body<Iterator1,Iterator2,Iterator3,Iterator4,Iterator5,Iterator6,BinaryPredicate,BinaryFunction>
make_serial_reduce_by_key_body(Iterator1 keys_first, Iterator2 values_first, Iterator3 result_offset, Iterator4 keys_result, Iterator5 values_result, Iterator6 carry_result, typename thrust::iterator_difference<Iterator1>::type n, size_t interval_size, size_t num_intervals, BinaryPredicate binary_pred, BinaryFunction binary_op)
{
return serial_reduce_by_key_body<Iterator1,Iterator2,Iterator3,Iterator4,Iterator5,Iterator6,BinaryPredicate,BinaryFunction>(keys_first, values_first, result_offset, keys_result, values_result, carry_result, n, interval_size, num_intervals, binary_pred, binary_op);
}
} // end reduce_by_key_detail
template<typename DerivedPolicy, typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename BinaryPredicate, typename BinaryFunction>
thrust::pair<Iterator3,Iterator4>
reduce_by_key(thrust::tbb::execution_policy<DerivedPolicy> &exec,
Iterator1 keys_first, Iterator1 keys_last,
Iterator2 values_first,
Iterator3 keys_result,
Iterator4 values_result,
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
typedef typename thrust::iterator_difference<Iterator1>::type difference_type;
difference_type n = keys_last - keys_first;
if(n == 0) return thrust::make_pair(keys_result, values_result);
// XXX this value is a tuning opportunity
const difference_type parallelism_threshold = 10000;
if(n < parallelism_threshold)
{
// don't bother parallelizing for small n
thrust::cpp::tag seq;
return thrust::reduce_by_key(seq, keys_first, keys_last, values_first, keys_result, values_result, binary_pred, binary_op);
}
// count the number of processors
const unsigned int p = thrust::max<unsigned int>(1u, ::tbb::tbb_thread::hardware_concurrency());
// generate O(P) intervals of sequential work
// XXX oversubscribing is a tuning opportunity
const unsigned int subscription_rate = 1;
difference_type interval_size = thrust::min<difference_type>(parallelism_threshold, thrust::max<difference_type>(n, n / (subscription_rate * p)));
difference_type num_intervals = reduce_by_key_detail::divide_ri(n, interval_size);
// decompose the input into intervals of size N / num_intervals
// add one extra element to this vector to store the size of the entire result
thrust::detail::temporary_array<difference_type, DerivedPolicy> interval_output_offsets(0, exec, num_intervals + 1);
// first count the number of tail flags in each interval
thrust::detail::tail_flags<Iterator1,BinaryPredicate> tail_flags = thrust::detail::make_tail_flags(keys_first, keys_last, binary_pred);
thrust::system::tbb::detail::reduce_intervals(exec, tail_flags.begin(), tail_flags.end(), interval_size, interval_output_offsets.begin() + 1, thrust::plus<size_t>());
interval_output_offsets[0] = 0;
// scan the counts to get each body's output offset
thrust::cpp::tag seq;
thrust::inclusive_scan(seq,
interval_output_offsets.begin() + 1, interval_output_offsets.end(),
interval_output_offsets.begin() + 1);
// do a reduce_by_key serially in each thread
// the final interval never has a carry by definition, so don't reserve space for it
typedef typename reduce_by_key_detail::partial_sum_type<Iterator2,BinaryFunction>::type carry_type;
thrust::detail::temporary_array<carry_type, DerivedPolicy> carries(0, exec, num_intervals - 1);
// force grainsize == 1 with simple_partioner()
::tbb::parallel_for(::tbb::blocked_range<difference_type>(0, num_intervals, 1),
reduce_by_key_detail::make_serial_reduce_by_key_body(keys_first, values_first, interval_output_offsets.begin(), keys_result, values_result, carries.begin(), n, interval_size, num_intervals, binary_pred, binary_op),
::tbb::simple_partitioner());
difference_type size_of_result = interval_output_offsets[num_intervals];
// sequentially accumulate the carries
// note that the last interval does not have a carry
// XXX find a way to express this loop via a sequential algorithm, perhaps reduce_by_key
for(typename thrust::detail::temporary_array<carry_type,DerivedPolicy>::size_type i = 0; i < carries.size(); ++i)
{
// if our interval has a carry, then we need to sum the carry to the next interval's output offset
// if it does not have a carry, then we need to ignore carry_value[i]
if(reduce_by_key_detail::interval_has_carry(i, interval_size, num_intervals, tail_flags.begin()))
{
difference_type output_idx = interval_output_offsets[i+1];
values_result[output_idx] = binary_op(values_result[output_idx], carries[i]);
}
}
return thrust::make_pair(keys_result + size_of_result, values_result + size_of_result);
}
} // end detail
} // end tbb
} // end system
} // end thrust