2014-03-18 22:17:40 +01:00

288 lines
11 KiB
C++

/*
* Copyright 2008-2012 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file sort.inl
* \brief Inline file for sort.h
*/
#include <thrust/system/cuda/detail/detail/stable_merge_sort.h>
#include <thrust/system/cuda/detail/detail/stable_primitive_sort.h>
#include <thrust/reverse.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/detail/type_traits.h>
#include <thrust/system/cuda/detail/execution_policy.h>
#include <thrust/system/cuda/detail/temporary_indirect_permutation.h>
#include <thrust/detail/trivial_sequence.h>
/*
* This file implements the following dispatch procedure for cuda::stable_sort()
* and cuda::stable_sort_by_key(). The first level inspects the KeyType
* and StrictWeakOrdering to determine whether a sort assuming primitive-typed
* data may be applied.
*
* If a sort assuming primitive-typed data can be applied (i.e., a radix sort),
* the input ranges are first trivialized (turned into simple contiguous ranges
* if they are not already). To implement descending orderings, an ascending
* sort will be reversed.
*
* If a sort assuming primitive-typed data cannot be applied, a comparison-based
* sort is used. Depending on the size of the key and value types, one level of
* indirection may be applied to their input ranges. This transformation
* may be applied to either range to convert an ill-suited problem (i.e. sorting with
* large keys or large value) into a problem more amenable to the underlying
* merge sort algorithm.
*/
namespace thrust
{
namespace system
{
namespace cuda
{
namespace detail
{
namespace stable_sort_detail
{
template<typename KeyType, typename StrictWeakCompare>
struct can_use_primitive_sort
: thrust::detail::and_<
thrust::detail::is_arithmetic<KeyType>,
thrust::detail::or_<
thrust::detail::is_same<StrictWeakCompare,thrust::less<KeyType> >,
thrust::detail::is_same<StrictWeakCompare,thrust::greater<KeyType> >
>
>
{};
template<typename RandomAccessIterator, typename StrictWeakCompare>
struct enable_if_primitive_sort
: thrust::detail::enable_if<
can_use_primitive_sort<
typename iterator_value<RandomAccessIterator>::type,
StrictWeakCompare
>::value
>
{};
template<typename RandomAccessIterator, typename StrictWeakCompare>
struct enable_if_comparison_sort
: thrust::detail::disable_if<
can_use_primitive_sort<
typename iterator_value<RandomAccessIterator>::type,
StrictWeakCompare
>::value
>
{};
template<typename DerivedPolicy,
typename RandomAccessIterator,
typename StrictWeakOrdering>
typename enable_if_primitive_sort<RandomAccessIterator,StrictWeakOrdering>::type
stable_sort(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator first,
RandomAccessIterator last,
StrictWeakOrdering comp)
{
// ensure sequence has trivial iterators
thrust::detail::trivial_sequence<RandomAccessIterator,DerivedPolicy> keys(exec, first, last);
// CUDA path for thrust::stable_sort with primitive keys
// (e.g. int, float, short, etc.) and a less<T> or greater<T> comparison
// method is implemented with a primitive sort
thrust::system::cuda::detail::detail::stable_primitive_sort(exec, keys.begin(), keys.end());
// copy results back, if necessary
if(!thrust::detail::is_trivial_iterator<RandomAccessIterator>::value)
{
thrust::copy(exec, keys.begin(), keys.end(), first);
}
// if comp is greater<T> then reverse the keys
typedef typename thrust::iterator_traits<RandomAccessIterator>::value_type KeyType;
const static bool reverse = thrust::detail::is_same<StrictWeakOrdering, typename thrust::greater<KeyType> >::value;
if(reverse)
{
thrust::reverse(first, last);
}
}
template<typename DerivedPolicy,
typename RandomAccessIterator,
typename StrictWeakOrdering>
typename enable_if_comparison_sort<RandomAccessIterator,StrictWeakOrdering>::type
stable_sort(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator first,
RandomAccessIterator last,
StrictWeakOrdering comp)
{
// decide whether to sort keys indirectly
typedef typename thrust::iterator_value<RandomAccessIterator>::type KeyType;
typedef thrust::detail::integral_constant<bool, (sizeof(KeyType) > 8)> use_key_indirection;
conditional_temporary_indirect_ordering<use_key_indirection, DerivedPolicy, RandomAccessIterator, StrictWeakOrdering> potentially_indirect_keys(derived_cast(exec), first, last, comp);
thrust::system::cuda::detail::detail::stable_merge_sort(exec,
potentially_indirect_keys.begin(),
potentially_indirect_keys.end(),
potentially_indirect_keys.comp());
}
template<typename DerivedPolicy,
typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename StrictWeakOrdering>
typename enable_if_primitive_sort<RandomAccessIterator1,StrictWeakOrdering>::type
stable_sort_by_key(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator1 keys_first,
RandomAccessIterator1 keys_last,
RandomAccessIterator2 values_first,
StrictWeakOrdering comp)
{
// path for thrust::stable_sort_by_key with primitive keys
// (e.g. int, float, short, etc.) and a less<T> or greater<T> comparison
// method is implemented with stable_primitive_sort_by_key
// if comp is greater<T> then reverse the keys and values
typedef typename thrust::iterator_traits<RandomAccessIterator1>::value_type KeyType;
const static bool reverse = thrust::detail::is_same<StrictWeakOrdering, typename thrust::greater<KeyType> >::value;
// note, we also have to reverse the (unordered) input to preserve stability
if (reverse)
{
thrust::reverse(exec, keys_first, keys_last);
thrust::reverse(exec, values_first, values_first + (keys_last - keys_first));
}
// ensure sequences have trivial iterators
thrust::detail::trivial_sequence<RandomAccessIterator1,DerivedPolicy> keys(exec, keys_first, keys_last);
thrust::detail::trivial_sequence<RandomAccessIterator2,DerivedPolicy> values(exec, values_first, values_first + (keys_last - keys_first));
thrust::system::cuda::detail::detail::stable_primitive_sort_by_key(exec, keys.begin(), keys.end(), values.begin());
// copy results back, if necessary
if(!thrust::detail::is_trivial_iterator<RandomAccessIterator1>::value)
thrust::copy(exec, keys.begin(), keys.end(), keys_first);
if(!thrust::detail::is_trivial_iterator<RandomAccessIterator2>::value)
thrust::copy(exec, values.begin(), values.end(), values_first);
if (reverse)
{
thrust::reverse(exec, keys_first, keys_last);
thrust::reverse(exec, values_first, values_first + (keys_last - keys_first));
}
}
template<typename DerivedPolicy,
typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename StrictWeakOrdering>
typename enable_if_comparison_sort<RandomAccessIterator1,StrictWeakOrdering>::type
stable_sort_by_key(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator1 keys_first,
RandomAccessIterator1 keys_last,
RandomAccessIterator2 values_first,
StrictWeakOrdering comp)
{
// decide whether to apply indirection to either range
typedef typename thrust::iterator_value<RandomAccessIterator1>::type KeyType;
typedef typename thrust::iterator_value<RandomAccessIterator2>::type ValueType;
typedef thrust::detail::integral_constant<bool, (sizeof(KeyType) > 8)> use_key_indirection;
typedef thrust::detail::integral_constant<bool, (sizeof(ValueType) > 4)> use_value_indirection;
conditional_temporary_indirect_ordering<
use_key_indirection,
DerivedPolicy,
RandomAccessIterator1,
StrictWeakOrdering
> potentially_indirect_keys(derived_cast(exec), keys_first, keys_last, comp);
conditional_temporary_indirect_permutation<
use_value_indirection,
DerivedPolicy,
RandomAccessIterator2
> potentially_indirect_values(derived_cast(exec), values_first, values_first + (keys_last - keys_first));
thrust::system::cuda::detail::detail::stable_merge_sort_by_key(exec,
potentially_indirect_keys.begin(),
potentially_indirect_keys.end(),
potentially_indirect_values.begin(),
potentially_indirect_keys.comp());
}
} // end namespace stable_sort_detail
template<typename DerivedPolicy,
typename RandomAccessIterator,
typename StrictWeakOrdering>
void stable_sort(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator first,
RandomAccessIterator last,
StrictWeakOrdering comp)
{
// we're attempting to launch a kernel, assert we're compiling with nvcc
// ========================================================================
// X Note to the user: If you've found this line due to a compiler error, X
// X you need to compile your code using nvcc, rather than g++ or cl.exe X
// ========================================================================
THRUST_STATIC_ASSERT( (thrust::detail::depend_on_instantiation<RandomAccessIterator, THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC>::value) );
stable_sort_detail::stable_sort(exec, first, last, comp);
}
template<typename DerivedPolicy,
typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename StrictWeakOrdering>
void stable_sort_by_key(execution_policy<DerivedPolicy> &exec,
RandomAccessIterator1 keys_first,
RandomAccessIterator1 keys_last,
RandomAccessIterator2 values_first,
StrictWeakOrdering comp)
{
// we're attempting to launch a kernel, assert we're compiling with nvcc
// ========================================================================
// X Note to the user: If you've found this line due to a compiler error, X
// X you need to compile your code using nvcc, rather than g++ or cl.exe X
// ========================================================================
THRUST_STATIC_ASSERT( (thrust::detail::depend_on_instantiation<RandomAccessIterator1, THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC>::value) );
stable_sort_detail::stable_sort_by_key(exec, keys_first, keys_last, values_first, comp);
}
} // end namespace detail
} // end namespace cuda
} // end namespace system
} // end namespace thrust