You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
7.7 KiB
213 lines
7.7 KiB
11 years ago
|
/*
|
||
|
* Copyright 2008-2012 NVIDIA Corporation
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
|
||
|
/*! \file reduce_by_key.inl
|
||
|
* \brief Inline file for reduce_by_key.h.
|
||
|
*/
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include <thrust/iterator/iterator_traits.h>
|
||
|
#include <thrust/iterator/detail/minimum_system.h>
|
||
|
#include <thrust/detail/type_traits.h>
|
||
|
#include <thrust/detail/type_traits/iterator/is_output_iterator.h>
|
||
|
#include <thrust/detail/type_traits/function_traits.h>
|
||
|
#include <thrust/transform.h>
|
||
|
#include <thrust/scatter.h>
|
||
|
#include <thrust/iterator/zip_iterator.h>
|
||
|
#include <limits>
|
||
|
|
||
|
#include <thrust/detail/internal_functional.h>
|
||
|
#include <thrust/scan.h>
|
||
|
#include <thrust/detail/temporary_array.h>
|
||
|
|
||
|
namespace thrust
|
||
|
{
|
||
|
namespace system
|
||
|
{
|
||
|
namespace detail
|
||
|
{
|
||
|
namespace generic
|
||
|
{
|
||
|
namespace detail
|
||
|
{
|
||
|
|
||
|
template <typename ValueType, typename TailFlagType, typename AssociativeOperator>
|
||
|
struct reduce_by_key_functor
|
||
|
{
|
||
|
AssociativeOperator binary_op;
|
||
|
|
||
|
typedef typename thrust::tuple<ValueType, TailFlagType> result_type;
|
||
|
|
||
|
__host__ __device__
|
||
|
reduce_by_key_functor(AssociativeOperator _binary_op) : binary_op(_binary_op) {}
|
||
|
|
||
|
__host__ __device__
|
||
|
result_type operator()(result_type a, result_type b)
|
||
|
{
|
||
|
return result_type(thrust::get<1>(b) ? thrust::get<0>(b) : binary_op(thrust::get<0>(a), thrust::get<0>(b)),
|
||
|
thrust::get<1>(a) | thrust::get<1>(b));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // end namespace detail
|
||
|
|
||
|
|
||
|
template<typename ExecutionPolicy,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate,
|
||
|
typename BinaryFunction>
|
||
|
thrust::pair<OutputIterator1,OutputIterator2>
|
||
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec,
|
||
|
InputIterator1 keys_first,
|
||
|
InputIterator1 keys_last,
|
||
|
InputIterator2 values_first,
|
||
|
OutputIterator1 keys_output,
|
||
|
OutputIterator2 values_output,
|
||
|
BinaryPredicate binary_pred,
|
||
|
BinaryFunction binary_op)
|
||
|
{
|
||
|
typedef typename thrust::iterator_traits<InputIterator1>::difference_type difference_type;
|
||
|
typedef typename thrust::iterator_traits<InputIterator1>::value_type KeyType;
|
||
|
|
||
|
typedef unsigned int FlagType; // TODO use difference_type
|
||
|
|
||
|
// the pseudocode for deducing the type of the temporary used below:
|
||
|
//
|
||
|
// if BinaryFunction is AdaptableBinaryFunction
|
||
|
// TemporaryType = AdaptableBinaryFunction::result_type
|
||
|
// else if OutputIterator2 is a "pure" output iterator
|
||
|
// TemporaryType = InputIterator2::value_type
|
||
|
// else
|
||
|
// TemporaryType = OutputIterator2::value_type
|
||
|
//
|
||
|
// XXX upon c++0x, TemporaryType needs to be:
|
||
|
// result_of<BinaryFunction>::type
|
||
|
|
||
|
typedef typename thrust::detail::eval_if<
|
||
|
thrust::detail::has_result_type<BinaryFunction>::value,
|
||
|
thrust::detail::result_type<BinaryFunction>,
|
||
|
thrust::detail::eval_if<
|
||
|
thrust::detail::is_output_iterator<OutputIterator2>::value,
|
||
|
thrust::iterator_value<InputIterator2>,
|
||
|
thrust::iterator_value<OutputIterator2>
|
||
|
>
|
||
|
>::type ValueType;
|
||
|
|
||
|
if (keys_first == keys_last)
|
||
|
return thrust::make_pair(keys_output, values_output);
|
||
|
|
||
|
// input size
|
||
|
difference_type n = keys_last - keys_first;
|
||
|
|
||
|
InputIterator2 values_last = values_first + n;
|
||
|
|
||
|
// compute head flags
|
||
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> head_flags(exec, n);
|
||
|
thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, head_flags.begin() + 1, thrust::detail::not2(binary_pred));
|
||
|
head_flags[0] = 1;
|
||
|
|
||
|
// compute tail flags
|
||
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> tail_flags(exec, n); //COPY INSTEAD OF TRANSFORM
|
||
|
thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, tail_flags.begin(), thrust::detail::not2(binary_pred));
|
||
|
tail_flags[n-1] = 1;
|
||
|
|
||
|
// scan the values by flag
|
||
|
thrust::detail::temporary_array<ValueType,ExecutionPolicy> scanned_values(exec, n);
|
||
|
thrust::detail::temporary_array<FlagType,ExecutionPolicy> scanned_tail_flags(exec, n);
|
||
|
|
||
|
thrust::inclusive_scan
|
||
|
(exec,
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(values_first, head_flags.begin())),
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(values_last, head_flags.end())),
|
||
|
thrust::make_zip_iterator(thrust::make_tuple(scanned_values.begin(), scanned_tail_flags.begin())),
|
||
|
detail::reduce_by_key_functor<ValueType, FlagType, BinaryFunction>(binary_op));
|
||
|
|
||
|
thrust::exclusive_scan(exec, tail_flags.begin(), tail_flags.end(), scanned_tail_flags.begin(), FlagType(0), thrust::plus<FlagType>());
|
||
|
|
||
|
// number of unique keys
|
||
|
FlagType N = scanned_tail_flags[n - 1] + 1;
|
||
|
|
||
|
// scatter the keys and accumulated values
|
||
|
thrust::scatter_if(exec, keys_first, keys_last, scanned_tail_flags.begin(), head_flags.begin(), keys_output);
|
||
|
thrust::scatter_if(exec, scanned_values.begin(), scanned_values.end(), scanned_tail_flags.begin(), tail_flags.begin(), values_output);
|
||
|
|
||
|
return thrust::make_pair(keys_output + N, values_output + N);
|
||
|
} // end reduce_by_key()
|
||
|
|
||
|
|
||
|
template<typename ExecutionPolicy,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2>
|
||
|
thrust::pair<OutputIterator1,OutputIterator2>
|
||
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec,
|
||
|
InputIterator1 keys_first,
|
||
|
InputIterator1 keys_last,
|
||
|
InputIterator2 values_first,
|
||
|
OutputIterator1 keys_output,
|
||
|
OutputIterator2 values_output)
|
||
|
{
|
||
|
typedef typename thrust::iterator_value<InputIterator1>::type KeyType;
|
||
|
|
||
|
// use equal_to<KeyType> as default BinaryPredicate
|
||
|
return thrust::reduce_by_key(exec, keys_first, keys_last, values_first, keys_output, values_output, thrust::equal_to<KeyType>());
|
||
|
} // end reduce_by_key()
|
||
|
|
||
|
|
||
|
template<typename ExecutionPolicy,
|
||
|
typename InputIterator1,
|
||
|
typename InputIterator2,
|
||
|
typename OutputIterator1,
|
||
|
typename OutputIterator2,
|
||
|
typename BinaryPredicate>
|
||
|
thrust::pair<OutputIterator1,OutputIterator2>
|
||
|
reduce_by_key(thrust::execution_policy<ExecutionPolicy> &exec,
|
||
|
InputIterator1 keys_first,
|
||
|
InputIterator1 keys_last,
|
||
|
InputIterator2 values_first,
|
||
|
OutputIterator1 keys_output,
|
||
|
OutputIterator2 values_output,
|
||
|
BinaryPredicate binary_pred)
|
||
|
{
|
||
|
typedef typename thrust::detail::eval_if<
|
||
|
thrust::detail::is_output_iterator<OutputIterator2>::value,
|
||
|
thrust::iterator_value<InputIterator2>,
|
||
|
thrust::iterator_value<OutputIterator2>
|
||
|
>::type T;
|
||
|
|
||
|
// use plus<T> as default BinaryFunction
|
||
|
return thrust::reduce_by_key(exec,
|
||
|
keys_first, keys_last,
|
||
|
values_first,
|
||
|
keys_output,
|
||
|
values_output,
|
||
|
binary_pred,
|
||
|
thrust::plus<T>());
|
||
|
} // end reduce_by_key()
|
||
|
|
||
|
|
||
|
} // end namespace generic
|
||
|
} // end namespace detail
|
||
|
} // end namespace system
|
||
|
} // end namespace thrust
|
||
|
|