/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! \file reduce_by_key.inl * \brief Inline file for reduce_by_key.h. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace thrust { namespace system { namespace detail { namespace generic { namespace detail { template struct reduce_by_key_functor { AssociativeOperator binary_op; typedef typename thrust::tuple result_type; __host__ __device__ reduce_by_key_functor(AssociativeOperator _binary_op) : binary_op(_binary_op) {} __host__ __device__ result_type operator()(result_type a, result_type b) { return result_type(thrust::get<1>(b) ? thrust::get<0>(b) : binary_op(thrust::get<0>(a), thrust::get<0>(b)), thrust::get<1>(a) | thrust::get<1>(b)); } }; } // end namespace detail template thrust::pair reduce_by_key(thrust::execution_policy &exec, InputIterator1 keys_first, InputIterator1 keys_last, InputIterator2 values_first, OutputIterator1 keys_output, OutputIterator2 values_output, BinaryPredicate binary_pred, BinaryFunction binary_op) { typedef typename thrust::iterator_traits::difference_type difference_type; typedef typename thrust::iterator_traits::value_type KeyType; typedef unsigned int FlagType; // TODO use difference_type // the pseudocode for deducing the type of the temporary used below: // // if BinaryFunction is AdaptableBinaryFunction // TemporaryType = AdaptableBinaryFunction::result_type // else if OutputIterator2 is a "pure" output iterator // TemporaryType = InputIterator2::value_type // else // TemporaryType = OutputIterator2::value_type // // XXX upon c++0x, TemporaryType needs to be: // result_of::type typedef typename thrust::detail::eval_if< thrust::detail::has_result_type::value, thrust::detail::result_type, thrust::detail::eval_if< thrust::detail::is_output_iterator::value, thrust::iterator_value, thrust::iterator_value > >::type ValueType; if (keys_first == keys_last) return thrust::make_pair(keys_output, values_output); // input size difference_type n = keys_last - keys_first; InputIterator2 values_last = values_first + n; // compute head flags thrust::detail::temporary_array head_flags(exec, n); thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, head_flags.begin() + 1, thrust::detail::not2(binary_pred)); head_flags[0] = 1; // compute tail flags thrust::detail::temporary_array tail_flags(exec, n); //COPY INSTEAD OF TRANSFORM thrust::transform(exec, keys_first, keys_last - 1, keys_first + 1, tail_flags.begin(), thrust::detail::not2(binary_pred)); tail_flags[n-1] = 1; // scan the values by flag thrust::detail::temporary_array scanned_values(exec, n); thrust::detail::temporary_array scanned_tail_flags(exec, n); thrust::inclusive_scan (exec, thrust::make_zip_iterator(thrust::make_tuple(values_first, head_flags.begin())), thrust::make_zip_iterator(thrust::make_tuple(values_last, head_flags.end())), thrust::make_zip_iterator(thrust::make_tuple(scanned_values.begin(), scanned_tail_flags.begin())), detail::reduce_by_key_functor(binary_op)); thrust::exclusive_scan(exec, tail_flags.begin(), tail_flags.end(), scanned_tail_flags.begin(), FlagType(0), thrust::plus()); // number of unique keys FlagType N = scanned_tail_flags[n - 1] + 1; // scatter the keys and accumulated values thrust::scatter_if(exec, keys_first, keys_last, scanned_tail_flags.begin(), head_flags.begin(), keys_output); thrust::scatter_if(exec, scanned_values.begin(), scanned_values.end(), scanned_tail_flags.begin(), tail_flags.begin(), values_output); return thrust::make_pair(keys_output + N, values_output + N); } // end reduce_by_key() template thrust::pair reduce_by_key(thrust::execution_policy &exec, InputIterator1 keys_first, InputIterator1 keys_last, InputIterator2 values_first, OutputIterator1 keys_output, OutputIterator2 values_output) { typedef typename thrust::iterator_value::type KeyType; // use equal_to as default BinaryPredicate return thrust::reduce_by_key(exec, keys_first, keys_last, values_first, keys_output, values_output, thrust::equal_to()); } // end reduce_by_key() template thrust::pair reduce_by_key(thrust::execution_policy &exec, InputIterator1 keys_first, InputIterator1 keys_last, InputIterator2 values_first, OutputIterator1 keys_output, OutputIterator2 values_output, BinaryPredicate binary_pred) { typedef typename thrust::detail::eval_if< thrust::detail::is_output_iterator::value, thrust::iterator_value, thrust::iterator_value >::type T; // use plus as default BinaryFunction return thrust::reduce_by_key(exec, keys_first, keys_last, values_first, keys_output, values_output, binary_pred, thrust::plus()); } // end reduce_by_key() } // end namespace generic } // end namespace detail } // end namespace system } // end namespace thrust