/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! \file internal_functional.inl * \brief Non-public functionals used to implement algorithm internals. */ #pragma once #include #include #include #include #include #include // for ::new namespace thrust { namespace detail { // unary_negate does not need to know argument_type template struct unary_negate { typedef bool result_type; Predicate pred; __host__ __device__ explicit unary_negate(const Predicate& pred) : pred(pred) {} template __host__ __device__ bool operator()(const T& x) { return !bool(pred(x)); } }; // binary_negate does not need to know first_argument_type or second_argument_type template struct binary_negate { typedef bool result_type; Predicate pred; __host__ __device__ explicit binary_negate(const Predicate& pred) : pred(pred) {} template __host__ __device__ bool operator()(const T1& x, const T2& y) { return !bool(pred(x,y)); } }; template __host__ __device__ thrust::detail::unary_negate not1(const Predicate &pred) { return thrust::detail::unary_negate(pred); } template __host__ __device__ thrust::detail::binary_negate not2(const Predicate &pred) { return thrust::detail::binary_negate(pred); } // convert a predicate to a 0 or 1 integral value template struct predicate_to_integral { Predicate pred; __host__ __device__ explicit predicate_to_integral(const Predicate& pred) : pred(pred) {} template __host__ __device__ bool operator()(const T& x) { return pred(x) ? IntegralType(1) : IntegralType(0); } }; // note that detail::equal_to does not force conversion from T2 -> T1 as equal_to does template struct equal_to { typedef bool result_type; template __host__ __device__ bool operator()(const T1& lhs, const T2& rhs) const { return lhs == rhs; } }; // note that equal_to_value does not force conversion from T2 -> T1 as equal_to does template struct equal_to_value { T2 rhs; equal_to_value(const T2& rhs) : rhs(rhs) {} template __host__ __device__ bool operator()(const T1& lhs) const { return lhs == rhs; } }; template struct tuple_binary_predicate { typedef bool result_type; __host__ __device__ tuple_binary_predicate(const Predicate& p) : pred(p) {} template __host__ __device__ bool operator()(const Tuple& t) const { return pred(thrust::get<0>(t), thrust::get<1>(t)); } Predicate pred; }; template struct tuple_not_binary_predicate { typedef bool result_type; __host__ __device__ tuple_not_binary_predicate(const Predicate& p) : pred(p) {} template __host__ __device__ bool operator()(const Tuple& t) const { return !pred(thrust::get<0>(t), thrust::get<1>(t)); } Predicate pred; }; template struct host_generate_functor { typedef void result_type; __host__ __device__ host_generate_functor(Generator g) : gen(g) {} // operator() does not take an lvalue reference because some iterators // produce temporary proxy references when dereferenced. for example, // consider the temporary tuple of references produced by zip_iterator. // such temporaries cannot bind to an lvalue reference. // // to WAR this, accept a const reference (which is bindable to a temporary), // and const_cast in the implementation. // // XXX change to an rvalue reference upon c++0x (which either a named variable // or temporary can bind to) template __host__ void operator()(const T &x) { // we have to be naughty and const_cast this to get it to work T &lvalue = const_cast(x); // this assigns correctly whether x is a true reference or proxy lvalue = gen(); } Generator gen; }; template struct device_generate_functor { typedef void result_type; __host__ __device__ device_generate_functor(Generator g) : gen(g) {} // operator() does not take an lvalue reference because some iterators // produce temporary proxy references when dereferenced. for example, // consider the temporary tuple of references produced by zip_iterator. // such temporaries cannot bind to an lvalue reference. // // to WAR this, accept a const reference (which is bindable to a temporary), // and const_cast in the implementation. // // XXX change to an rvalue reference upon c++0x (which either a named variable // or temporary can bind to) template __host__ __device__ void operator()(const T &x) { // we have to be naughty and const_cast this to get it to work T &lvalue = const_cast(x); // this assigns correctly whether x is a true reference or proxy lvalue = gen(); } Generator gen; }; template struct generate_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct zipped_binary_op { typedef ResultType result_type; __host__ __device__ zipped_binary_op(BinaryFunction binary_op) : m_binary_op(binary_op) {} template __host__ __device__ inline result_type operator()(Tuple t) { return m_binary_op(thrust::get<0>(t), thrust::get<1>(t)); } BinaryFunction m_binary_op; }; template struct is_non_const_reference : thrust::detail::and_< thrust::detail::not_ >, thrust::detail::is_reference > {}; template struct is_tuple_of_iterator_references : thrust::detail::false_type {}; template struct is_tuple_of_iterator_references< thrust::detail::tuple_of_iterator_references< T1,T2,T3,T4,T5,T6,T7,T8,T9,T10 > > : thrust::detail::true_type {}; // use this enable_if to avoid assigning to temporaries in the transform functors below // XXX revisit this problem with c++11 perfect forwarding template struct enable_if_non_const_reference_or_tuple_of_iterator_references : thrust::detail::enable_if< is_non_const_reference::value || is_tuple_of_iterator_references::value > {}; template struct host_unary_transform_functor { typedef void result_type; UnaryFunction f; host_unary_transform_functor(UnaryFunction f_) :f(f_) {} template inline __host__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<1,Tuple>::type >::type operator()(Tuple t) { thrust::get<1>(t) = f(thrust::get<0>(t)); } }; template struct device_unary_transform_functor { typedef void result_type; UnaryFunction f; device_unary_transform_functor(UnaryFunction f_) :f(f_) {} // add __host__ to allow the omp backend compile with nvcc template inline __host__ __device__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<1,Tuple>::type >::type operator()(Tuple t) { thrust::get<1>(t) = f(thrust::get<0>(t)); } }; template struct unary_transform_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct host_binary_transform_functor { BinaryFunction f; host_binary_transform_functor(BinaryFunction f_) :f(f_) {} template __host__ void operator()(Tuple t) { thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t)); } }; // end binary_transform_functor template struct device_binary_transform_functor { BinaryFunction f; device_binary_transform_functor(BinaryFunction f_) :f(f_) {} // add __host__ to allow the omp backend compile with nvcc template inline __host__ __device__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<2,Tuple>::type >::type operator()(Tuple t) { thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t)); } }; // end binary_transform_functor template struct binary_transform_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct host_unary_transform_if_functor { UnaryFunction unary_op; Predicate pred; host_unary_transform_if_functor(UnaryFunction unary_op_, Predicate pred_) : unary_op(unary_op_), pred(pred_) {} template inline __host__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<1,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<0>(t))) { thrust::get<1>(t) = unary_op(thrust::get<0>(t)); } } }; // end host_unary_transform_if_functor template struct device_unary_transform_if_functor { UnaryFunction unary_op; Predicate pred; device_unary_transform_if_functor(UnaryFunction unary_op_, Predicate pred_) : unary_op(unary_op_), pred(pred_) {} template inline __host__ __device__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<1,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<0>(t))) { thrust::get<1>(t) = unary_op(thrust::get<0>(t)); } } }; // end device_unary_transform_if_functor template struct unary_transform_if_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct host_unary_transform_if_with_stencil_functor { UnaryFunction unary_op; Predicate pred; host_unary_transform_if_with_stencil_functor(UnaryFunction _unary_op, Predicate _pred) : unary_op(_unary_op), pred(_pred) {} template inline __host__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<2,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<1>(t))) thrust::get<2>(t) = unary_op(thrust::get<0>(t)); } }; // end host_unary_transform_if_with_stencil_functor template struct device_unary_transform_if_with_stencil_functor { UnaryFunction unary_op; Predicate pred; device_unary_transform_if_with_stencil_functor(UnaryFunction _unary_op, Predicate _pred) : unary_op(_unary_op), pred(_pred) {} // add __host__ to allow the omp backend compile with nvcc template inline __host__ __device__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<2,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<1>(t))) thrust::get<2>(t) = unary_op(thrust::get<0>(t)); } }; // end device_unary_transform_if_with_stencil_functor template struct unary_transform_if_with_stencil_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct host_binary_transform_if_functor { BinaryFunction binary_op; Predicate pred; host_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred) : binary_op(_binary_op), pred(_pred) {} template inline __host__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<3,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<2>(t))) thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t)); } }; // end host_binary_transform_if_functor template struct device_binary_transform_if_functor { BinaryFunction binary_op; Predicate pred; device_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred) : binary_op(_binary_op), pred(_pred) {} // add __host__ to allow the omp backend compile with nvcc template inline __host__ __device__ typename enable_if_non_const_reference_or_tuple_of_iterator_references< typename thrust::tuple_element<3,Tuple>::type >::type operator()(Tuple t) { if(pred(thrust::get<2>(t))) thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t)); } }; // end device_binary_transform_if_functor template struct binary_transform_if_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct host_destroy_functor { __host__ void operator()(T &x) const { x.~T(); } // end operator()() }; // end host_destroy_functor template struct device_destroy_functor { // add __host__ to allow the omp backend to compile with nvcc __host__ __device__ void operator()(T &x) const { x.~T(); } // end operator()() }; // end device_destroy_functor template struct destroy_functor : thrust::detail::eval_if< thrust::detail::is_convertible::value, thrust::detail::identity_ >, thrust::detail::identity_ > > {}; template struct fill_functor { const T exemplar; fill_functor(const T& _exemplar) : exemplar(_exemplar) {} __host__ __device__ T operator()(void) const { return exemplar; } }; template struct uninitialized_fill_functor { T exemplar; uninitialized_fill_functor(T x):exemplar(x){} __host__ __device__ void operator()(T &x) { ::new(static_cast(&x)) T(exemplar); } // end operator()() }; // end uninitialized_fill_functor // this predicate tests two two-element tuples // we first use a Compare for the first element // if the first elements are equivalent, we use // < for the second elements template struct compare_first_less_second { compare_first_less_second(Compare c) : comp(c) {} template __host__ __device__ bool operator()(T1 lhs, T2 rhs) { return comp(thrust::get<0>(lhs), thrust::get<0>(rhs)) || (!comp(thrust::get<0>(rhs), thrust::get<0>(lhs)) && thrust::get<1>(lhs) < thrust::get<1>(rhs)); } Compare comp; }; // end compare_first_less_second template struct compare_first { Compare comp; compare_first(Compare comp) : comp(comp) {} template __host__ __device__ bool operator()(const Tuple1 &x, const Tuple2 &y) { return comp(thrust::raw_reference_cast(thrust::get<0>(x)), thrust::raw_reference_cast(thrust::get<0>(y))); } }; // end compare_first } // end namespace detail } // end namespace thrust