You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
5.7 KiB
169 lines
5.7 KiB
11 years ago
|
/*
|
||
|
* Copyright 2008-2012 NVIDIA Corporation
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include <thrust/iterator/iterator_traits.h>
|
||
|
#include <thrust/detail/raw_reference_cast.h>
|
||
|
#include <thrust/system/detail/generic/scalar/binary_search.h>
|
||
|
|
||
|
namespace thrust
|
||
|
{
|
||
|
namespace system
|
||
|
{
|
||
|
namespace cuda
|
||
|
{
|
||
|
namespace detail
|
||
|
{
|
||
|
namespace block
|
||
|
{
|
||
|
|
||
|
template<typename Context,
|
||
|
typename RandomAccessIterator1,
|
||
|
typename RandomAccessIterator2,
|
||
|
typename RandomAccessIterator3,
|
||
|
typename StrictWeakOrdering>
|
||
|
__device__ __thrust_forceinline__
|
||
|
RandomAccessIterator3 merge(Context context,
|
||
|
RandomAccessIterator1 first1,
|
||
|
RandomAccessIterator1 last1,
|
||
|
RandomAccessIterator2 first2,
|
||
|
RandomAccessIterator2 last2,
|
||
|
RandomAccessIterator3 result,
|
||
|
StrictWeakOrdering comp)
|
||
|
{
|
||
|
typedef typename thrust::iterator_difference<RandomAccessIterator1>::type difference1;
|
||
|
typedef typename thrust::iterator_difference<RandomAccessIterator2>::type difference2;
|
||
|
|
||
|
difference1 n1 = last1 - first1;
|
||
|
difference2 n2 = last2 - first2;
|
||
|
|
||
|
// find the rank of each element in the other array
|
||
|
difference2 rank2 = 0;
|
||
|
if(context.thread_index() < n1)
|
||
|
{
|
||
|
RandomAccessIterator1 x = first1;
|
||
|
x += context.thread_index();
|
||
|
|
||
|
// lower_bound ensures that x sorts before any equivalent element of input2
|
||
|
// this ensures stability
|
||
|
rank2 = thrust::system::detail::generic::scalar::lower_bound(first2, last2, raw_reference_cast(*x), comp) - first2;
|
||
|
} // end if
|
||
|
|
||
|
difference1 rank1 = 0;
|
||
|
if(context.thread_index() < n2)
|
||
|
{
|
||
|
RandomAccessIterator2 x = first2 + context.thread_index();
|
||
|
|
||
|
// upper_bound ensures that x sorts before any equivalent element of input1
|
||
|
// this ensures stability
|
||
|
rank1 = thrust::system::detail::generic::scalar::upper_bound(first1, last1, raw_reference_cast(*x), comp) - first1;
|
||
|
} // end if
|
||
|
|
||
|
if(context.thread_index() < n1)
|
||
|
{
|
||
|
// scatter each element from input1
|
||
|
RandomAccessIterator1 src = first1 + context.thread_index();
|
||
|
RandomAccessIterator3 dst = result + context.thread_index() + rank2;
|
||
|
|
||
|
*dst = *src;
|
||
|
}
|
||
|
|
||
|
if(context.thread_index() < n2)
|
||
|
{
|
||
|
// scatter each element from input2
|
||
|
RandomAccessIterator2 src = first2 + context.thread_index();
|
||
|
RandomAccessIterator3 dst = result + context.thread_index() + rank1;
|
||
|
|
||
|
*dst = *src;
|
||
|
}
|
||
|
|
||
|
return result + n1 + n2;
|
||
|
} // end merge
|
||
|
|
||
|
|
||
|
template<typename Context,
|
||
|
typename RandomAccessIterator1,
|
||
|
typename RandomAccessIterator2,
|
||
|
typename Size1,
|
||
|
typename Size2,
|
||
|
typename StrictWeakOrdering>
|
||
|
__device__ __thrust_forceinline__
|
||
|
void inplace_merge_by_key_n(Context context,
|
||
|
RandomAccessIterator1 keys_first,
|
||
|
RandomAccessIterator2 values_first,
|
||
|
Size1 n1,
|
||
|
Size2 n2,
|
||
|
StrictWeakOrdering comp)
|
||
|
{
|
||
|
RandomAccessIterator1 input1 = keys_first;
|
||
|
RandomAccessIterator1 input2 = keys_first + n1;
|
||
|
|
||
|
RandomAccessIterator2 input1val = values_first;
|
||
|
RandomAccessIterator2 input2val = values_first + n1;
|
||
|
|
||
|
typedef typename thrust::iterator_value<RandomAccessIterator1>::type KeyType;
|
||
|
typedef typename thrust::iterator_value<RandomAccessIterator2>::type ValueType;
|
||
|
|
||
|
// XXX use uninitialized here
|
||
|
KeyType inp1 = input1[context.thread_index()]; ValueType inp1val = input1val[context.thread_index()];
|
||
|
KeyType inp2 = input2[context.thread_index()]; ValueType inp2val = input2val[context.thread_index()];
|
||
|
|
||
|
// to merge input1 and input2, use binary search to find the rank of inp1 & inp2 in arrays input2 & input1, respectively
|
||
|
// as before, the "end" variables point to one element after the last element of the arrays
|
||
|
|
||
|
// start by looking through input2 for inp1's rank
|
||
|
unsigned int start_1 = 0;
|
||
|
|
||
|
// don't do the search if our value is beyond the end of input1
|
||
|
if(context.thread_index() < n1)
|
||
|
{
|
||
|
start_1 = thrust::system::detail::generic::scalar::lower_bound_n(input2, n2, inp1, comp) - input2;
|
||
|
} // end if
|
||
|
|
||
|
// now look through input1 for inp2's rank
|
||
|
unsigned int start_2 = 0;
|
||
|
|
||
|
// don't do the search if our value is beyond the end of input2
|
||
|
if(context.thread_index() < n2)
|
||
|
{
|
||
|
// upper_bound ensures that equivalent elements in the first range sort before the second
|
||
|
start_2 = thrust::system::detail::generic::scalar::upper_bound_n(input1, n1, inp2, comp) - input1;
|
||
|
} // end if
|
||
|
|
||
|
context.barrier();
|
||
|
|
||
|
// Write back into the right position to the input arrays; can be done in place since we read in
|
||
|
// the input arrays into registers before.
|
||
|
if(context.thread_index() < n1)
|
||
|
{
|
||
|
input1[start_1 + context.thread_index()] = inp1;
|
||
|
input1val[start_1 + context.thread_index()] = inp1val;
|
||
|
} // end if
|
||
|
|
||
|
if(context.thread_index() < n2)
|
||
|
{
|
||
|
input1[start_2 + context.thread_index()] = inp2;
|
||
|
input1val[start_2 + context.thread_index()] = inp2val;
|
||
|
} // end if
|
||
|
} // end inplace_merge_by_key_n()
|
||
|
|
||
|
|
||
|
} // end namespace block
|
||
|
} // end namespace detail
|
||
|
} // end namespace cuda
|
||
|
} // end namespace system
|
||
|
} // end namespace thrust
|
||
|
|