/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! \file distance.h * \brief Device implementations for distance. */ #pragma once #include #include #include #include #include #include #include #include #include namespace thrust { namespace system { namespace detail { namespace generic { namespace detail { ////////////// // Functors // ////////////// // return the smaller/larger element making sure to prefer the // first occurance of the minimum/maximum element template struct min_element_reduction { BinaryPredicate comp; __host__ __device__ min_element_reduction(BinaryPredicate comp) : comp(comp){} __host__ __device__ thrust::tuple operator()(const thrust::tuple& lhs, const thrust::tuple& rhs ) { if(comp(thrust::get<0>(lhs), thrust::get<0>(rhs))) return lhs; if(comp(thrust::get<0>(rhs), thrust::get<0>(lhs))) return rhs; // values are equivalent, prefer value with smaller index if(thrust::get<1>(lhs) < thrust::get<1>(rhs)) return lhs; else return rhs; } // end operator()() }; // end min_element_reduction template struct max_element_reduction { BinaryPredicate comp; __host__ __device__ max_element_reduction(BinaryPredicate comp) : comp(comp){} __host__ __device__ thrust::tuple operator()(const thrust::tuple& lhs, const thrust::tuple& rhs ) { if(comp(thrust::get<0>(lhs), thrust::get<0>(rhs))) return rhs; if(comp(thrust::get<0>(rhs), thrust::get<0>(lhs))) return lhs; // values are equivalent, prefer value with smaller index if(thrust::get<1>(lhs) < thrust::get<1>(rhs)) return lhs; else return rhs; } // end operator()() }; // end max_element_reduction // return the smaller & larger element making sure to prefer the // first occurance of the minimum/maximum element template struct minmax_element_reduction { BinaryPredicate comp; minmax_element_reduction(BinaryPredicate comp) : comp(comp){} __host__ __device__ thrust::tuple< thrust::tuple, thrust::tuple > operator()(const thrust::tuple< thrust::tuple, thrust::tuple >& lhs, const thrust::tuple< thrust::tuple, thrust::tuple >& rhs ) { return thrust::make_tuple(min_element_reduction(comp)(thrust::get<0>(lhs), thrust::get<0>(rhs)), max_element_reduction(comp)(thrust::get<1>(lhs), thrust::get<1>(rhs))); } // end operator()() }; // end minmax_element_reduction template struct duplicate_tuple { __host__ __device__ thrust::tuple< thrust::tuple, thrust::tuple > operator()(const thrust::tuple& t) { return thrust::make_tuple(t, t); } }; // end duplicate_tuple } // end namespace detail template ForwardIterator min_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last) { typedef typename thrust::iterator_value::type value_type; return thrust::min_element(exec, first, last, thrust::less()); } // end min_element() template ForwardIterator min_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last, BinaryPredicate comp) { if (first == last) return last; typedef typename thrust::iterator_traits::value_type InputType; typedef typename thrust::iterator_traits::difference_type IndexType; thrust::tuple result = thrust::reduce (exec, thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))), thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))) + (last - first), thrust::tuple(*first, 0), detail::min_element_reduction(comp)); return first + thrust::get<1>(result); } // end min_element() template ForwardIterator max_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last) { typedef typename thrust::iterator_value::type value_type; return thrust::max_element(exec, first, last, thrust::less()); } // end max_element() template ForwardIterator max_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last, BinaryPredicate comp) { if (first == last) return last; typedef typename thrust::iterator_traits::value_type InputType; typedef typename thrust::iterator_traits::difference_type IndexType; thrust::tuple result = thrust::reduce (exec, thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))), thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))) + (last - first), thrust::tuple(*first, 0), detail::max_element_reduction(comp)); return first + thrust::get<1>(result); } // end max_element() template thrust::pair minmax_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last) { typedef typename thrust::iterator_value::type value_type; return thrust::minmax_element(exec, first, last, thrust::less()); } // end minmax_element() template thrust::pair minmax_element(thrust::execution_policy &exec, ForwardIterator first, ForwardIterator last, BinaryPredicate comp) { if (first == last) return thrust::make_pair(last, last); typedef typename thrust::iterator_traits::value_type InputType; typedef typename thrust::iterator_traits::difference_type IndexType; thrust::tuple< thrust::tuple, thrust::tuple > result = thrust::transform_reduce (exec, thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))), thrust::make_zip_iterator(thrust::make_tuple(first, thrust::counting_iterator(0))) + (last - first), detail::duplicate_tuple(), detail::duplicate_tuple()(thrust::tuple(*first, 0)), detail::minmax_element_reduction(comp)); return thrust::make_pair(first + thrust::get<1>(thrust::get<0>(result)), first + thrust::get<1>(thrust::get<1>(result))); } // end minmax_element() } // end namespace generic } // end namespace detail } // end namespace system } // end namespace thrust