706 lines
23 KiB
Plaintext
706 lines
23 KiB
Plaintext
|
/*
|
||
|
* Copyright 2008-2012 NVIDIA Corporation
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
|
||
|
#include <thrust/detail/config.h>
|
||
|
|
||
|
#include <thrust/iterator/iterator_traits.h>
|
||
|
#include <thrust/iterator/counting_iterator.h>
|
||
|
#include <thrust/iterator/transform_iterator.h>
|
||
|
#include <thrust/iterator/permutation_iterator.h>
|
||
|
#include <thrust/iterator/zip_iterator.h>
|
||
|
#include <thrust/system/detail/generic/select_system.h>
|
||
|
|
||
|
#include <thrust/detail/type_traits.h>
|
||
|
#include <thrust/detail/type_traits/function_traits.h>
|
||
|
#include <thrust/detail/type_traits/iterator/is_output_iterator.h>
|
||
|
#include <thrust/detail/type_traits/iterator/is_discard_iterator.h>
|
||
|
|
||
|
#include <thrust/detail/minmax.h>
|
||
|
#include <thrust/detail/internal_functional.h>
|
||
|
#include <thrust/detail/temporary_array.h>
|
||
|
|
||
|
#include <thrust/reduce.h>
|
||
|
#include <thrust/scan.h>
|
||
|
#include <thrust/system/cuda/detail/default_decomposition.h>
|
||
|
#include <thrust/system/cuda/detail/block/inclusive_scan.h>
|
||
|
#include <thrust/system/cuda/detail/execution_policy.h>
|
||
|
#include <thrust/system/cuda/detail/detail/launch_closure.h>
|
||
|
#include <thrust/system/cuda/detail/reduce_intervals.h>
|
||
|
#include <thrust/system/cuda/detail/detail/uninitialized.h>
|
||
|
|
||
|
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN
|
||
|
|
||
|
namespace thrust
|
||
|
{
|
||
|
namespace system
|
||
|
{
|
||
|
namespace cuda
|
||
|
{
|
||
|
namespace detail
|
||
|
{
|
||
|
namespace reduce_by_key_detail
|
||
|
{
|
||
|
|
||
|
template <typename FlagType, typename IndexType, typename KeyType, typename BinaryPredicate>
|
||
|
struct tail_flag_functor
|
||
|
{
|
||
|
BinaryPredicate binary_pred; // NB: this must be the first member for performance reasons
|
||
|
IndexType n;
|
||
|
|
||
|
typedef FlagType result_type;
|
||
|
|
||
|
tail_flag_functor(IndexType n, BinaryPredicate binary_pred)
|
||
|
: n(n), binary_pred(binary_pred)
|
||
|
{}
|
||
|
|
||
|
// XXX why is this noticably faster? (it may read past the end of input)
|
||
|
//FlagType operator()(const thrust::tuple<IndexType,KeyType,KeyType>& t) const
|
||
|
|
||
|
template <typename Tuple>
|
||
|
__host__ __device__ __thrust_forceinline__
|
||
|
FlagType operator()(const Tuple& t)
|
||
|
{
|
||
|
if (thrust::get<0>(t) == (n - 1) || !binary_pred(thrust::get<1>(t), thrust::get<2>(t)))
|
||
|
return 1;
|
||
|
else
|
||
|
return 0;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
|
||
|
template <unsigned int CTA_SIZE,
|
||
|
unsigned int K,
|
||
|
bool FullBlock,
|
||
|
typename Context,
|
||
|
typename FlagIterator,
|
||
|
typename FlagType>
|
||
|
__device__ __thrust_forceinline__
|
||
|
FlagType load_flags(Context context,
|
||
|
const unsigned int n,
|
||
|
FlagIterator iflags,
|
||
|
FlagType (&sflag)[CTA_SIZE])
|
||
|
{
|
||
|
FlagType flag_bits = 0;
|
||
|
|
||
|
// load flags in unordered fashion
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = k*CTA_SIZE + context.thread_index();
|
||
|
|
||
|
if (FullBlock || offset < n)
|
||
|
{
|
||
|
FlagIterator temp = iflags + offset;
|
||
|
if (*temp)
|
||
|
flag_bits |= FlagType(1) << k;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
sflag[context.thread_index()] = flag_bits;
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
flag_bits = 0;
|
||
|
|
||
|
// obtain flags for iflags[K * context.thread_index(), K * context.thread_index() + K)
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = K * context.thread_index() + k;
|
||
|
|
||
|
if (FullBlock || offset < n)
|
||
|
{
|
||
|
flag_bits |= ((sflag[offset % CTA_SIZE] >> (offset / CTA_SIZE)) & FlagType(1)) << k;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
sflag[context.thread_index()] = flag_bits;
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
return flag_bits;
|
||
|
}
|
||
|
|
||
|
template <unsigned int CTA_SIZE,
|
||
|
unsigned int K,
|
||
|
bool FullBlock,
|
||
|
typename Context,
|
||
|
typename InputIterator2,
|
||
|
typename ValueType>
|
||
|
__device__ __thrust_forceinline__
|
||
|
void load_values(Context context,
|
||
|
const unsigned int n,
|
||
|
InputIterator2 ivals,
|
||
|
ValueType (&sdata)[K][CTA_SIZE + 1])
|
||
|
{
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = k*CTA_SIZE + context.thread_index();
|
||
|
|
||
|
if (FullBlock || offset < n)
|
||
|
{
|
||
|
InputIterator2 temp = ivals + offset;
|
||
|
sdata[offset % K][offset / K] = *temp;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
}
|
||
|
|
||
|
|
||
|
template <unsigned int CTA_SIZE,
|
||
|
unsigned int K,
|
||
|
bool FullBlock,
|
||
|
typename Context,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction,
|
||
|
typename FlagIterator,
|
||
|
typename FlagType,
|
||
|
typename IndexType,
|
||
|
typename ValueType>
|
||
|
__device__ __thrust_forceinline__
|
||
|
void reduce_by_key_body(Context context,
|
||
|
const unsigned int n,
|
||
|
InputIterator1 ikeys,
|
||
|
InputIterator2 ivals,
|
||
|
OutputIterator1 okeys,
|
||
|
OutputIterator2 ovals,
|
||
|
BinaryPredicate binary_pred,
|
||
|
BinaryFunction binary_op,
|
||
|
FlagIterator iflags,
|
||
|
FlagType (&sflag)[CTA_SIZE],
|
||
|
ValueType (&sdata)[K][CTA_SIZE + 1],
|
||
|
bool& carry_in,
|
||
|
IndexType& carry_index,
|
||
|
ValueType& carry_value)
|
||
|
{
|
||
|
// load flags
|
||
|
const FlagType flag_bits = load_flags<CTA_SIZE,K,FullBlock>(context, n, iflags, sflag);
|
||
|
const FlagType flag_count = __popc(flag_bits); // TODO hide this behind a template
|
||
|
const FlagType left_flag = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1];
|
||
|
const FlagType head_flag = (context.thread_index() == 0 || flag_bits & ((1 << (K - 1)) - 1) || left_flag & (1 << (K - 1))) ? 1 : 0;
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
// scan flag counts
|
||
|
sflag[context.thread_index()] = flag_count; context.barrier();
|
||
|
|
||
|
block::inclusive_scan(context, sflag, thrust::plus<FlagType>());
|
||
|
|
||
|
const FlagType output_position = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1];
|
||
|
const FlagType num_outputs = sflag[CTA_SIZE - 1];
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
// shuffle keys and write keys out
|
||
|
if (!thrust::detail::is_discard_iterator<OutputIterator1>::value)
|
||
|
{
|
||
|
// XXX this could be improved
|
||
|
for (unsigned int i = 0; i < num_outputs; i += CTA_SIZE)
|
||
|
{
|
||
|
FlagType position = output_position;
|
||
|
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
if (flag_bits & (FlagType(1) << k))
|
||
|
{
|
||
|
if (i <= position && position < i + CTA_SIZE)
|
||
|
sflag[position - i] = K * context.thread_index() + k;
|
||
|
position++;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
if (i + context.thread_index() < num_outputs)
|
||
|
{
|
||
|
InputIterator1 tmp1 = ikeys + sflag[context.thread_index()];
|
||
|
OutputIterator1 tmp2 = okeys + (i + context.thread_index());
|
||
|
*tmp2 = *tmp1;
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// load values
|
||
|
load_values<CTA_SIZE,K,FullBlock> (context, n, ivals, sdata);
|
||
|
|
||
|
ValueType ldata[K];
|
||
|
for (unsigned int k = 0; k < K; k++)
|
||
|
ldata[k] = sdata[k][context.thread_index()];
|
||
|
|
||
|
// carry in (if necessary)
|
||
|
if (context.thread_index() == 0 && carry_in)
|
||
|
{
|
||
|
// XXX WAR sm_10 issue
|
||
|
ValueType tmp1 = carry_value;
|
||
|
ldata[0] = binary_op(tmp1, ldata[0]);
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
// sum local values
|
||
|
{
|
||
|
for(unsigned int k = 1; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = K * context.thread_index() + k;
|
||
|
|
||
|
if (FullBlock || offset < n)
|
||
|
{
|
||
|
if (!(flag_bits & (FlagType(1) << (k - 1))))
|
||
|
ldata[k] = binary_op(ldata[k - 1], ldata[k]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// second level segmented scan
|
||
|
{
|
||
|
// use head flags for segmented scan
|
||
|
sflag[context.thread_index()] = head_flag; sdata[K - 1][context.thread_index()] = ldata[K - 1]; context.barrier();
|
||
|
|
||
|
if (FullBlock)
|
||
|
block::inclusive_scan_by_flag(context, sflag, sdata[K-1], binary_op);
|
||
|
else
|
||
|
block::inclusive_scan_by_flag_n(context, sflag, sdata[K-1], n, binary_op);
|
||
|
}
|
||
|
|
||
|
// update local values
|
||
|
if (context.thread_index() > 0)
|
||
|
{
|
||
|
unsigned int update_bits = (flag_bits << 1) | (left_flag >> (K - 1));
|
||
|
// TODO remove guard
|
||
|
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
|
||
|
unsigned int update_count = __ffs(update_bits) - 1u; // NB: this might wrap around to UINT_MAX
|
||
|
#else
|
||
|
unsigned int update_count = 0;
|
||
|
#endif // THRUST_DEVICE_COMPILER_NVCC
|
||
|
|
||
|
if (!FullBlock && (K + 1) * context.thread_index() > n)
|
||
|
update_count = thrust::min(n - K * context.thread_index(), update_count);
|
||
|
|
||
|
ValueType left = sdata[K - 1][context.thread_index() - 1];
|
||
|
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
if (k < update_count)
|
||
|
ldata[k] = binary_op(left, ldata[k]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
// store carry out
|
||
|
if (FullBlock)
|
||
|
{
|
||
|
if (context.thread_index() == CTA_SIZE - 1)
|
||
|
{
|
||
|
carry_value = ldata[K - 1];
|
||
|
carry_in = (flag_bits & (FlagType(1) << (K - 1))) ? false : true;
|
||
|
carry_index = num_outputs;
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
if (context.thread_index() == (n - 1) / K)
|
||
|
{
|
||
|
for (unsigned int k = 0; k < K; k++)
|
||
|
if (k == (n - 1) % K)
|
||
|
carry_value = ldata[k];
|
||
|
carry_in = (flag_bits & (FlagType(1) << ((n - 1) % K))) ? false : true;
|
||
|
carry_index = num_outputs;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// shuffle values
|
||
|
{
|
||
|
FlagType position = output_position;
|
||
|
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = K * context.thread_index() + k;
|
||
|
|
||
|
if (FullBlock || offset < n)
|
||
|
{
|
||
|
if (flag_bits & (FlagType(1) << k))
|
||
|
{
|
||
|
sdata[position / CTA_SIZE][position % CTA_SIZE] = ldata[k];
|
||
|
position++;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
|
||
|
// write values out
|
||
|
for(unsigned int k = 0; k < K; k++)
|
||
|
{
|
||
|
const unsigned int offset = CTA_SIZE * k + context.thread_index();
|
||
|
|
||
|
if (offset < num_outputs)
|
||
|
{
|
||
|
OutputIterator2 tmp = ovals + offset;
|
||
|
*tmp = sdata[k][context.thread_index()];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
}
|
||
|
|
||
|
template <typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction,
|
||
|
typename FlagIterator,
|
||
|
typename IndexIterator,
|
||
|
typename ValueIterator,
|
||
|
typename BoolIterator,
|
||
|
typename Decomposition,
|
||
|
typename Context>
|
||
|
struct reduce_by_key_closure
|
||
|
{
|
||
|
InputIterator1 ikeys;
|
||
|
InputIterator2 ivals;
|
||
|
OutputIterator1 okeys;
|
||
|
OutputIterator2 ovals;
|
||
|
BinaryPredicate binary_pred;
|
||
|
BinaryFunction binary_op;
|
||
|
FlagIterator iflags;
|
||
|
IndexIterator interval_counts;
|
||
|
ValueIterator interval_values;
|
||
|
BoolIterator interval_carry;
|
||
|
Decomposition decomp;
|
||
|
Context context;
|
||
|
|
||
|
typedef Context context_type;
|
||
|
|
||
|
reduce_by_key_closure(InputIterator1 ikeys,
|
||
|
InputIterator2 ivals,
|
||
|
OutputIterator1 okeys,
|
||
|
OutputIterator2 ovals,
|
||
|
BinaryPredicate binary_pred,
|
||
|
BinaryFunction binary_op,
|
||
|
FlagIterator iflags,
|
||
|
IndexIterator interval_counts,
|
||
|
ValueIterator interval_values,
|
||
|
BoolIterator interval_carry,
|
||
|
Decomposition decomp,
|
||
|
Context context = Context())
|
||
|
: ikeys(ikeys), ivals(ivals), okeys(okeys), ovals(ovals), binary_pred(binary_pred), binary_op(binary_op),
|
||
|
iflags(iflags), interval_counts(interval_counts), interval_values(interval_values), interval_carry(interval_carry),
|
||
|
decomp(decomp), context(context) {}
|
||
|
|
||
|
__device__ __thrust_forceinline__
|
||
|
void operator()(void)
|
||
|
{
|
||
|
typedef typename thrust::iterator_value<InputIterator1>::type KeyType;
|
||
|
typedef typename thrust::iterator_value<ValueIterator>::type ValueType;
|
||
|
typedef typename Decomposition::index_type IndexType;
|
||
|
typedef typename thrust::iterator_value<FlagIterator>::type FlagType;
|
||
|
|
||
|
const unsigned int CTA_SIZE = context_type::ThreadsPerBlock::value;
|
||
|
|
||
|
// TODO centralize this mapping (__CUDA_ARCH__ -> smem bytes)
|
||
|
#if __CUDA_ARCH__ >= 200
|
||
|
const unsigned int SMEM = (48 * 1024);
|
||
|
#else
|
||
|
const unsigned int SMEM = (16 * 1024) - 256;
|
||
|
#endif
|
||
|
const unsigned int SMEM_FIXED = CTA_SIZE * sizeof(FlagType) + sizeof(ValueType) + sizeof(IndexType) + sizeof(bool);
|
||
|
const unsigned int BOUND_1 = (SMEM - SMEM_FIXED) / ((CTA_SIZE + 1) * sizeof(ValueType));
|
||
|
const unsigned int BOUND_2 = 8 * sizeof(FlagType);
|
||
|
const unsigned int BOUND_3 = 6;
|
||
|
|
||
|
// TODO replace this with a static_min<BOUND_1,BOUND_2,BOUND_3>::value
|
||
|
const unsigned int K = (BOUND_1 < BOUND_2) ? (BOUND_1 < BOUND_3 ? BOUND_1 : BOUND_3) : (BOUND_2 < BOUND_3 ? BOUND_2 : BOUND_3);
|
||
|
|
||
|
__shared__ detail::uninitialized<FlagType[CTA_SIZE]> sflag;
|
||
|
__shared__ detail::uninitialized<ValueType[K][CTA_SIZE + 1]> sdata; // padded to avoid bank conflicts
|
||
|
|
||
|
__shared__ detail::uninitialized<ValueType> carry_value; // storage for carry in and carry out
|
||
|
__shared__ detail::uninitialized<IndexType> carry_index;
|
||
|
__shared__ detail::uninitialized<bool> carry_in;
|
||
|
|
||
|
typename Decomposition::range_type interval = decomp[context.block_index()];
|
||
|
//thrust::system::detail::internal::index_range<IndexType> interval = decomp[context.block_index()];
|
||
|
|
||
|
|
||
|
if (context.thread_index() == 0)
|
||
|
{
|
||
|
carry_in = false; // act as though the previous segment terminated just before us
|
||
|
|
||
|
if (context.block_index() == 0)
|
||
|
{
|
||
|
carry_index = 0;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
interval_counts += (context.block_index() - 1);
|
||
|
carry_index = *interval_counts;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
IndexType base = interval.begin();
|
||
|
|
||
|
// advance input and output iterators
|
||
|
ikeys += base;
|
||
|
ivals += base;
|
||
|
iflags += base;
|
||
|
okeys += carry_index;
|
||
|
ovals += carry_index;
|
||
|
|
||
|
const unsigned int unit_size = K * CTA_SIZE;
|
||
|
|
||
|
// process full units
|
||
|
while (base + unit_size <= interval.end())
|
||
|
{
|
||
|
const unsigned int n = unit_size;
|
||
|
reduce_by_key_body<CTA_SIZE,K,true>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get());
|
||
|
base += unit_size;
|
||
|
ikeys += unit_size;
|
||
|
ivals += unit_size;
|
||
|
iflags += unit_size;
|
||
|
okeys += carry_index;
|
||
|
ovals += carry_index;
|
||
|
}
|
||
|
|
||
|
// process partially full unit at end of input (if necessary)
|
||
|
if (base < interval.end())
|
||
|
{
|
||
|
const unsigned int n = interval.end() - base;
|
||
|
reduce_by_key_body<CTA_SIZE,K,false>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get());
|
||
|
}
|
||
|
|
||
|
if (context.thread_index() == 0)
|
||
|
{
|
||
|
interval_values += context.block_index();
|
||
|
interval_carry += context.block_index();
|
||
|
*interval_values = carry_value;
|
||
|
*interval_carry = carry_in;
|
||
|
}
|
||
|
}
|
||
|
}; // end reduce_by_key_closure
|
||
|
|
||
|
template <typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction>
|
||
|
struct DefaultPolicy
|
||
|
{
|
||
|
// typedefs
|
||
|
typedef unsigned int FlagType;
|
||
|
typedef typename thrust::iterator_traits<InputIterator1>::difference_type IndexType;
|
||
|
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
|
||
|
typedef thrust::system::detail::internal::uniform_decomposition<IndexType> Decomposition;
|
||
|
|
||
|
// the pseudocode for deducing the type of the temporary used below:
|
||
|
//
|
||
|
// if BinaryFunction is AdaptableBinaryFunction
|
||
|
// TemporaryType = AdaptableBinaryFunction::result_type
|
||
|
// else if OutputIterator2 is a "pure" output iterator
|
||
|
// TemporaryType = InputIterator2::value_type
|
||
|
// else
|
||
|
// TemporaryType = OutputIterator2::value_type
|
||
|
//
|
||
|
// XXX upon c++0x, TemporaryType needs to be:
|
||
|
// result_of<BinaryFunction>::type
|
||
|
|
||
|
typedef typename thrust::detail::eval_if<
|
||
|
thrust::detail::has_result_type<BinaryFunction>::value,
|
||
|
thrust::detail::result_type<BinaryFunction>,
|
||
|
thrust::detail::eval_if<
|
||
|
thrust::detail::is_output_iterator<OutputIterator2>::value,
|
||
|
thrust::iterator_value<InputIterator2>,
|
||
|
thrust::iterator_value<OutputIterator2>
|
||
|
>
|
||
|
>::type ValueType;
|
||
|
|
||
|
// XXX WAR problem on sm_11
|
||
|
// TODO tune this
|
||
|
const static unsigned int ThreadsPerBlock = (thrust::detail::is_pod<ValueType>::value) ? 256 : 192;
|
||
|
|
||
|
DefaultPolicy(InputIterator1 first1, InputIterator1 last1)
|
||
|
: decomp(default_decomposition<IndexType>(last1 - first1))
|
||
|
{}
|
||
|
|
||
|
// member variables
|
||
|
Decomposition decomp;
|
||
|
};
|
||
|
|
||
|
template <typename DerivedPolicy,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction,
|
||
|
typename Policy>
|
||
|
thrust::pair<OutputIterator1,OutputIterator2>
|
||
|
reduce_by_key(execution_policy<DerivedPolicy> &exec,
|
||
|
InputIterator1 keys_first,
|
||
|
InputIterator1 keys_last,
|
||
|
InputIterator2 values_first,
|
||
|
OutputIterator1 keys_output,
|
||
|
OutputIterator2 values_output,
|
||
|
BinaryPredicate binary_pred,
|
||
|
BinaryFunction binary_op,
|
||
|
Policy policy)
|
||
|
{
|
||
|
typedef typename Policy::FlagType FlagType;
|
||
|
typedef typename Policy::Decomposition Decomposition;
|
||
|
typedef typename Policy::IndexType IndexType;
|
||
|
typedef typename Policy::KeyType KeyType;
|
||
|
typedef typename Policy::ValueType ValueType;
|
||
|
|
||
|
// temporary arrays
|
||
|
typedef thrust::detail::temporary_array<IndexType,DerivedPolicy> IndexArray;
|
||
|
typedef thrust::detail::temporary_array<KeyType,DerivedPolicy> KeyArray;
|
||
|
typedef thrust::detail::temporary_array<ValueType,DerivedPolicy> ValueArray;
|
||
|
typedef thrust::detail::temporary_array<bool,DerivedPolicy> BoolArray;
|
||
|
|
||
|
Decomposition decomp = policy.decomp;
|
||
|
|
||
|
// input size
|
||
|
IndexType n = keys_last - keys_first;
|
||
|
|
||
|
if (n == 0)
|
||
|
return thrust::make_pair(keys_output, values_output);
|
||
|
|
||
|
IndexArray interval_counts(exec, decomp.size());
|
||
|
ValueArray interval_values(exec, decomp.size());
|
||
|
BoolArray interval_carry(exec, decomp.size());
|
||
|
|
||
|
// an ode to c++11 auto
|
||
|
typedef thrust::counting_iterator<IndexType> CountingIterator;
|
||
|
typedef thrust::transform_iterator<
|
||
|
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>,
|
||
|
thrust::zip_iterator<
|
||
|
thrust::tuple<CountingIterator,InputIterator1,InputIterator1>
|
||
|
>
|
||
|
> FlagIterator;
|
||
|
|
||
|
FlagIterator iflag= thrust::make_transform_iterator
|
||
|
(thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator<IndexType>(0), keys_first, keys_first + 1)),
|
||
|
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>(n, binary_pred));
|
||
|
|
||
|
// count number of tail flags per interval
|
||
|
thrust::system::cuda::detail::reduce_intervals(exec, iflag, interval_counts.begin(), thrust::plus<IndexType>(), decomp);
|
||
|
|
||
|
thrust::inclusive_scan(exec,
|
||
|
interval_counts.begin(), interval_counts.end(),
|
||
|
interval_counts.begin(),
|
||
|
thrust::plus<IndexType>());
|
||
|
|
||
|
// determine output size
|
||
|
const IndexType N = interval_counts[interval_counts.size() - 1];
|
||
|
|
||
|
const static unsigned int ThreadsPerBlock = Policy::ThreadsPerBlock;
|
||
|
typedef typename IndexArray::iterator IndexIterator;
|
||
|
typedef typename ValueArray::iterator ValueIterator;
|
||
|
typedef typename BoolArray::iterator BoolIterator;
|
||
|
typedef detail::statically_blocked_thread_array<ThreadsPerBlock> Context;
|
||
|
typedef reduce_by_key_closure<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction,
|
||
|
FlagIterator,IndexIterator,ValueIterator,BoolIterator,Decomposition,Context> Closure;
|
||
|
Closure closure
|
||
|
(keys_first, values_first,
|
||
|
keys_output, values_output,
|
||
|
binary_pred, binary_op,
|
||
|
iflag,
|
||
|
interval_counts.begin(),
|
||
|
interval_values.begin(),
|
||
|
interval_carry.begin(),
|
||
|
decomp);
|
||
|
detail::launch_closure(closure, decomp.size(), ThreadsPerBlock);
|
||
|
|
||
|
if (decomp.size() > 1)
|
||
|
{
|
||
|
ValueArray interval_values2(exec, decomp.size());
|
||
|
IndexArray interval_counts2(exec, decomp.size());
|
||
|
BoolArray interval_carry2(exec, decomp.size());
|
||
|
|
||
|
IndexType N2 =
|
||
|
thrust::reduce_by_key
|
||
|
(exec,
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.begin(), interval_carry.begin())),
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.end(), interval_carry.end())),
|
||
|
interval_values.begin(),
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())),
|
||
|
interval_values2.begin(),
|
||
|
thrust::equal_to< thrust::tuple<IndexType,bool> >(),
|
||
|
binary_op).first
|
||
|
-
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin()));
|
||
|
|
||
|
thrust::transform_if
|
||
|
(exec,
|
||
|
interval_values2.begin(), interval_values2.begin() + N2,
|
||
|
thrust::make_permutation_iterator(values_output, interval_counts2.begin()),
|
||
|
interval_carry2.begin(),
|
||
|
thrust::make_permutation_iterator(values_output, interval_counts2.begin()),
|
||
|
binary_op,
|
||
|
thrust::identity<bool>());
|
||
|
}
|
||
|
|
||
|
return thrust::make_pair(keys_output + N, values_output + N);
|
||
|
}
|
||
|
|
||
|
} // end namespace reduce_by_key_detail
|
||
|
|
||
|
|
||
|
template <typename DerivedPolicy,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction>
|
||
|
thrust::pair<OutputIterator1,OutputIterator2>
|
||
|
reduce_by_key(execution_policy<DerivedPolicy> &exec,
|
||
|
InputIterator1 keys_first,
|
||
|
InputIterator1 keys_last,
|
||
|
InputIterator2 values_first,
|
||
|
OutputIterator1 keys_output,
|
||
|
OutputIterator2 values_output,
|
||
|
BinaryPredicate binary_pred,
|
||
|
BinaryFunction binary_op)
|
||
|
{
|
||
|
return reduce_by_key_detail::reduce_by_key
|
||
|
(exec,
|
||
|
keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op,
|
||
|
reduce_by_key_detail::DefaultPolicy<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction>(keys_first, keys_last));
|
||
|
} // end reduce_by_key()
|
||
|
|
||
|
} // end namespace detail
|
||
|
} // end namespace cuda
|
||
|
} // end namespace system
|
||
|
} // end namespace thrust
|
||
|
|
||
|
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_END
|
||
|
|