2014-03-18 22:17:40 +01:00

706 lines
23 KiB
C++

/*
* Copyright 2008-2012 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/detail/config.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/system/detail/generic/select_system.h>
#include <thrust/detail/type_traits.h>
#include <thrust/detail/type_traits/function_traits.h>
#include <thrust/detail/type_traits/iterator/is_output_iterator.h>
#include <thrust/detail/type_traits/iterator/is_discard_iterator.h>
#include <thrust/detail/minmax.h>
#include <thrust/detail/internal_functional.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/reduce.h>
#include <thrust/scan.h>
#include <thrust/system/cuda/detail/default_decomposition.h>
#include <thrust/system/cuda/detail/block/inclusive_scan.h>
#include <thrust/system/cuda/detail/execution_policy.h>
#include <thrust/system/cuda/detail/detail/launch_closure.h>
#include <thrust/system/cuda/detail/reduce_intervals.h>
#include <thrust/system/cuda/detail/detail/uninitialized.h>
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN
namespace thrust
{
namespace system
{
namespace cuda
{
namespace detail
{
namespace reduce_by_key_detail
{
template <typename FlagType, typename IndexType, typename KeyType, typename BinaryPredicate>
struct tail_flag_functor
{
BinaryPredicate binary_pred; // NB: this must be the first member for performance reasons
IndexType n;
typedef FlagType result_type;
tail_flag_functor(IndexType n, BinaryPredicate binary_pred)
: n(n), binary_pred(binary_pred)
{}
// XXX why is this noticably faster? (it may read past the end of input)
//FlagType operator()(const thrust::tuple<IndexType,KeyType,KeyType>& t) const
template <typename Tuple>
__host__ __device__ __thrust_forceinline__
FlagType operator()(const Tuple& t)
{
if (thrust::get<0>(t) == (n - 1) || !binary_pred(thrust::get<1>(t), thrust::get<2>(t)))
return 1;
else
return 0;
}
};
template <unsigned int CTA_SIZE,
unsigned int K,
bool FullBlock,
typename Context,
typename FlagIterator,
typename FlagType>
__device__ __thrust_forceinline__
FlagType load_flags(Context context,
const unsigned int n,
FlagIterator iflags,
FlagType (&sflag)[CTA_SIZE])
{
FlagType flag_bits = 0;
// load flags in unordered fashion
for(unsigned int k = 0; k < K; k++)
{
const unsigned int offset = k*CTA_SIZE + context.thread_index();
if (FullBlock || offset < n)
{
FlagIterator temp = iflags + offset;
if (*temp)
flag_bits |= FlagType(1) << k;
}
}
sflag[context.thread_index()] = flag_bits;
context.barrier();
flag_bits = 0;
// obtain flags for iflags[K * context.thread_index(), K * context.thread_index() + K)
for(unsigned int k = 0; k < K; k++)
{
const unsigned int offset = K * context.thread_index() + k;
if (FullBlock || offset < n)
{
flag_bits |= ((sflag[offset % CTA_SIZE] >> (offset / CTA_SIZE)) & FlagType(1)) << k;
}
}
context.barrier();
sflag[context.thread_index()] = flag_bits;
context.barrier();
return flag_bits;
}
template <unsigned int CTA_SIZE,
unsigned int K,
bool FullBlock,
typename Context,
typename InputIterator2,
typename ValueType>
__device__ __thrust_forceinline__
void load_values(Context context,
const unsigned int n,
InputIterator2 ivals,
ValueType (&sdata)[K][CTA_SIZE + 1])
{
for(unsigned int k = 0; k < K; k++)
{
const unsigned int offset = k*CTA_SIZE + context.thread_index();
if (FullBlock || offset < n)
{
InputIterator2 temp = ivals + offset;
sdata[offset % K][offset / K] = *temp;
}
}
context.barrier();
}
template <unsigned int CTA_SIZE,
unsigned int K,
bool FullBlock,
typename Context,
typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction,
typename FlagIterator,
typename FlagType,
typename IndexType,
typename ValueType>
__device__ __thrust_forceinline__
void reduce_by_key_body(Context context,
const unsigned int n,
InputIterator1 ikeys,
InputIterator2 ivals,
OutputIterator1 okeys,
OutputIterator2 ovals,
BinaryPredicate binary_pred,
BinaryFunction binary_op,
FlagIterator iflags,
FlagType (&sflag)[CTA_SIZE],
ValueType (&sdata)[K][CTA_SIZE + 1],
bool& carry_in,
IndexType& carry_index,
ValueType& carry_value)
{
// load flags
const FlagType flag_bits = load_flags<CTA_SIZE,K,FullBlock>(context, n, iflags, sflag);
const FlagType flag_count = __popc(flag_bits); // TODO hide this behind a template
const FlagType left_flag = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1];
const FlagType head_flag = (context.thread_index() == 0 || flag_bits & ((1 << (K - 1)) - 1) || left_flag & (1 << (K - 1))) ? 1 : 0;
context.barrier();
// scan flag counts
sflag[context.thread_index()] = flag_count; context.barrier();
block::inclusive_scan(context, sflag, thrust::plus<FlagType>());
const FlagType output_position = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1];
const FlagType num_outputs = sflag[CTA_SIZE - 1];
context.barrier();
// shuffle keys and write keys out
if (!thrust::detail::is_discard_iterator<OutputIterator1>::value)
{
// XXX this could be improved
for (unsigned int i = 0; i < num_outputs; i += CTA_SIZE)
{
FlagType position = output_position;
for(unsigned int k = 0; k < K; k++)
{
if (flag_bits & (FlagType(1) << k))
{
if (i <= position && position < i + CTA_SIZE)
sflag[position - i] = K * context.thread_index() + k;
position++;
}
}
context.barrier();
if (i + context.thread_index() < num_outputs)
{
InputIterator1 tmp1 = ikeys + sflag[context.thread_index()];
OutputIterator1 tmp2 = okeys + (i + context.thread_index());
*tmp2 = *tmp1;
}
context.barrier();
}
}
// load values
load_values<CTA_SIZE,K,FullBlock> (context, n, ivals, sdata);
ValueType ldata[K];
for (unsigned int k = 0; k < K; k++)
ldata[k] = sdata[k][context.thread_index()];
// carry in (if necessary)
if (context.thread_index() == 0 && carry_in)
{
// XXX WAR sm_10 issue
ValueType tmp1 = carry_value;
ldata[0] = binary_op(tmp1, ldata[0]);
}
context.barrier();
// sum local values
{
for(unsigned int k = 1; k < K; k++)
{
const unsigned int offset = K * context.thread_index() + k;
if (FullBlock || offset < n)
{
if (!(flag_bits & (FlagType(1) << (k - 1))))
ldata[k] = binary_op(ldata[k - 1], ldata[k]);
}
}
}
// second level segmented scan
{
// use head flags for segmented scan
sflag[context.thread_index()] = head_flag; sdata[K - 1][context.thread_index()] = ldata[K - 1]; context.barrier();
if (FullBlock)
block::inclusive_scan_by_flag(context, sflag, sdata[K-1], binary_op);
else
block::inclusive_scan_by_flag_n(context, sflag, sdata[K-1], n, binary_op);
}
// update local values
if (context.thread_index() > 0)
{
unsigned int update_bits = (flag_bits << 1) | (left_flag >> (K - 1));
// TODO remove guard
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
unsigned int update_count = __ffs(update_bits) - 1u; // NB: this might wrap around to UINT_MAX
#else
unsigned int update_count = 0;
#endif // THRUST_DEVICE_COMPILER_NVCC
if (!FullBlock && (K + 1) * context.thread_index() > n)
update_count = thrust::min(n - K * context.thread_index(), update_count);
ValueType left = sdata[K - 1][context.thread_index() - 1];
for(unsigned int k = 0; k < K; k++)
{
if (k < update_count)
ldata[k] = binary_op(left, ldata[k]);
}
}
context.barrier();
// store carry out
if (FullBlock)
{
if (context.thread_index() == CTA_SIZE - 1)
{
carry_value = ldata[K - 1];
carry_in = (flag_bits & (FlagType(1) << (K - 1))) ? false : true;
carry_index = num_outputs;
}
}
else
{
if (context.thread_index() == (n - 1) / K)
{
for (unsigned int k = 0; k < K; k++)
if (k == (n - 1) % K)
carry_value = ldata[k];
carry_in = (flag_bits & (FlagType(1) << ((n - 1) % K))) ? false : true;
carry_index = num_outputs;
}
}
// shuffle values
{
FlagType position = output_position;
for(unsigned int k = 0; k < K; k++)
{
const unsigned int offset = K * context.thread_index() + k;
if (FullBlock || offset < n)
{
if (flag_bits & (FlagType(1) << k))
{
sdata[position / CTA_SIZE][position % CTA_SIZE] = ldata[k];
position++;
}
}
}
}
context.barrier();
// write values out
for(unsigned int k = 0; k < K; k++)
{
const unsigned int offset = CTA_SIZE * k + context.thread_index();
if (offset < num_outputs)
{
OutputIterator2 tmp = ovals + offset;
*tmp = sdata[k][context.thread_index()];
}
}
context.barrier();
}
template <typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction,
typename FlagIterator,
typename IndexIterator,
typename ValueIterator,
typename BoolIterator,
typename Decomposition,
typename Context>
struct reduce_by_key_closure
{
InputIterator1 ikeys;
InputIterator2 ivals;
OutputIterator1 okeys;
OutputIterator2 ovals;
BinaryPredicate binary_pred;
BinaryFunction binary_op;
FlagIterator iflags;
IndexIterator interval_counts;
ValueIterator interval_values;
BoolIterator interval_carry;
Decomposition decomp;
Context context;
typedef Context context_type;
reduce_by_key_closure(InputIterator1 ikeys,
InputIterator2 ivals,
OutputIterator1 okeys,
OutputIterator2 ovals,
BinaryPredicate binary_pred,
BinaryFunction binary_op,
FlagIterator iflags,
IndexIterator interval_counts,
ValueIterator interval_values,
BoolIterator interval_carry,
Decomposition decomp,
Context context = Context())
: ikeys(ikeys), ivals(ivals), okeys(okeys), ovals(ovals), binary_pred(binary_pred), binary_op(binary_op),
iflags(iflags), interval_counts(interval_counts), interval_values(interval_values), interval_carry(interval_carry),
decomp(decomp), context(context) {}
__device__ __thrust_forceinline__
void operator()(void)
{
typedef typename thrust::iterator_value<InputIterator1>::type KeyType;
typedef typename thrust::iterator_value<ValueIterator>::type ValueType;
typedef typename Decomposition::index_type IndexType;
typedef typename thrust::iterator_value<FlagIterator>::type FlagType;
const unsigned int CTA_SIZE = context_type::ThreadsPerBlock::value;
// TODO centralize this mapping (__CUDA_ARCH__ -> smem bytes)
#if __CUDA_ARCH__ >= 200
const unsigned int SMEM = (48 * 1024);
#else
const unsigned int SMEM = (16 * 1024) - 256;
#endif
const unsigned int SMEM_FIXED = CTA_SIZE * sizeof(FlagType) + sizeof(ValueType) + sizeof(IndexType) + sizeof(bool);
const unsigned int BOUND_1 = (SMEM - SMEM_FIXED) / ((CTA_SIZE + 1) * sizeof(ValueType));
const unsigned int BOUND_2 = 8 * sizeof(FlagType);
const unsigned int BOUND_3 = 6;
// TODO replace this with a static_min<BOUND_1,BOUND_2,BOUND_3>::value
const unsigned int K = (BOUND_1 < BOUND_2) ? (BOUND_1 < BOUND_3 ? BOUND_1 : BOUND_3) : (BOUND_2 < BOUND_3 ? BOUND_2 : BOUND_3);
__shared__ detail::uninitialized<FlagType[CTA_SIZE]> sflag;
__shared__ detail::uninitialized<ValueType[K][CTA_SIZE + 1]> sdata; // padded to avoid bank conflicts
__shared__ detail::uninitialized<ValueType> carry_value; // storage for carry in and carry out
__shared__ detail::uninitialized<IndexType> carry_index;
__shared__ detail::uninitialized<bool> carry_in;
typename Decomposition::range_type interval = decomp[context.block_index()];
//thrust::system::detail::internal::index_range<IndexType> interval = decomp[context.block_index()];
if (context.thread_index() == 0)
{
carry_in = false; // act as though the previous segment terminated just before us
if (context.block_index() == 0)
{
carry_index = 0;
}
else
{
interval_counts += (context.block_index() - 1);
carry_index = *interval_counts;
}
}
context.barrier();
IndexType base = interval.begin();
// advance input and output iterators
ikeys += base;
ivals += base;
iflags += base;
okeys += carry_index;
ovals += carry_index;
const unsigned int unit_size = K * CTA_SIZE;
// process full units
while (base + unit_size <= interval.end())
{
const unsigned int n = unit_size;
reduce_by_key_body<CTA_SIZE,K,true>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get());
base += unit_size;
ikeys += unit_size;
ivals += unit_size;
iflags += unit_size;
okeys += carry_index;
ovals += carry_index;
}
// process partially full unit at end of input (if necessary)
if (base < interval.end())
{
const unsigned int n = interval.end() - base;
reduce_by_key_body<CTA_SIZE,K,false>(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get());
}
if (context.thread_index() == 0)
{
interval_values += context.block_index();
interval_carry += context.block_index();
*interval_values = carry_value;
*interval_carry = carry_in;
}
}
}; // end reduce_by_key_closure
template <typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction>
struct DefaultPolicy
{
// typedefs
typedef unsigned int FlagType;
typedef typename thrust::iterator_traits<InputIterator1>::difference_type IndexType;
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
typedef thrust::system::detail::internal::uniform_decomposition<IndexType> Decomposition;
// the pseudocode for deducing the type of the temporary used below:
//
// if BinaryFunction is AdaptableBinaryFunction
// TemporaryType = AdaptableBinaryFunction::result_type
// else if OutputIterator2 is a "pure" output iterator
// TemporaryType = InputIterator2::value_type
// else
// TemporaryType = OutputIterator2::value_type
//
// XXX upon c++0x, TemporaryType needs to be:
// result_of<BinaryFunction>::type
typedef typename thrust::detail::eval_if<
thrust::detail::has_result_type<BinaryFunction>::value,
thrust::detail::result_type<BinaryFunction>,
thrust::detail::eval_if<
thrust::detail::is_output_iterator<OutputIterator2>::value,
thrust::iterator_value<InputIterator2>,
thrust::iterator_value<OutputIterator2>
>
>::type ValueType;
// XXX WAR problem on sm_11
// TODO tune this
const static unsigned int ThreadsPerBlock = (thrust::detail::is_pod<ValueType>::value) ? 256 : 192;
DefaultPolicy(InputIterator1 first1, InputIterator1 last1)
: decomp(default_decomposition<IndexType>(last1 - first1))
{}
// member variables
Decomposition decomp;
};
template <typename DerivedPolicy,
typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction,
typename Policy>
thrust::pair<OutputIterator1,OutputIterator2>
reduce_by_key(execution_policy<DerivedPolicy> &exec,
InputIterator1 keys_first,
InputIterator1 keys_last,
InputIterator2 values_first,
OutputIterator1 keys_output,
OutputIterator2 values_output,
BinaryPredicate binary_pred,
BinaryFunction binary_op,
Policy policy)
{
typedef typename Policy::FlagType FlagType;
typedef typename Policy::Decomposition Decomposition;
typedef typename Policy::IndexType IndexType;
typedef typename Policy::KeyType KeyType;
typedef typename Policy::ValueType ValueType;
// temporary arrays
typedef thrust::detail::temporary_array<IndexType,DerivedPolicy> IndexArray;
typedef thrust::detail::temporary_array<KeyType,DerivedPolicy> KeyArray;
typedef thrust::detail::temporary_array<ValueType,DerivedPolicy> ValueArray;
typedef thrust::detail::temporary_array<bool,DerivedPolicy> BoolArray;
Decomposition decomp = policy.decomp;
// input size
IndexType n = keys_last - keys_first;
if (n == 0)
return thrust::make_pair(keys_output, values_output);
IndexArray interval_counts(exec, decomp.size());
ValueArray interval_values(exec, decomp.size());
BoolArray interval_carry(exec, decomp.size());
// an ode to c++11 auto
typedef thrust::counting_iterator<IndexType> CountingIterator;
typedef thrust::transform_iterator<
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>,
thrust::zip_iterator<
thrust::tuple<CountingIterator,InputIterator1,InputIterator1>
>
> FlagIterator;
FlagIterator iflag= thrust::make_transform_iterator
(thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator<IndexType>(0), keys_first, keys_first + 1)),
tail_flag_functor<FlagType,IndexType,KeyType,BinaryPredicate>(n, binary_pred));
// count number of tail flags per interval
thrust::system::cuda::detail::reduce_intervals(exec, iflag, interval_counts.begin(), thrust::plus<IndexType>(), decomp);
thrust::inclusive_scan(exec,
interval_counts.begin(), interval_counts.end(),
interval_counts.begin(),
thrust::plus<IndexType>());
// determine output size
const IndexType N = interval_counts[interval_counts.size() - 1];
const static unsigned int ThreadsPerBlock = Policy::ThreadsPerBlock;
typedef typename IndexArray::iterator IndexIterator;
typedef typename ValueArray::iterator ValueIterator;
typedef typename BoolArray::iterator BoolIterator;
typedef detail::statically_blocked_thread_array<ThreadsPerBlock> Context;
typedef reduce_by_key_closure<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction,
FlagIterator,IndexIterator,ValueIterator,BoolIterator,Decomposition,Context> Closure;
Closure closure
(keys_first, values_first,
keys_output, values_output,
binary_pred, binary_op,
iflag,
interval_counts.begin(),
interval_values.begin(),
interval_carry.begin(),
decomp);
detail::launch_closure(closure, decomp.size(), ThreadsPerBlock);
if (decomp.size() > 1)
{
ValueArray interval_values2(exec, decomp.size());
IndexArray interval_counts2(exec, decomp.size());
BoolArray interval_carry2(exec, decomp.size());
IndexType N2 =
thrust::reduce_by_key
(exec,
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.begin(), interval_carry.begin())),
thrust::make_zip_iterator(thrust::make_tuple(interval_counts.end(), interval_carry.end())),
interval_values.begin(),
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())),
interval_values2.begin(),
thrust::equal_to< thrust::tuple<IndexType,bool> >(),
binary_op).first
-
thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin()));
thrust::transform_if
(exec,
interval_values2.begin(), interval_values2.begin() + N2,
thrust::make_permutation_iterator(values_output, interval_counts2.begin()),
interval_carry2.begin(),
thrust::make_permutation_iterator(values_output, interval_counts2.begin()),
binary_op,
thrust::identity<bool>());
}
return thrust::make_pair(keys_output + N, values_output + N);
}
} // end namespace reduce_by_key_detail
template <typename DerivedPolicy,
typename InputIterator1,
typename InputIterator2,
typename OutputIterator1,
typename OutputIterator2,
typename BinaryPredicate,
typename BinaryFunction>
thrust::pair<OutputIterator1,OutputIterator2>
reduce_by_key(execution_policy<DerivedPolicy> &exec,
InputIterator1 keys_first,
InputIterator1 keys_last,
InputIterator2 values_first,
OutputIterator1 keys_output,
OutputIterator2 values_output,
BinaryPredicate binary_pred,
BinaryFunction binary_op)
{
return reduce_by_key_detail::reduce_by_key
(exec,
keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op,
reduce_by_key_detail::DefaultPolicy<InputIterator1,InputIterator2,OutputIterator1,OutputIterator2,BinaryPredicate,BinaryFunction>(keys_first, keys_last));
} // end reduce_by_key()
} // end namespace detail
} // end namespace cuda
} // end namespace system
} // end namespace thrust
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_END