/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include __THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_BEGIN namespace thrust { namespace system { namespace cuda { namespace detail { namespace reduce_by_key_detail { template struct tail_flag_functor { BinaryPredicate binary_pred; // NB: this must be the first member for performance reasons IndexType n; typedef FlagType result_type; tail_flag_functor(IndexType n, BinaryPredicate binary_pred) : n(n), binary_pred(binary_pred) {} // XXX why is this noticably faster? (it may read past the end of input) //FlagType operator()(const thrust::tuple& t) const template __host__ __device__ __thrust_forceinline__ FlagType operator()(const Tuple& t) { if (thrust::get<0>(t) == (n - 1) || !binary_pred(thrust::get<1>(t), thrust::get<2>(t))) return 1; else return 0; } }; template __device__ __thrust_forceinline__ FlagType load_flags(Context context, const unsigned int n, FlagIterator iflags, FlagType (&sflag)[CTA_SIZE]) { FlagType flag_bits = 0; // load flags in unordered fashion for(unsigned int k = 0; k < K; k++) { const unsigned int offset = k*CTA_SIZE + context.thread_index(); if (FullBlock || offset < n) { FlagIterator temp = iflags + offset; if (*temp) flag_bits |= FlagType(1) << k; } } sflag[context.thread_index()] = flag_bits; context.barrier(); flag_bits = 0; // obtain flags for iflags[K * context.thread_index(), K * context.thread_index() + K) for(unsigned int k = 0; k < K; k++) { const unsigned int offset = K * context.thread_index() + k; if (FullBlock || offset < n) { flag_bits |= ((sflag[offset % CTA_SIZE] >> (offset / CTA_SIZE)) & FlagType(1)) << k; } } context.barrier(); sflag[context.thread_index()] = flag_bits; context.barrier(); return flag_bits; } template __device__ __thrust_forceinline__ void load_values(Context context, const unsigned int n, InputIterator2 ivals, ValueType (&sdata)[K][CTA_SIZE + 1]) { for(unsigned int k = 0; k < K; k++) { const unsigned int offset = k*CTA_SIZE + context.thread_index(); if (FullBlock || offset < n) { InputIterator2 temp = ivals + offset; sdata[offset % K][offset / K] = *temp; } } context.barrier(); } template __device__ __thrust_forceinline__ void reduce_by_key_body(Context context, const unsigned int n, InputIterator1 ikeys, InputIterator2 ivals, OutputIterator1 okeys, OutputIterator2 ovals, BinaryPredicate binary_pred, BinaryFunction binary_op, FlagIterator iflags, FlagType (&sflag)[CTA_SIZE], ValueType (&sdata)[K][CTA_SIZE + 1], bool& carry_in, IndexType& carry_index, ValueType& carry_value) { // load flags const FlagType flag_bits = load_flags(context, n, iflags, sflag); const FlagType flag_count = __popc(flag_bits); // TODO hide this behind a template const FlagType left_flag = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1]; const FlagType head_flag = (context.thread_index() == 0 || flag_bits & ((1 << (K - 1)) - 1) || left_flag & (1 << (K - 1))) ? 1 : 0; context.barrier(); // scan flag counts sflag[context.thread_index()] = flag_count; context.barrier(); block::inclusive_scan(context, sflag, thrust::plus()); const FlagType output_position = (context.thread_index() == 0) ? 0 : sflag[context.thread_index() - 1]; const FlagType num_outputs = sflag[CTA_SIZE - 1]; context.barrier(); // shuffle keys and write keys out if (!thrust::detail::is_discard_iterator::value) { // XXX this could be improved for (unsigned int i = 0; i < num_outputs; i += CTA_SIZE) { FlagType position = output_position; for(unsigned int k = 0; k < K; k++) { if (flag_bits & (FlagType(1) << k)) { if (i <= position && position < i + CTA_SIZE) sflag[position - i] = K * context.thread_index() + k; position++; } } context.barrier(); if (i + context.thread_index() < num_outputs) { InputIterator1 tmp1 = ikeys + sflag[context.thread_index()]; OutputIterator1 tmp2 = okeys + (i + context.thread_index()); *tmp2 = *tmp1; } context.barrier(); } } // load values load_values (context, n, ivals, sdata); ValueType ldata[K]; for (unsigned int k = 0; k < K; k++) ldata[k] = sdata[k][context.thread_index()]; // carry in (if necessary) if (context.thread_index() == 0 && carry_in) { // XXX WAR sm_10 issue ValueType tmp1 = carry_value; ldata[0] = binary_op(tmp1, ldata[0]); } context.barrier(); // sum local values { for(unsigned int k = 1; k < K; k++) { const unsigned int offset = K * context.thread_index() + k; if (FullBlock || offset < n) { if (!(flag_bits & (FlagType(1) << (k - 1)))) ldata[k] = binary_op(ldata[k - 1], ldata[k]); } } } // second level segmented scan { // use head flags for segmented scan sflag[context.thread_index()] = head_flag; sdata[K - 1][context.thread_index()] = ldata[K - 1]; context.barrier(); if (FullBlock) block::inclusive_scan_by_flag(context, sflag, sdata[K-1], binary_op); else block::inclusive_scan_by_flag_n(context, sflag, sdata[K-1], n, binary_op); } // update local values if (context.thread_index() > 0) { unsigned int update_bits = (flag_bits << 1) | (left_flag >> (K - 1)); // TODO remove guard #if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC unsigned int update_count = __ffs(update_bits) - 1u; // NB: this might wrap around to UINT_MAX #else unsigned int update_count = 0; #endif // THRUST_DEVICE_COMPILER_NVCC if (!FullBlock && (K + 1) * context.thread_index() > n) update_count = thrust::min(n - K * context.thread_index(), update_count); ValueType left = sdata[K - 1][context.thread_index() - 1]; for(unsigned int k = 0; k < K; k++) { if (k < update_count) ldata[k] = binary_op(left, ldata[k]); } } context.barrier(); // store carry out if (FullBlock) { if (context.thread_index() == CTA_SIZE - 1) { carry_value = ldata[K - 1]; carry_in = (flag_bits & (FlagType(1) << (K - 1))) ? false : true; carry_index = num_outputs; } } else { if (context.thread_index() == (n - 1) / K) { for (unsigned int k = 0; k < K; k++) if (k == (n - 1) % K) carry_value = ldata[k]; carry_in = (flag_bits & (FlagType(1) << ((n - 1) % K))) ? false : true; carry_index = num_outputs; } } // shuffle values { FlagType position = output_position; for(unsigned int k = 0; k < K; k++) { const unsigned int offset = K * context.thread_index() + k; if (FullBlock || offset < n) { if (flag_bits & (FlagType(1) << k)) { sdata[position / CTA_SIZE][position % CTA_SIZE] = ldata[k]; position++; } } } } context.barrier(); // write values out for(unsigned int k = 0; k < K; k++) { const unsigned int offset = CTA_SIZE * k + context.thread_index(); if (offset < num_outputs) { OutputIterator2 tmp = ovals + offset; *tmp = sdata[k][context.thread_index()]; } } context.barrier(); } template struct reduce_by_key_closure { InputIterator1 ikeys; InputIterator2 ivals; OutputIterator1 okeys; OutputIterator2 ovals; BinaryPredicate binary_pred; BinaryFunction binary_op; FlagIterator iflags; IndexIterator interval_counts; ValueIterator interval_values; BoolIterator interval_carry; Decomposition decomp; Context context; typedef Context context_type; reduce_by_key_closure(InputIterator1 ikeys, InputIterator2 ivals, OutputIterator1 okeys, OutputIterator2 ovals, BinaryPredicate binary_pred, BinaryFunction binary_op, FlagIterator iflags, IndexIterator interval_counts, ValueIterator interval_values, BoolIterator interval_carry, Decomposition decomp, Context context = Context()) : ikeys(ikeys), ivals(ivals), okeys(okeys), ovals(ovals), binary_pred(binary_pred), binary_op(binary_op), iflags(iflags), interval_counts(interval_counts), interval_values(interval_values), interval_carry(interval_carry), decomp(decomp), context(context) {} __device__ __thrust_forceinline__ void operator()(void) { typedef typename thrust::iterator_value::type KeyType; typedef typename thrust::iterator_value::type ValueType; typedef typename Decomposition::index_type IndexType; typedef typename thrust::iterator_value::type FlagType; const unsigned int CTA_SIZE = context_type::ThreadsPerBlock::value; // TODO centralize this mapping (__CUDA_ARCH__ -> smem bytes) #if __CUDA_ARCH__ >= 200 const unsigned int SMEM = (48 * 1024); #else const unsigned int SMEM = (16 * 1024) - 256; #endif const unsigned int SMEM_FIXED = CTA_SIZE * sizeof(FlagType) + sizeof(ValueType) + sizeof(IndexType) + sizeof(bool); const unsigned int BOUND_1 = (SMEM - SMEM_FIXED) / ((CTA_SIZE + 1) * sizeof(ValueType)); const unsigned int BOUND_2 = 8 * sizeof(FlagType); const unsigned int BOUND_3 = 6; // TODO replace this with a static_min::value const unsigned int K = (BOUND_1 < BOUND_2) ? (BOUND_1 < BOUND_3 ? BOUND_1 : BOUND_3) : (BOUND_2 < BOUND_3 ? BOUND_2 : BOUND_3); __shared__ detail::uninitialized sflag; __shared__ detail::uninitialized sdata; // padded to avoid bank conflicts __shared__ detail::uninitialized carry_value; // storage for carry in and carry out __shared__ detail::uninitialized carry_index; __shared__ detail::uninitialized carry_in; typename Decomposition::range_type interval = decomp[context.block_index()]; //thrust::system::detail::internal::index_range interval = decomp[context.block_index()]; if (context.thread_index() == 0) { carry_in = false; // act as though the previous segment terminated just before us if (context.block_index() == 0) { carry_index = 0; } else { interval_counts += (context.block_index() - 1); carry_index = *interval_counts; } } context.barrier(); IndexType base = interval.begin(); // advance input and output iterators ikeys += base; ivals += base; iflags += base; okeys += carry_index; ovals += carry_index; const unsigned int unit_size = K * CTA_SIZE; // process full units while (base + unit_size <= interval.end()) { const unsigned int n = unit_size; reduce_by_key_body(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get()); base += unit_size; ikeys += unit_size; ivals += unit_size; iflags += unit_size; okeys += carry_index; ovals += carry_index; } // process partially full unit at end of input (if necessary) if (base < interval.end()) { const unsigned int n = interval.end() - base; reduce_by_key_body(context, n, ikeys, ivals, okeys, ovals, binary_pred, binary_op, iflags, sflag.get(), sdata.get(), carry_in.get(), carry_index.get(), carry_value.get()); } if (context.thread_index() == 0) { interval_values += context.block_index(); interval_carry += context.block_index(); *interval_values = carry_value; *interval_carry = carry_in; } } }; // end reduce_by_key_closure template struct DefaultPolicy { // typedefs typedef unsigned int FlagType; typedef typename thrust::iterator_traits::difference_type IndexType; typedef typename thrust::iterator_traits::value_type KeyType; typedef thrust::system::detail::internal::uniform_decomposition Decomposition; // the pseudocode for deducing the type of the temporary used below: // // if BinaryFunction is AdaptableBinaryFunction // TemporaryType = AdaptableBinaryFunction::result_type // else if OutputIterator2 is a "pure" output iterator // TemporaryType = InputIterator2::value_type // else // TemporaryType = OutputIterator2::value_type // // XXX upon c++0x, TemporaryType needs to be: // result_of::type typedef typename thrust::detail::eval_if< thrust::detail::has_result_type::value, thrust::detail::result_type, thrust::detail::eval_if< thrust::detail::is_output_iterator::value, thrust::iterator_value, thrust::iterator_value > >::type ValueType; // XXX WAR problem on sm_11 // TODO tune this const static unsigned int ThreadsPerBlock = (thrust::detail::is_pod::value) ? 256 : 192; DefaultPolicy(InputIterator1 first1, InputIterator1 last1) : decomp(default_decomposition(last1 - first1)) {} // member variables Decomposition decomp; }; template thrust::pair reduce_by_key(execution_policy &exec, InputIterator1 keys_first, InputIterator1 keys_last, InputIterator2 values_first, OutputIterator1 keys_output, OutputIterator2 values_output, BinaryPredicate binary_pred, BinaryFunction binary_op, Policy policy) { typedef typename Policy::FlagType FlagType; typedef typename Policy::Decomposition Decomposition; typedef typename Policy::IndexType IndexType; typedef typename Policy::KeyType KeyType; typedef typename Policy::ValueType ValueType; // temporary arrays typedef thrust::detail::temporary_array IndexArray; typedef thrust::detail::temporary_array KeyArray; typedef thrust::detail::temporary_array ValueArray; typedef thrust::detail::temporary_array BoolArray; Decomposition decomp = policy.decomp; // input size IndexType n = keys_last - keys_first; if (n == 0) return thrust::make_pair(keys_output, values_output); IndexArray interval_counts(exec, decomp.size()); ValueArray interval_values(exec, decomp.size()); BoolArray interval_carry(exec, decomp.size()); // an ode to c++11 auto typedef thrust::counting_iterator CountingIterator; typedef thrust::transform_iterator< tail_flag_functor, thrust::zip_iterator< thrust::tuple > > FlagIterator; FlagIterator iflag= thrust::make_transform_iterator (thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator(0), keys_first, keys_first + 1)), tail_flag_functor(n, binary_pred)); // count number of tail flags per interval thrust::system::cuda::detail::reduce_intervals(exec, iflag, interval_counts.begin(), thrust::plus(), decomp); thrust::inclusive_scan(exec, interval_counts.begin(), interval_counts.end(), interval_counts.begin(), thrust::plus()); // determine output size const IndexType N = interval_counts[interval_counts.size() - 1]; const static unsigned int ThreadsPerBlock = Policy::ThreadsPerBlock; typedef typename IndexArray::iterator IndexIterator; typedef typename ValueArray::iterator ValueIterator; typedef typename BoolArray::iterator BoolIterator; typedef detail::statically_blocked_thread_array Context; typedef reduce_by_key_closure Closure; Closure closure (keys_first, values_first, keys_output, values_output, binary_pred, binary_op, iflag, interval_counts.begin(), interval_values.begin(), interval_carry.begin(), decomp); detail::launch_closure(closure, decomp.size(), ThreadsPerBlock); if (decomp.size() > 1) { ValueArray interval_values2(exec, decomp.size()); IndexArray interval_counts2(exec, decomp.size()); BoolArray interval_carry2(exec, decomp.size()); IndexType N2 = thrust::reduce_by_key (exec, thrust::make_zip_iterator(thrust::make_tuple(interval_counts.begin(), interval_carry.begin())), thrust::make_zip_iterator(thrust::make_tuple(interval_counts.end(), interval_carry.end())), interval_values.begin(), thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())), interval_values2.begin(), thrust::equal_to< thrust::tuple >(), binary_op).first - thrust::make_zip_iterator(thrust::make_tuple(interval_counts2.begin(), interval_carry2.begin())); thrust::transform_if (exec, interval_values2.begin(), interval_values2.begin() + N2, thrust::make_permutation_iterator(values_output, interval_counts2.begin()), interval_carry2.begin(), thrust::make_permutation_iterator(values_output, interval_counts2.begin()), binary_op, thrust::identity()); } return thrust::make_pair(keys_output + N, values_output + N); } } // end namespace reduce_by_key_detail template thrust::pair reduce_by_key(execution_policy &exec, InputIterator1 keys_first, InputIterator1 keys_last, InputIterator2 values_first, OutputIterator1 keys_output, OutputIterator2 values_output, BinaryPredicate binary_pred, BinaryFunction binary_op) { return reduce_by_key_detail::reduce_by_key (exec, keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op, reduce_by_key_detail::DefaultPolicy(keys_first, keys_last)); } // end reduce_by_key() } // end namespace detail } // end namespace cuda } // end namespace system } // end namespace thrust __THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING_END