You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
705 lines
23 KiB
705 lines
23 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
|
|
#include <thrust/detail/config.h> |
|
|
|
#include <thrust/iterator/iterator_traits.h> |
|
#include <thrust/iterator/counting_iterator.h> |
|
#include <thrust/iterator/transform_iterator.h> |
|
#include <thrust/iterator/permutation_iterator.h> |
|
#include <thrust/iterator/zip_iterator.h> |
|
#include <thrust/system/detail/generic/select_system.h> |
|
|
|
#include <thrust/detail/type_traits.h> |
|
#include <thrust/detail/type_traits/function_traits.h> |
|
#include <thrust/detail/type_traits/iterator/is_output_iterator.h> |
|
#include <thrust/detail/type_traits/iterator/is_discard_iterator.h> |
|
|
|
#include <thrust/detail/minmax.h> |
|
#include <thrust/detail/internal_functional.h> |
|
#include <thrust/detail/temporary_array.h> |
|
|
|
#include <thrust/reduce.h> |
|
#include <thrust/scan.h> |
|
#include <thrust/system/cuda/detail/default_decomposition.h> |
|
#include <thrust/system/cuda/detail/block/inclusive_scan.h> |
|
#include <thrust/system/cuda/detail/execution_policy.h> |
|
#include <thrust/system/cuda/detail/detail/launch_closure.h> |
|
#include <thrust/system/cuda/detail/reduce_intervals.h> |
|
#include <thrust/system/cuda/detail/detail/uninitialized.h> |
|
|
|
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN |
|
|
|
namespace thrust |
|
{ |
|
namespace system |
|
{ |
|
namespace cuda |
|
{ |
|
namespace detail |
|
{ |
|
namespace reduce_by_key_detail |
|
{ |
|
|
|
template <typename FlagType, typename IndexType, typename KeyType, typename BinaryPredicate> |
|
struct tail_flag_functor |
|
{ |
|
BinaryPredicate binary_pred; // NB: this must be the first member for performance reasons |
|
IndexType n; |
|
|
|
typedef FlagType result_type; |
|
|
|
tail_flag_functor(IndexType n, BinaryPredicate binary_pred) |
|
: n(n), binary_pred(binary_pred) |
|
{} |
|
|
|
// XXX why is this noticably faster? (it may read past the end of input) |
|
//FlagType operator()(const thrust::tuple<IndexType,KeyType,KeyType>& t) const |
|
|
|
template <typename Tuple> |
|
__host__ __device__ __thrust_forceinline__ |
|
FlagType operator()(const Tuple& t) |
|
{ |
|
if (thrust::get<0>(t) == (n - 1) || !binary_pred(thrust::get<1>(t), thrust::get<2>(t))) |
|
return 1; |
|
else |
|
return 0; |
|
} |
|
}; |
|
|
|
|
|
template <unsigned int CTA_SIZE, |
|
unsigned int K, |
|
bool FullBlock, |
|
typename Context, |
|
typename FlagIterator, |
|
typename FlagType> |
|
__device__ __thrust_forceinline__ |
|
FlagType load_flags(Context context, |
|
const unsigned int n, |
|
FlagIterator iflags, |
|
FlagType (&sflag)[CTA_SIZE]) |
|
{ |
|
FlagType flag_bits = 0; |
|
|
|
// load flags in unordered fashion |
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
const unsigned int offset = k*CTA_SIZE + context.thread_index(); |
|
|
|
if (FullBlock || offset < n) |
|
{ |
|
FlagIterator temp = iflags + offset; |
|
if (*temp) |
|
flag_bits |= FlagType(1) << k; |
|
} |
|
} |
|
|
|
sflag[context.thread_index()] = flag_bits; |
|
|
|
context.barrier(); |
|
|
|
flag_bits = 0; |
|
|
|
// obtain flags for iflags[K * context.thread_index(), K * context.thread_index() + K) |
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
const unsigned int offset = K * context.thread_index() + k; |
|
|
|
if (FullBlock || offset < n) |
|
{ |
|
flag_bits |= ((sflag[offset % CTA_SIZE] >> (offset / CTA_SIZE)) & FlagType(1)) << k; |
|
} |
|
} |
|
|
|
context.barrier(); |
|
|
|
sflag[context.thread_index()] = flag_bits; |
|
|
|
context.barrier(); |
|
|
|
return flag_bits; |
|
} |
|
|
|
template <unsigned int CTA_SIZE, |
|
unsigned int K, |
|
bool FullBlock, |
|
typename Context, |
|
typename InputIterator2, |
|
typename ValueType> |
|
__device__ __thrust_forceinline__ |
|
void load_values(Context context, |
|
const unsigned int n, |
|
InputIterator2 ivals, |
|
ValueType (&sdata)[K][CTA_SIZE + 1]) |
|
{ |
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
const unsigned int offset = k*CTA_SIZE + context.thread_index(); |
|
|
|
if (FullBlock || offset < n) |
|
{ |
|
InputIterator2 temp = ivals + offset; |
|
sdata[offset % K][offset / K] = *temp; |
|
} |
|
} |
|
|
|
context.barrier(); |
|
} |
|
|
|
|
|
template <unsigned int CTA_SIZE, |
|
unsigned int K, |
|
bool FullBlock, |
|
typename Context, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction, |
|
typename FlagIterator, |
|
typename FlagType, |
|
typename IndexType, |
|
typename ValueType> |
|
__device__ __thrust_forceinline__ |
|
void reduce_by_key_body(Context context, |
|
const unsigned int n, |
|
InputIterator1 ikeys, |
|
InputIterator2 ivals, |
|
OutputIterator1 okeys, |
|
OutputIterator2 ovals, |
|
BinaryPredicate binary_pred, |
|
BinaryFunction binary_op, |
|
FlagIterator iflags, |
|
FlagType (&sflag)[CTA_SIZE], |
|
ValueType (&sdata)[K][CTA_SIZE + 1], |
|
bool& carry_in, |
|
IndexType& carry_index, |
|
ValueType& carry_value) |
|
{ |
|
// load flags |
|
const FlagType flag_bits = load_flags<CTA_SIZE,K,FullBlock>(context, n, iflags, sflag); |
|
const FlagType flag_count = __popc(flag_bits); // TODO hide this behind a template |
|
const FlagType left_flag = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1]; |
|
const FlagType head_flag = (context.thread_index() == 0 || flag_bits & ((1 << (K - 1)) - 1) || left_flag & (1 << (K - 1))) ? 1 : 0; |
|
|
|
context.barrier(); |
|
|
|
// scan flag counts |
|
sflag[context.thread_index()] = flag_count; context.barrier(); |
|
|
|
block::inclusive_scan(context, sflag, thrust::plus<FlagType>()); |
|
|
|
const FlagType output_position = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1]; |
|
const FlagType num_outputs = sflag[CTA_SIZE - 1]; |
|
|
|
context.barrier(); |
|
|
|
// shuffle keys and write keys out |
|
if (!thrust::detail::is_discard_iterator<OutputIterator1>::value) |
|
{ |
|
// XXX this could be improved |
|
for (unsigned int i = 0; i < num_outputs; i += CTA_SIZE) |
|
{ |
|
FlagType position = output_position; |
|
|
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
if (flag_bits & (FlagType(1) << k)) |
|
{ |
|
if (i <= position && position < i + CTA_SIZE) |
|
sflag[position - i] = K * context.thread_index() + k; |
|
position++; |
|
} |
|
} |
|
|
|
context.barrier(); |
|
|
|
if (i + context.thread_index() < num_outputs) |
|
{ |
|
InputIterator1 tmp1 = ikeys + sflag[context.thread_index()]; |
|
OutputIterator1 tmp2 = okeys + (i + context.thread_index()); |
|
*tmp2 = *tmp1; |
|
} |
|
|
|
context.barrier(); |
|
} |
|
} |
|
|
|
// load values |
|
load_values<CTA_SIZE,K,FullBlock> (context, n, ivals, sdata); |
|
|
|
ValueType ldata[K]; |
|
for (unsigned int k = 0; k < K; k++) |
|
ldata[k] = sdata[k][context.thread_index()]; |
|
|
|
// carry in (if necessary) |
|
if (context.thread_index() == 0 && carry_in) |
|
{ |
|
// XXX WAR sm_10 issue |
|
ValueType tmp1 = carry_value; |
|
ldata[0] = binary_op(tmp1, ldata[0]); |
|
} |
|
|
|
context.barrier(); |
|
|
|
// sum local values |
|
{ |
|
for(unsigned int k = 1; k < K; k++) |
|
{ |
|
const unsigned int offset = K * context.thread_index() + k; |
|
|
|
if (FullBlock || offset < n) |
|
{ |
|
if (!(flag_bits & (FlagType(1) << (k - 1)))) |
|
ldata[k] = binary_op(ldata[k - 1], ldata[k]); |
|
} |
|
} |
|
} |
|
|
|
// second level segmented scan |
|
{ |
|
// use head flags for segmented scan |
|
sflag[context.thread_index()] = head_flag; sdata[K - 1][context.thread_index()] = ldata[K - 1]; context.barrier(); |
|
|
|
if (FullBlock) |
|
block::inclusive_scan_by_flag(context, sflag, sdata[K-1], binary_op); |
|
else |
|
block::inclusive_scan_by_flag_n(context, sflag, sdata[K-1], n, binary_op); |
|
} |
|
|
|
// update local values |
|
if (context.thread_index() > 0) |
|
{ |
|
unsigned int update_bits = (flag_bits << 1) | (left_flag >> (K - 1)); |
|
// TODO remove guard |
|
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC |
|
unsigned int update_count = __ffs(update_bits) - 1u; // NB: this might wrap around to UINT_MAX |
|
#else |
|
unsigned int update_count = 0; |
|
#endif // THRUST_DEVICE_COMPILER_NVCC |
|
|
|
if (!FullBlock && (K + 1) * context.thread_index() > n) |
|
update_count = thrust::min(n - K * context.thread_index(), update_count); |
|
|
|
ValueType left = sdata[K - 1][context.thread_index() - 1]; |
|
|
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
if (k < update_count) |
|
ldata[k] = binary_op(left, ldata[k]); |
|
} |
|
} |
|
|
|
context.barrier(); |
|
|
|
// store carry out |
|
if (FullBlock) |
|
{ |
|
if (context.thread_index() == CTA_SIZE - 1) |
|
{ |
|
carry_value = ldata[K - 1]; |
|
carry_in = (flag_bits & (FlagType(1) << (K - 1))) ? false : true; |
|
carry_index = num_outputs; |
|
} |
|
} |
|
else |
|
{ |
|
if (context.thread_index() == (n - 1) / K) |
|
{ |
|
for (unsigned int k = 0; k < K; k++) |
|
if (k == (n - 1) % K) |
|
carry_value = ldata[k]; |
|
carry_in = (flag_bits & (FlagType(1) << ((n - 1) % K))) ? false : true; |
|
carry_index = num_outputs; |
|
} |
|
} |
|
|
|
// shuffle values |
|
{ |
|
FlagType position = output_position; |
|
|
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
const unsigned int offset = K * context.thread_index() + k; |
|
|
|
if (FullBlock || offset < n) |
|
{ |
|
if (flag_bits & (FlagType(1) << k)) |
|
{ |
|
sdata[position / CTA_SIZE][position % CTA_SIZE] = ldata[k]; |
|
position++; |
|
} |
|
} |
|
} |
|
} |
|
|
|
context.barrier(); |
|
|
|
|
|
// write values out |
|
for(unsigned int k = 0; k < K; k++) |
|
{ |
|
const unsigned int offset = CTA_SIZE * k + context.thread_index(); |
|
|
|
if (offset < num_outputs) |
|
{ |
|
OutputIterator2 tmp = ovals + offset; |
|
*tmp = sdata[k][context.thread_index()]; |
|
} |
|
} |
|
|
|
context.barrier(); |
|
} |
|
|
|
template <typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction, |
|
typename FlagIterator, |
|
typename IndexIterator, |
|
typename ValueIterator, |
|
typename BoolIterator, |
|
typename Decomposition, |
|
typename Context> |
|
struct reduce_by_key_closure |
|
{ |
|
InputIterator1 ikeys; |
|
InputIterator2 ivals; |
|
OutputIterator1 okeys; |
|
OutputIterator2 ovals; |
|
BinaryPredicate binary_pred; |
|
BinaryFunction binary_op; |
|
FlagIterator iflags; |
|
IndexIterator interval_counts; |
|
ValueIterator interval_values; |
|
BoolIterator interval_carry; |
|
Decomposition decomp; |
|
Context context; |
|
|
|
typedef Context context_type; |
|
|
|
reduce_by_key_closure(InputIterator1 ikeys, |
|
InputIterator2 ivals, |
|
OutputIterator1 okeys, |
|
OutputIterator2 ovals, |
|
BinaryPredicate binary_pred, |
|
BinaryFunction binary_op, |
|
FlagIterator iflags, |
|
IndexIterator interval_counts, |
|
ValueIterator interval_values, |
|
BoolIterator interval_carry, |
|
Decomposition decomp, |
|
Context context = Context()) |
|
: ikeys(ikeys), ivals(ivals), okeys(okeys), ovals(ovals), binary_pred(binary_pred), binary_op(binary_op), |
|
iflags(iflags), interval_counts(interval_counts), interval_values(interval_values), interval_carry(interval_carry), |
|
decomp(decomp), context(context) {} |
|
|
|
__device__ __thrust_forceinline__ |
|
void operator()(void) |
|
{ |
|
typedef typename thrust::iterator_value<InputIterator1>::type KeyType; |
|
typedef typename thrust::iterator_value<ValueIterator>::type ValueType; |
|
typedef typename Decomposition::index_type IndexType; |
|
typedef typename thrust::iterator_value<FlagIterator>::type FlagType; |
|
|
|
const unsigned int CTA_SIZE = context_type::ThreadsPerBlock::value; |
|
|
|
// TODO centralize this mapping (__CUDA_ARCH__ -> smem bytes) |
|
#if __CUDA_ARCH__ >= 200 |
|
const unsigned int SMEM = (48 * 1024); |
|
#else |
|
const unsigned int SMEM = (16 * 1024) - 256; |
|
#endif |
|
const unsigned int SMEM_FIXED = CTA_SIZE * sizeof(FlagType) + sizeof(ValueType) + sizeof(IndexType) + sizeof(bool); |
|
const unsigned int BOUND_1 = (SMEM - SMEM_FIXED) / ((CTA_SIZE + 1) * sizeof(ValueType)); |
|
const unsigned int BOUND_2 = 8 * sizeof(FlagType); |
|
const unsigned int BOUND_3 = 6; |
|
|
|
// TODO replace this with a static_min<BOUND_1,BOUND_2,BOUND_3>::value |
|
const unsigned int K = (BOUND_1 < BOUND_2) ? (BOUND_1 < BOUND_3 ? BOUND_1 : BOUND_3) : (BOUND_2 < BOUND_3 ? BOUND_2 : BOUND_3); |
|
|
|
__shared__ detail::uninitialized<FlagType[CTA_SIZE]> sflag; |
|
__shared__ detail::uninitialized<ValueType[K][CTA_SIZE + 1]> sdata; // padded to avoid bank conflicts |
|
|
|
__shared__ detail::uninitialized<ValueType> carry_value; // storage for carry in and carry out |
|
__shared__ detail::uninitialized<IndexType> carry_index; |
|
__shared__ detail::uninitialized<bool> carry_in; |
|
|
|
typename Decomposition::range_type interval = decomp[context.block_index()]; |
|
//thrust::system::detail::internal::index_range<IndexType> interval = decomp[context.block_index()]; |
|
|
|
|
|
if (context.thread_index() == 0) |
|
{ |
|
carry_in = false; // act as though the previous segment terminated just before us |
|
|
|
if (context.block_index() == 0) |
|
{ |
|
carry_index = 0; |
|
} |
|
else |
|
{ |
|
interval_counts += (context.block_index() - 1); |
|
carry_index = *interval_counts; |
|
} |
|
} |
|
|
|
context.barrier(); |
|
|
|
IndexType base = interval.begin(); |
|
|
|
// advance input and output iterators |
|
ikeys += base; |
|
ivals += base; |
|
iflags += base; |
|
okeys += carry_index; |
|
ovals += carry_index; |
|
|
|
const unsigned int unit_size = K * CTA_SIZE; |
|
|
|
// process full units |
|
while (base + unit_size <= interval.end()) |
|
{ |
|
const unsigned int n = unit_size; |
|
reduce_by_key_body<CTA_SIZE,K,true>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get()); |
|
base += unit_size; |
|
ikeys += unit_size; |
|
ivals += unit_size; |
|
iflags += unit_size; |
|
okeys += carry_index; |
|
ovals += carry_index; |
|
} |
|
|
|
// process partially full unit at end of input (if necessary) |
|
if (base < interval.end()) |
|
{ |
|
const unsigned int n = interval.end() - base; |
|
reduce_by_key_body<CTA_SIZE,K,false>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get()); |
|
} |
|
|
|
if (context.thread_index() == 0) |
|
{ |
|
interval_values += context.block_index(); |
|
interval_carry += context.block_index(); |
|
*interval_values = carry_value; |
|
*interval_carry = carry_in; |
|
} |
|
} |
|
}; // end reduce_by_key_closure |
|
|
|
template <typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction> |
|
struct DefaultPolicy |
|
{ |
|
// typedefs |
|
typedef unsigned int FlagType; |
|
typedef typename thrust::iterator_traits<InputIterator1>::difference_type IndexType; |
|
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType; |
|
typedef thrust::system::detail::internal::uniform_decomposition<IndexType> Decomposition; |
|
|
|
// the pseudocode for deducing the type of the temporary used below: |
|
// |
|
// if BinaryFunction is AdaptableBinaryFunction |
|
// TemporaryType = AdaptableBinaryFunction::result_type |
|
// else if OutputIterator2 is a "pure" output iterator |
|
// TemporaryType = InputIterator2::value_type |
|
// else |
|
// TemporaryType = OutputIterator2::value_type |
|
// |
|
// XXX upon c++0x, TemporaryType needs to be: |
|
// result_of<BinaryFunction>::type |
|
|
|
typedef typename thrust::detail::eval_if< |
|
thrust::detail::has_result_type<BinaryFunction>::value, |
|
thrust::detail::result_type<BinaryFunction>, |
|
thrust::detail::eval_if< |
|
thrust::detail::is_output_iterator<OutputIterator2>::value, |
|
thrust::iterator_value<InputIterator2>, |
|
thrust::iterator_value<OutputIterator2> |
|
> |
|
>::type ValueType; |
|
|
|
// XXX WAR problem on sm_11 |
|
// TODO tune this |
|
const static unsigned int ThreadsPerBlock = (thrust::detail::is_pod<ValueType>::value) ? 256 : 192; |
|
|
|
DefaultPolicy(InputIterator1 first1, InputIterator1 last1) |
|
: decomp(default_decomposition<IndexType>(last1 - first1)) |
|
{} |
|
|
|
// member variables |
|
Decomposition decomp; |
|
}; |
|
|
|
template <typename DerivedPolicy, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction, |
|
typename Policy> |
|
thrust::pair<OutputIterator1,OutputIterator2> |
|
reduce_by_key(execution_policy<DerivedPolicy> &exec, |
|
InputIterator1 keys_first, |
|
InputIterator1 keys_last, |
|
InputIterator2 values_first, |
|
OutputIterator1 keys_output, |
|
OutputIterator2 values_output, |
|
BinaryPredicate binary_pred, |
|
BinaryFunction binary_op, |
|
Policy policy) |
|
{ |
|
typedef typename Policy::FlagType FlagType; |
|
typedef typename Policy::Decomposition Decomposition; |
|
typedef typename Policy::IndexType IndexType; |
|
typedef typename Policy::KeyType KeyType; |
|
typedef typename Policy::ValueType ValueType; |
|
|
|
// temporary arrays |
|
typedef thrust::detail::temporary_array<IndexType,DerivedPolicy> IndexArray; |
|
typedef thrust::detail::temporary_array<KeyType,DerivedPolicy> KeyArray; |
|
typedef thrust::detail::temporary_array<ValueType,DerivedPolicy> ValueArray; |
|
typedef thrust::detail::temporary_array<bool,DerivedPolicy> BoolArray; |
|
|
|
Decomposition decomp = policy.decomp; |
|
|
|
// input size |
|
IndexType n = keys_last - keys_first; |
|
|
|
if (n == 0) |
|
return thrust::make_pair(keys_output, values_output); |
|
|
|
IndexArray interval_counts(exec, decomp.size()); |
|
ValueArray interval_values(exec, decomp.size()); |
|
BoolArray interval_carry(exec, decomp.size()); |
|
|
|
// an ode to c++11 auto |
|
typedef thrust::counting_iterator<IndexType> CountingIterator; |
|
typedef thrust::transform_iterator< |
|
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>, |
|
thrust::zip_iterator< |
|
thrust::tuple<CountingIterator,InputIterator1,InputIterator1> |
|
> |
|
> FlagIterator; |
|
|
|
FlagIterator iflag= thrust::make_transform_iterator |
|
(thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator<IndexType>(0), keys_first, keys_first + 1)), |
|
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>(n, binary_pred)); |
|
|
|
// count number of tail flags per interval |
|
thrust::system::cuda::detail::reduce_intervals(exec, iflag, interval_counts.begin(), thrust::plus<IndexType>(), decomp); |
|
|
|
thrust::inclusive_scan(exec, |
|
interval_counts.begin(), interval_counts.end(), |
|
interval_counts.begin(), |
|
thrust::plus<IndexType>()); |
|
|
|
// determine output size |
|
const IndexType N = interval_counts[interval_counts.size() - 1]; |
|
|
|
const static unsigned int ThreadsPerBlock = Policy::ThreadsPerBlock; |
|
typedef typename IndexArray::iterator IndexIterator; |
|
typedef typename ValueArray::iterator ValueIterator; |
|
typedef typename BoolArray::iterator BoolIterator; |
|
typedef detail::statically_blocked_thread_array<ThreadsPerBlock> Context; |
|
typedef reduce_by_key_closure<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction, |
|
FlagIterator,IndexIterator,ValueIterator,BoolIterator,Decomposition,Context> Closure; |
|
Closure closure |
|
(keys_first, values_first, |
|
keys_output, values_output, |
|
binary_pred, binary_op, |
|
iflag, |
|
interval_counts.begin(), |
|
interval_values.begin(), |
|
interval_carry.begin(), |
|
decomp); |
|
detail::launch_closure(closure, decomp.size(), ThreadsPerBlock); |
|
|
|
if (decomp.size() > 1) |
|
{ |
|
ValueArray interval_values2(exec, decomp.size()); |
|
IndexArray interval_counts2(exec, decomp.size()); |
|
BoolArray interval_carry2(exec, decomp.size()); |
|
|
|
IndexType N2 = |
|
thrust::reduce_by_key |
|
(exec, |
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.begin(), interval_carry.begin())), |
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.end(), interval_carry.end())), |
|
interval_values.begin(), |
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())), |
|
interval_values2.begin(), |
|
thrust::equal_to< thrust::tuple<IndexType,bool> >(), |
|
binary_op).first |
|
- |
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())); |
|
|
|
thrust::transform_if |
|
(exec, |
|
interval_values2.begin(), interval_values2.begin() + N2, |
|
thrust::make_permutation_iterator(values_output, interval_counts2.begin()), |
|
interval_carry2.begin(), |
|
thrust::make_permutation_iterator(values_output, interval_counts2.begin()), |
|
binary_op, |
|
thrust::identity<bool>()); |
|
} |
|
|
|
return thrust::make_pair(keys_output + N, values_output + N); |
|
} |
|
|
|
} // end namespace reduce_by_key_detail |
|
|
|
|
|
template <typename DerivedPolicy, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction> |
|
thrust::pair<OutputIterator1,OutputIterator2> |
|
reduce_by_key(execution_policy<DerivedPolicy> &exec, |
|
InputIterator1 keys_first, |
|
InputIterator1 keys_last, |
|
InputIterator2 values_first, |
|
OutputIterator1 keys_output, |
|
OutputIterator2 values_output, |
|
BinaryPredicate binary_pred, |
|
BinaryFunction binary_op) |
|
{ |
|
return reduce_by_key_detail::reduce_by_key |
|
(exec, |
|
keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op, |
|
reduce_by_key_detail::DefaultPolicy<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction>(keys_first, keys_last)); |
|
} // end reduce_by_key() |
|
|
|
} // end namespace detail |
|
} // end namespace cuda |
|
} // end namespace system |
|
} // end namespace thrust |
|
|
|
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_END |
|
|
|
|