You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
678 lines
18 KiB
678 lines
18 KiB
/* |
|
* Copyright 2008-2012 NVIDIA Corporation |
|
* |
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
* you may not use this file except in compliance with the License. |
|
* You may obtain a copy of the License at |
|
* |
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
* |
|
* Unless required by applicable law or agreed to in writing, software |
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
* See the License for the specific language governing permissions and |
|
* limitations under the License. |
|
*/ |
|
|
|
|
|
/*! \file internal_functional.inl |
|
* \brief Non-public functionals used to implement algorithm internals. |
|
*/ |
|
|
|
#pragma once |
|
|
|
#include <thrust/tuple.h> |
|
#include <thrust/iterator/iterator_traits.h> |
|
#include <thrust/detail/type_traits.h> |
|
#include <thrust/iterator/detail/tuple_of_iterator_references.h> |
|
#include <thrust/detail/raw_reference_cast.h> |
|
#include <memory> // for ::new |
|
|
|
namespace thrust |
|
{ |
|
namespace detail |
|
{ |
|
|
|
// unary_negate does not need to know argument_type |
|
template <typename Predicate> |
|
struct unary_negate |
|
{ |
|
typedef bool result_type; |
|
|
|
Predicate pred; |
|
|
|
__host__ __device__ |
|
explicit unary_negate(const Predicate& pred) : pred(pred) {} |
|
|
|
template <typename T> |
|
__host__ __device__ |
|
bool operator()(const T& x) |
|
{ |
|
return !bool(pred(x)); |
|
} |
|
}; |
|
|
|
// binary_negate does not need to know first_argument_type or second_argument_type |
|
template <typename Predicate> |
|
struct binary_negate |
|
{ |
|
typedef bool result_type; |
|
|
|
Predicate pred; |
|
|
|
__host__ __device__ |
|
explicit binary_negate(const Predicate& pred) : pred(pred) {} |
|
|
|
template <typename T1, typename T2> |
|
__host__ __device__ |
|
bool operator()(const T1& x, const T2& y) |
|
{ |
|
return !bool(pred(x,y)); |
|
} |
|
}; |
|
|
|
template<typename Predicate> |
|
__host__ __device__ |
|
thrust::detail::unary_negate<Predicate> not1(const Predicate &pred) |
|
{ |
|
return thrust::detail::unary_negate<Predicate>(pred); |
|
} |
|
|
|
template<typename Predicate> |
|
__host__ __device__ |
|
thrust::detail::binary_negate<Predicate> not2(const Predicate &pred) |
|
{ |
|
return thrust::detail::binary_negate<Predicate>(pred); |
|
} |
|
|
|
|
|
// convert a predicate to a 0 or 1 integral value |
|
template <typename Predicate, typename IntegralType> |
|
struct predicate_to_integral |
|
{ |
|
Predicate pred; |
|
|
|
__host__ __device__ |
|
explicit predicate_to_integral(const Predicate& pred) : pred(pred) {} |
|
|
|
template <typename T> |
|
__host__ __device__ |
|
bool operator()(const T& x) |
|
{ |
|
return pred(x) ? IntegralType(1) : IntegralType(0); |
|
} |
|
}; |
|
|
|
|
|
// note that detail::equal_to does not force conversion from T2 -> T1 as equal_to does |
|
template <typename T1> |
|
struct equal_to |
|
{ |
|
typedef bool result_type; |
|
|
|
template <typename T2> |
|
__host__ __device__ |
|
bool operator()(const T1& lhs, const T2& rhs) const |
|
{ |
|
return lhs == rhs; |
|
} |
|
}; |
|
|
|
// note that equal_to_value does not force conversion from T2 -> T1 as equal_to does |
|
template <typename T2> |
|
struct equal_to_value |
|
{ |
|
T2 rhs; |
|
|
|
equal_to_value(const T2& rhs) : rhs(rhs) {} |
|
|
|
template <typename T1> |
|
__host__ __device__ |
|
bool operator()(const T1& lhs) const |
|
{ |
|
return lhs == rhs; |
|
} |
|
}; |
|
|
|
template <typename Predicate> |
|
struct tuple_binary_predicate |
|
{ |
|
typedef bool result_type; |
|
|
|
__host__ __device__ |
|
tuple_binary_predicate(const Predicate& p) : pred(p) {} |
|
|
|
template<typename Tuple> |
|
__host__ __device__ |
|
bool operator()(const Tuple& t) const |
|
{ |
|
return pred(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
|
|
Predicate pred; |
|
}; |
|
|
|
template <typename Predicate> |
|
struct tuple_not_binary_predicate |
|
{ |
|
typedef bool result_type; |
|
|
|
__host__ __device__ |
|
tuple_not_binary_predicate(const Predicate& p) : pred(p) {} |
|
|
|
template<typename Tuple> |
|
__host__ __device__ |
|
bool operator()(const Tuple& t) const |
|
{ |
|
return !pred(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
|
|
Predicate pred; |
|
}; |
|
|
|
template<typename Generator> |
|
struct host_generate_functor |
|
{ |
|
typedef void result_type; |
|
|
|
__host__ __device__ |
|
host_generate_functor(Generator g) |
|
: gen(g) {} |
|
|
|
// operator() does not take an lvalue reference because some iterators |
|
// produce temporary proxy references when dereferenced. for example, |
|
// consider the temporary tuple of references produced by zip_iterator. |
|
// such temporaries cannot bind to an lvalue reference. |
|
// |
|
// to WAR this, accept a const reference (which is bindable to a temporary), |
|
// and const_cast in the implementation. |
|
// |
|
// XXX change to an rvalue reference upon c++0x (which either a named variable |
|
// or temporary can bind to) |
|
template<typename T> |
|
__host__ |
|
void operator()(const T &x) |
|
{ |
|
// we have to be naughty and const_cast this to get it to work |
|
T &lvalue = const_cast<T&>(x); |
|
|
|
// this assigns correctly whether x is a true reference or proxy |
|
lvalue = gen(); |
|
} |
|
|
|
Generator gen; |
|
}; |
|
|
|
template<typename Generator> |
|
struct device_generate_functor |
|
{ |
|
typedef void result_type; |
|
|
|
__host__ __device__ |
|
device_generate_functor(Generator g) |
|
: gen(g) {} |
|
|
|
// operator() does not take an lvalue reference because some iterators |
|
// produce temporary proxy references when dereferenced. for example, |
|
// consider the temporary tuple of references produced by zip_iterator. |
|
// such temporaries cannot bind to an lvalue reference. |
|
// |
|
// to WAR this, accept a const reference (which is bindable to a temporary), |
|
// and const_cast in the implementation. |
|
// |
|
// XXX change to an rvalue reference upon c++0x (which either a named variable |
|
// or temporary can bind to) |
|
template<typename T> |
|
__host__ __device__ |
|
void operator()(const T &x) |
|
{ |
|
// we have to be naughty and const_cast this to get it to work |
|
T &lvalue = const_cast<T&>(x); |
|
|
|
// this assigns correctly whether x is a true reference or proxy |
|
lvalue = gen(); |
|
} |
|
|
|
Generator gen; |
|
}; |
|
|
|
template<typename System, typename Generator> |
|
struct generate_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_generate_functor<Generator> >, |
|
thrust::detail::identity_<device_generate_functor<Generator> > |
|
> |
|
{}; |
|
|
|
|
|
template<typename ResultType, typename BinaryFunction> |
|
struct zipped_binary_op |
|
{ |
|
typedef ResultType result_type; |
|
|
|
__host__ __device__ |
|
zipped_binary_op(BinaryFunction binary_op) |
|
: m_binary_op(binary_op) {} |
|
|
|
template<typename Tuple> |
|
__host__ __device__ |
|
inline result_type operator()(Tuple t) |
|
{ |
|
return m_binary_op(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
|
|
BinaryFunction m_binary_op; |
|
}; |
|
|
|
|
|
template<typename T> |
|
struct is_non_const_reference |
|
: thrust::detail::and_< |
|
thrust::detail::not_<thrust::detail::is_const<T> >, |
|
thrust::detail::is_reference<T> |
|
> |
|
{}; |
|
|
|
template<typename T> struct is_tuple_of_iterator_references : thrust::detail::false_type {}; |
|
|
|
template<typename T1, typename T2, typename T3, |
|
typename T4, typename T5, typename T6, |
|
typename T7, typename T8, typename T9, |
|
typename T10> |
|
struct is_tuple_of_iterator_references< |
|
thrust::detail::tuple_of_iterator_references< |
|
T1,T2,T3,T4,T5,T6,T7,T8,T9,T10 |
|
> |
|
> |
|
: thrust::detail::true_type |
|
{}; |
|
|
|
// use this enable_if to avoid assigning to temporaries in the transform functors below |
|
// XXX revisit this problem with c++11 perfect forwarding |
|
template<typename T> |
|
struct enable_if_non_const_reference_or_tuple_of_iterator_references |
|
: thrust::detail::enable_if< |
|
is_non_const_reference<T>::value || is_tuple_of_iterator_references<T>::value |
|
> |
|
{}; |
|
|
|
|
|
template<typename UnaryFunction> |
|
struct host_unary_transform_functor |
|
{ |
|
typedef void result_type; |
|
|
|
UnaryFunction f; |
|
|
|
host_unary_transform_functor(UnaryFunction f_) |
|
:f(f_) {} |
|
|
|
template<typename Tuple> |
|
inline __host__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<1,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
thrust::get<1>(t) = f(thrust::get<0>(t)); |
|
} |
|
}; |
|
|
|
template<typename UnaryFunction> |
|
struct device_unary_transform_functor |
|
{ |
|
typedef void result_type; |
|
|
|
UnaryFunction f; |
|
|
|
device_unary_transform_functor(UnaryFunction f_) |
|
:f(f_) {} |
|
|
|
// add __host__ to allow the omp backend compile with nvcc |
|
template<typename Tuple> |
|
inline __host__ __device__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<1,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
thrust::get<1>(t) = f(thrust::get<0>(t)); |
|
} |
|
}; |
|
|
|
|
|
template<typename System, typename UnaryFunction> |
|
struct unary_transform_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_unary_transform_functor<UnaryFunction> >, |
|
thrust::detail::identity_<device_unary_transform_functor<UnaryFunction> > |
|
> |
|
{}; |
|
|
|
|
|
template <typename BinaryFunction> |
|
struct host_binary_transform_functor |
|
{ |
|
BinaryFunction f; |
|
|
|
host_binary_transform_functor(BinaryFunction f_) |
|
:f(f_) |
|
{} |
|
|
|
template <typename Tuple> |
|
__host__ |
|
void operator()(Tuple t) |
|
{ |
|
thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
}; // end binary_transform_functor |
|
|
|
|
|
template <typename BinaryFunction> |
|
struct device_binary_transform_functor |
|
{ |
|
BinaryFunction f; |
|
|
|
device_binary_transform_functor(BinaryFunction f_) |
|
:f(f_) |
|
{} |
|
|
|
// add __host__ to allow the omp backend compile with nvcc |
|
template <typename Tuple> |
|
inline __host__ __device__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<2,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
}; // end binary_transform_functor |
|
|
|
|
|
template<typename System, typename BinaryFunction> |
|
struct binary_transform_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_binary_transform_functor<BinaryFunction> >, |
|
thrust::detail::identity_<device_binary_transform_functor<BinaryFunction> > |
|
> |
|
{}; |
|
|
|
|
|
template <typename UnaryFunction, typename Predicate> |
|
struct host_unary_transform_if_functor |
|
{ |
|
UnaryFunction unary_op; |
|
Predicate pred; |
|
|
|
host_unary_transform_if_functor(UnaryFunction unary_op_, Predicate pred_) |
|
: unary_op(unary_op_), pred(pred_) {} |
|
|
|
template<typename Tuple> |
|
inline __host__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<1,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<0>(t))) |
|
{ |
|
thrust::get<1>(t) = unary_op(thrust::get<0>(t)); |
|
} |
|
} |
|
}; // end host_unary_transform_if_functor |
|
|
|
|
|
template <typename UnaryFunction, typename Predicate> |
|
struct device_unary_transform_if_functor |
|
{ |
|
UnaryFunction unary_op; |
|
Predicate pred; |
|
|
|
device_unary_transform_if_functor(UnaryFunction unary_op_, Predicate pred_) |
|
: unary_op(unary_op_), pred(pred_) {} |
|
|
|
template<typename Tuple> |
|
inline __host__ __device__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<1,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<0>(t))) |
|
{ |
|
thrust::get<1>(t) = unary_op(thrust::get<0>(t)); |
|
} |
|
} |
|
}; // end device_unary_transform_if_functor |
|
|
|
|
|
template<typename System, typename UnaryFunction, typename Predicate> |
|
struct unary_transform_if_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_unary_transform_if_functor<UnaryFunction,Predicate> >, |
|
thrust::detail::identity_<device_unary_transform_if_functor<UnaryFunction,Predicate> > |
|
> |
|
{}; |
|
|
|
|
|
template <typename UnaryFunction, typename Predicate> |
|
struct host_unary_transform_if_with_stencil_functor |
|
{ |
|
UnaryFunction unary_op; |
|
Predicate pred; |
|
|
|
host_unary_transform_if_with_stencil_functor(UnaryFunction _unary_op, Predicate _pred) |
|
: unary_op(_unary_op), pred(_pred) {} |
|
|
|
template <typename Tuple> |
|
inline __host__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<2,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<1>(t))) |
|
thrust::get<2>(t) = unary_op(thrust::get<0>(t)); |
|
} |
|
}; // end host_unary_transform_if_with_stencil_functor |
|
|
|
|
|
template <typename UnaryFunction, typename Predicate> |
|
struct device_unary_transform_if_with_stencil_functor |
|
{ |
|
UnaryFunction unary_op; |
|
Predicate pred; |
|
|
|
device_unary_transform_if_with_stencil_functor(UnaryFunction _unary_op, Predicate _pred) |
|
: unary_op(_unary_op), pred(_pred) {} |
|
|
|
// add __host__ to allow the omp backend compile with nvcc |
|
template <typename Tuple> |
|
inline __host__ __device__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<2,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<1>(t))) |
|
thrust::get<2>(t) = unary_op(thrust::get<0>(t)); |
|
} |
|
}; // end device_unary_transform_if_with_stencil_functor |
|
|
|
|
|
template<typename System, typename UnaryFunction, typename Predicate> |
|
struct unary_transform_if_with_stencil_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_unary_transform_if_with_stencil_functor<UnaryFunction,Predicate> >, |
|
thrust::detail::identity_<device_unary_transform_if_with_stencil_functor<UnaryFunction,Predicate> > |
|
> |
|
{}; |
|
|
|
|
|
template <typename BinaryFunction, typename Predicate> |
|
struct host_binary_transform_if_functor |
|
{ |
|
BinaryFunction binary_op; |
|
Predicate pred; |
|
|
|
host_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred) |
|
: binary_op(_binary_op), pred(_pred) {} |
|
|
|
template <typename Tuple> |
|
inline __host__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<3,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<2>(t))) |
|
thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
}; // end host_binary_transform_if_functor |
|
|
|
|
|
template <typename BinaryFunction, typename Predicate> |
|
struct device_binary_transform_if_functor |
|
{ |
|
BinaryFunction binary_op; |
|
Predicate pred; |
|
|
|
device_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred) |
|
: binary_op(_binary_op), pred(_pred) {} |
|
|
|
// add __host__ to allow the omp backend compile with nvcc |
|
template <typename Tuple> |
|
inline __host__ __device__ |
|
typename enable_if_non_const_reference_or_tuple_of_iterator_references< |
|
typename thrust::tuple_element<3,Tuple>::type |
|
>::type |
|
operator()(Tuple t) |
|
{ |
|
if(pred(thrust::get<2>(t))) |
|
thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t)); |
|
} |
|
}; // end device_binary_transform_if_functor |
|
|
|
|
|
template<typename System, typename BinaryFunction, typename Predicate> |
|
struct binary_transform_if_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_binary_transform_if_functor<BinaryFunction,Predicate> >, |
|
thrust::detail::identity_<device_binary_transform_if_functor<BinaryFunction,Predicate> > |
|
> |
|
{}; |
|
|
|
|
|
template<typename T> |
|
struct host_destroy_functor |
|
{ |
|
__host__ |
|
void operator()(T &x) const |
|
{ |
|
x.~T(); |
|
} // end operator()() |
|
}; // end host_destroy_functor |
|
|
|
|
|
template<typename T> |
|
struct device_destroy_functor |
|
{ |
|
// add __host__ to allow the omp backend to compile with nvcc |
|
__host__ __device__ |
|
void operator()(T &x) const |
|
{ |
|
x.~T(); |
|
} // end operator()() |
|
}; // end device_destroy_functor |
|
|
|
|
|
template<typename System, typename T> |
|
struct destroy_functor |
|
: thrust::detail::eval_if< |
|
thrust::detail::is_convertible<System, thrust::host_system_tag>::value, |
|
thrust::detail::identity_<host_destroy_functor<T> >, |
|
thrust::detail::identity_<device_destroy_functor<T> > |
|
> |
|
{}; |
|
|
|
|
|
template <typename T> |
|
struct fill_functor |
|
{ |
|
const T exemplar; |
|
|
|
fill_functor(const T& _exemplar) |
|
: exemplar(_exemplar) {} |
|
|
|
__host__ __device__ |
|
T operator()(void) const |
|
{ |
|
return exemplar; |
|
} |
|
}; |
|
|
|
|
|
template<typename T> |
|
struct uninitialized_fill_functor |
|
{ |
|
T exemplar; |
|
|
|
uninitialized_fill_functor(T x):exemplar(x){} |
|
|
|
__host__ __device__ |
|
void operator()(T &x) |
|
{ |
|
::new(static_cast<void*>(&x)) T(exemplar); |
|
} // end operator()() |
|
}; // end uninitialized_fill_functor |
|
|
|
|
|
// this predicate tests two two-element tuples |
|
// we first use a Compare for the first element |
|
// if the first elements are equivalent, we use |
|
// < for the second elements |
|
template<typename Compare> |
|
struct compare_first_less_second |
|
{ |
|
compare_first_less_second(Compare c) |
|
: comp(c) {} |
|
|
|
template<typename T1, typename T2> |
|
__host__ __device__ |
|
bool operator()(T1 lhs, T2 rhs) |
|
{ |
|
return comp(thrust::get<0>(lhs), thrust::get<0>(rhs)) || (!comp(thrust::get<0>(rhs), thrust::get<0>(lhs)) && thrust::get<1>(lhs) < thrust::get<1>(rhs)); |
|
} |
|
|
|
Compare comp; |
|
}; // end compare_first_less_second |
|
|
|
|
|
template<typename Compare> |
|
struct compare_first |
|
{ |
|
Compare comp; |
|
|
|
compare_first(Compare comp) |
|
: comp(comp) |
|
{} |
|
|
|
template<typename Tuple1, typename Tuple2> |
|
__host__ __device__ |
|
bool operator()(const Tuple1 &x, const Tuple2 &y) |
|
{ |
|
return comp(thrust::raw_reference_cast(thrust::get<0>(x)), thrust::raw_reference_cast(thrust::get<0>(y))); |
|
} |
|
}; // end compare_first |
|
|
|
|
|
} // end namespace detail |
|
} // end namespace thrust |
|
|
|
|