/* * Copyright 2008-2012 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Copyright Jens Maurer 2000-2001 * Distributed under the Boost Software License, Version 1.0. (See * accompanying file LICENSE_1_0.txt or copy at * http://www.boost.org/LICENSE_1_0.txt) */ #pragma once #include #include #include #include #include namespace thrust { namespace random { namespace detail { // this version samples the normal distribution directly // and uses the non-standard math function erfcinv template class normal_distribution_nvcc { protected: template __host__ __device__ RealType sample(UniformRandomNumberGenerator &urng, const RealType mean, const RealType stddev) { typedef typename UniformRandomNumberGenerator::result_type uint_type; const uint_type urng_range = UniformRandomNumberGenerator::max - UniformRandomNumberGenerator::min; // Constants for conversion const RealType S1 = static_cast(1) / urng_range; const RealType S2 = S1 / 2; RealType S3 = static_cast(-1.4142135623730950488016887242097); // -sqrt(2) // Get the integer value uint_type u = urng() - UniformRandomNumberGenerator::min; // Ensure the conversion to float will give a value in the range [0,0.5) if(u > (urng_range / 2)) { u = urng_range - u; S3 = -S3; } // Convert to floating point in [0,0.5) RealType p = u*S1 + S2; // Apply inverse error function return mean + stddev * S3 * erfcinv(2 * p); } // no-op __host__ __device__ void reset() {} }; // this version samples the normal distribution using // Marsaglia's "polar method" template class normal_distribution_portable { protected: normal_distribution_portable() : m_valid(false) {} normal_distribution_portable(const normal_distribution_portable &other) : m_valid(other.m_valid) {} void reset() { m_valid = false; } // note that we promise to call this member function with the same mean and stddev template __host__ __device__ RealType sample(UniformRandomNumberGenerator &urng, const RealType mean, const RealType stddev) { // implementation from Boost // allow for Koenig lookup using std::sqrt; using std::log; using std::sin; using std::cos; if(!m_valid) { uniform_real_distribution u01; m_r1 = u01(urng); m_r2 = u01(urng); m_cached_rho = sqrt(-RealType(2) * log(RealType(1)-m_r2)); m_valid = true; } else { m_valid = false; } const RealType pi = RealType(3.14159265358979323846); RealType result = m_cached_rho * (m_valid ? cos(RealType(2)*pi*m_r1) : sin(RealType(2)*pi*m_r1)); return result; } private: RealType m_r1, m_r2, m_cached_rho; bool m_valid; }; template struct normal_distribution_base { #if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC typedef normal_distribution_nvcc type; #else typedef normal_distribution_portable type; #endif }; } // end detail } // end random } // end thrust