You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
7.7 KiB
212 lines
7.7 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
|
|
/*! \file reduce_by_key.inl |
|
* \brief Inline file for reduce_by_key.h. |
|
*/ |
|
|
|
#pragma once |
|
|
|
#include <thrust/iterator/iterator_traits.h> |
|
#include <thrust/iterator/detail/minimum_system.h> |
|
#include <thrust/detail/type_traits.h> |
|
#include <thrust/detail/type_traits/iterator/is_output_iterator.h> |
|
#include <thrust/detail/type_traits/function_traits.h> |
|
#include <thrust/transform.h> |
|
#include <thrust/scatter.h> |
|
#include <thrust/iterator/zip_iterator.h> |
|
#include <limits> |
|
|
|
#include <thrust/detail/internal_functional.h> |
|
#include <thrust/scan.h> |
|
#include <thrust/detail/temporary_array.h> |
|
|
|
namespace thrust |
|
{ |
|
namespace system |
|
{ |
|
namespace detail |
|
{ |
|
namespace generic |
|
{ |
|
namespace detail |
|
{ |
|
|
|
template <typename ValueType, typename TailFlagType, typename AssociativeOperator> |
|
struct reduce_by_key_functor |
|
{ |
|
AssociativeOperator binary_op; |
|
|
|
typedef typename thrust::tuple<ValueType, TailFlagType> result_type; |
|
|
|
__host__ __device__ |
|
reduce_by_key_functor(AssociativeOperator _binary_op) : binary_op(_binary_op) {} |
|
|
|
__host__ __device__ |
|
result_type operator()(result_type a, result_type b) |
|
{ |
|
return result_type(thrust::get<1>(b) ? thrust::get<0>(b) : binary_op(thrust::get<0>(a), thrust::get<0>(b)), |
|
thrust::get<1>(a) | thrust::get<1>(b)); |
|
} |
|
}; |
|
|
|
} // end namespace detail |
|
|
|
|
|
template<typename ExecutionPolicy, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate, |
|
typename BinaryFunction> |
|
thrust::pair<OutputIterator1,OutputIterator2> |
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec, |
|
InputIterator1 keys_first, |
|
InputIterator1 keys_last, |
|
InputIterator2 values_first, |
|
OutputIterator1 keys_output, |
|
OutputIterator2 values_output, |
|
BinaryPredicate binary_pred, |
|
BinaryFunction binary_op) |
|
{ |
|
typedef typename thrust::iterator_traits<InputIterator1>::difference_type difference_type; |
|
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType; |
|
|
|
typedef unsigned int FlagType; // TODO use difference_type |
|
|
|
// the pseudocode for deducing the type of the temporary used below: |
|
// |
|
// if BinaryFunction is AdaptableBinaryFunction |
|
// TemporaryType = AdaptableBinaryFunction::result_type |
|
// else if OutputIterator2 is a "pure" output iterator |
|
// TemporaryType = InputIterator2::value_type |
|
// else |
|
// TemporaryType = OutputIterator2::value_type |
|
// |
|
// XXX upon c++0x, TemporaryType needs to be: |
|
// result_of<BinaryFunction>::type |
|
|
|
typedef typename thrust::detail::eval_if< |
|
thrust::detail::has_result_type<BinaryFunction>::value, |
|
thrust::detail::result_type<BinaryFunction>, |
|
thrust::detail::eval_if< |
|
thrust::detail::is_output_iterator<OutputIterator2>::value, |
|
thrust::iterator_value<InputIterator2>, |
|
thrust::iterator_value<OutputIterator2> |
|
> |
|
>::type ValueType; |
|
|
|
if (keys_first == keys_last) |
|
return thrust::make_pair(keys_output, values_output); |
|
|
|
// input size |
|
difference_type n = keys_last - keys_first; |
|
|
|
InputIterator2 values_last = values_first + n; |
|
|
|
// compute head flags |
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> head_flags(exec, n); |
|
thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, head_flags.begin() + 1, thrust::detail::not2(binary_pred)); |
|
head_flags[0] = 1; |
|
|
|
// compute tail flags |
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> tail_flags(exec, n); //COPY INSTEAD OF TRANSFORM |
|
thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, tail_flags.begin(), thrust::detail::not2(binary_pred)); |
|
tail_flags[n-1] = 1; |
|
|
|
// scan the values by flag |
|
thrust::detail::temporary_array<ValueType,ExecutionPolicy> scanned_values(exec, n); |
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> scanned_tail_flags(exec, n); |
|
|
|
thrust::inclusive_scan |
|
(exec, |
|
thrust::make_zip_iterator(thrust::make_tuple(values_first, head_flags.begin())), |
|
thrust::make_zip_iterator(thrust::make_tuple(values_last, head_flags.end())), |
|
thrust::make_zip_iterator(thrust::make_tuple(scanned_values.begin(), scanned_tail_flags.begin())), |
|
detail::reduce_by_key_functor<ValueType, FlagType, BinaryFunction>(binary_op)); |
|
|
|
thrust::exclusive_scan(exec, tail_flags.begin(), tail_flags.end(), scanned_tail_flags.begin(), FlagType(0), thrust::plus<FlagType>()); |
|
|
|
// number of unique keys |
|
FlagType N = scanned_tail_flags[n - 1] + 1; |
|
|
|
// scatter the keys and accumulated values |
|
thrust::scatter_if(exec, keys_first, keys_last, scanned_tail_flags.begin(), head_flags.begin(), keys_output); |
|
thrust::scatter_if(exec, scanned_values.begin(), scanned_values.end(), scanned_tail_flags.begin(), tail_flags.begin(), values_output); |
|
|
|
return thrust::make_pair(keys_output + N, values_output + N); |
|
} // end reduce_by_key() |
|
|
|
|
|
template<typename ExecutionPolicy, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2> |
|
thrust::pair<OutputIterator1,OutputIterator2> |
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec, |
|
InputIterator1 keys_first, |
|
InputIterator1 keys_last, |
|
InputIterator2 values_first, |
|
OutputIterator1 keys_output, |
|
OutputIterator2 values_output) |
|
{ |
|
typedef typename thrust::iterator_value<InputIterator1>::type KeyType; |
|
|
|
// use equal_to<KeyType> as default BinaryPredicate |
|
return thrust::reduce_by_key(exec, keys_first, keys_last, values_first, keys_output, values_output, thrust::equal_to<KeyType>()); |
|
} // end reduce_by_key() |
|
|
|
|
|
template<typename ExecutionPolicy, |
|
typename InputIterator1, |
|
typename InputIterator2, |
|
typename OutputIterator1, |
|
typename OutputIterator2, |
|
typename BinaryPredicate> |
|
thrust::pair<OutputIterator1,OutputIterator2> |
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec, |
|
InputIterator1 keys_first, |
|
InputIterator1 keys_last, |
|
InputIterator2 values_first, |
|
OutputIterator1 keys_output, |
|
OutputIterator2 values_output, |
|
BinaryPredicate binary_pred) |
|
{ |
|
typedef typename thrust::detail::eval_if< |
|
thrust::detail::is_output_iterator<OutputIterator2>::value, |
|
thrust::iterator_value<InputIterator2>, |
|
thrust::iterator_value<OutputIterator2> |
|
>::type T; |
|
|
|
// use plus<T> as default BinaryFunction |
|
return thrust::reduce_by_key(exec, |
|
keys_first, keys_last, |
|
values_first, |
|
keys_output, |
|
values_output, |
|
binary_pred, |
|
thrust::plus<T>()); |
|
} // end reduce_by_key() |
|
|
|
|
|
} // end namespace generic |
|
} // end namespace detail |
|
} // end namespace system |
|
} // end namespace thrust |
|
|
|
|