/*
 *  Copyright 2008-2012 NVIDIA Corporation
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#pragma once

#include <thrust/detail/config.h>
#include <thrust/system/tbb/detail/reduce_by_key.h>
#include <thrust/iterator/reverse_iterator.h>
#include <thrust/system/cpp/execution_policy.h>
#include <thrust/system/tbb/detail/execution_policy.h>
#include <thrust/system/tbb/detail/reduce_intervals.h>
#include <thrust/detail/minmax.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/detail/range/tail_flags.h>
#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>
#include <tbb/tbb_thread.h>
#include <cassert>


namespace thrust
{
namespace system
{
namespace tbb
{
namespace detail
{
namespace reduce_by_key_detail
{


template<typename L, typename R>
  inline L divide_ri(const L x, const R y)
{
  return (x + (y - 1)) / y;
}


template<typename InputIterator, typename BinaryFunction, typename OutputIterator = void>
  struct partial_sum_type
    : thrust::detail::eval_if<
        thrust::detail::has_result_type<BinaryFunction>::value,
        thrust::detail::result_type<BinaryFunction>,
        thrust::detail::eval_if<
          thrust::detail::is_output_iterator<OutputIterator>::value,
          thrust::iterator_value<InputIterator>,
          thrust::iterator_value<OutputIterator>
        >
      >
{};


template<typename InputIterator, typename BinaryFunction>
  struct partial_sum_type<InputIterator,BinaryFunction,void>
    : thrust::detail::eval_if<
        thrust::detail::has_result_type<BinaryFunction>::value,
        thrust::detail::result_type<BinaryFunction>,
        thrust::iterator_value<InputIterator>
      >
{};


template<typename InputIterator1,
         typename InputIterator2,
         typename BinaryPredicate,
         typename BinaryFunction>
  thrust::pair<
    InputIterator1,
    thrust::pair<
      typename InputIterator1::value_type,
      typename partial_sum_type<InputIterator2,BinaryFunction>::type
    >
  >
    reduce_last_segment_backward(InputIterator1 keys_first,
                                 InputIterator1 keys_last,
                                 InputIterator2 values_first,
                                 BinaryPredicate binary_pred,
                                 BinaryFunction binary_op)
{
  typename thrust::iterator_difference<InputIterator1>::type n = keys_last - keys_first;

  // reverse the ranges and consume from the end
  thrust::reverse_iterator<InputIterator1> keys_first_r(keys_last);
  thrust::reverse_iterator<InputIterator1> keys_last_r(keys_first);
  thrust::reverse_iterator<InputIterator2> values_first_r(values_first + n);

  typename InputIterator1::value_type result_key = *keys_first_r;
  typename partial_sum_type<InputIterator2,BinaryFunction>::type result_value = *values_first_r;

  // consume the entirety of the first key's sequence
  for(++keys_first_r, ++values_first_r;
      (keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key);
      ++keys_first_r, ++values_first_r)
  {
    result_value = binary_op(result_value, *values_first_r);
  }

  return thrust::make_pair(keys_first_r.base(), thrust::make_pair(result_key, result_value));
}


template<typename InputIterator1,
         typename InputIterator2,
         typename OutputIterator1,
         typename OutputIterator2,
         typename BinaryPredicate,
         typename BinaryFunction>
  thrust::tuple<
    OutputIterator1,
    OutputIterator2,
    typename InputIterator1::value_type,
    typename partial_sum_type<InputIterator2,BinaryFunction>::type
  >
    reduce_by_key_with_carry(InputIterator1 keys_first, 
                             InputIterator1 keys_last,
                             InputIterator2 values_first,
                             OutputIterator1 keys_output,
                             OutputIterator2 values_output,
                             BinaryPredicate binary_pred,
                             BinaryFunction binary_op)
{
  // first, consume the last sequence to produce the carry
  // XXX is there an elegant way to pose this such that we don't need to default construct carry?
  thrust::pair<
    typename InputIterator1::value_type,
    typename partial_sum_type<InputIterator2,BinaryFunction>::type
  > carry;

  thrust::tie(keys_last, carry) = reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op);

  // finish with sequential reduce_by_key
  thrust::cpp::tag seq;
  thrust::tie(keys_output, values_output) =
    thrust::reduce_by_key(seq, keys_first, keys_last, values_first, keys_output, values_output, binary_pred, binary_op);
  
  return thrust::make_tuple(keys_output, values_output, carry.first, carry.second);
}


template<typename Iterator>
  bool interval_has_carry(size_t interval_idx, size_t interval_size, size_t num_intervals, Iterator tail_flags)
{
  // to discover whether the interval has a carry, look at the tail_flag corresponding to its last element 
  // the final interval never has a carry by definition
  return (interval_idx + 1 < num_intervals) ? !tail_flags[(interval_idx + 1) * interval_size - 1] : false;
}


template<typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6, typename BinaryPredicate, typename BinaryFunction>
  struct serial_reduce_by_key_body
{
  typedef typename thrust::iterator_difference<Iterator1>::type size_type;

  Iterator1 keys_first;
  Iterator2 values_first;
  Iterator3 result_offset;
  Iterator4 keys_result;
  Iterator5 values_result;
  Iterator6 carry_result;

  size_type n;
  size_type interval_size;
  size_type num_intervals;

  BinaryPredicate binary_pred;
  BinaryFunction binary_op;

  serial_reduce_by_key_body(Iterator1 keys_first, Iterator2 values_first, Iterator3 result_offset, Iterator4 keys_result, Iterator5 values_result, Iterator6 carry_result, size_type n, size_type interval_size, size_type num_intervals, BinaryPredicate binary_pred, BinaryFunction binary_op)
    : keys_first(keys_first), values_first(values_first),
      result_offset(result_offset),
      keys_result(keys_result),
      values_result(values_result),
      carry_result(carry_result),
      n(n),
      interval_size(interval_size),
      num_intervals(num_intervals),
      binary_pred(binary_pred),
      binary_op(binary_op)
  {}

  void operator()(const ::tbb::blocked_range<size_type> &r) const
  {
    assert(r.size() == 1);

    const size_type interval_idx = r.begin();

    const size_type offset_to_first = interval_size * interval_idx;
    const size_type offset_to_last = thrust::min(n, offset_to_first + interval_size);

    Iterator1 my_keys_first     = keys_first    + offset_to_first;
    Iterator1 my_keys_last      = keys_first    + offset_to_last;
    Iterator2 my_values_first   = values_first  + offset_to_first;
    Iterator3 my_result_offset  = result_offset + interval_idx;
    Iterator4 my_keys_result    = keys_result   + *my_result_offset;
    Iterator5 my_values_result  = values_result + *my_result_offset;
    Iterator6 my_carry_result   = carry_result  + interval_idx;

    // consume the rest of the interval with reduce_by_key
    typedef typename thrust::iterator_value<Iterator1>::type key_type;
    typedef typename partial_sum_type<Iterator2,BinaryFunction>::type value_type;

    // XXX is there a way to pose this so that we don't require default construction of carry?
    thrust::pair<key_type, value_type> carry;

    thrust::tie(my_keys_result, my_values_result, carry.first, carry.second) =
      reduce_by_key_with_carry(my_keys_first,
                               my_keys_last,
                               my_values_first,
                               my_keys_result,
                               my_values_result,
                               binary_pred,
                               binary_op);

    // store to carry only when we actually have a carry
    // store to my_keys_result & my_values_result otherwise
    
    // create tail_flags so we can check for a carry
    thrust::detail::tail_flags<Iterator1,BinaryPredicate> flags = thrust::detail::make_tail_flags(keys_first, keys_first + n, binary_pred);

    if(interval_has_carry(interval_idx, interval_size, num_intervals, flags.begin()))
    {
      // we can ignore the carry's key
      // XXX because the carry result is uninitialized, we should copy construct
      *my_carry_result = carry.second;
    }
    else
    {
      *my_keys_result = carry.first;
      *my_values_result = carry.second;
    }
  }
};


template<typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6, typename BinaryPredicate, typename BinaryFunction>
  serial_reduce_by_key_body<Iterator1,Iterator2,Iterator3,Iterator4,Iterator5,Iterator6,BinaryPredicate,BinaryFunction>
    make_serial_reduce_by_key_body(Iterator1 keys_first, Iterator2 values_first, Iterator3 result_offset, Iterator4 keys_result, Iterator5 values_result, Iterator6 carry_result, typename thrust::iterator_difference<Iterator1>::type n, size_t interval_size, size_t num_intervals, BinaryPredicate binary_pred, BinaryFunction binary_op)
{
  return serial_reduce_by_key_body<Iterator1,Iterator2,Iterator3,Iterator4,Iterator5,Iterator6,BinaryPredicate,BinaryFunction>(keys_first, values_first, result_offset, keys_result, values_result, carry_result, n, interval_size, num_intervals, binary_pred, binary_op);
}


} // end reduce_by_key_detail


template<typename DerivedPolicy, typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename BinaryPredicate, typename BinaryFunction>
  thrust::pair<Iterator3,Iterator4>
    reduce_by_key(thrust::tbb::execution_policy<DerivedPolicy> &exec,
                  Iterator1 keys_first, Iterator1 keys_last, 
                  Iterator2 values_first,
                  Iterator3 keys_result,
                  Iterator4 values_result,
                  BinaryPredicate binary_pred,
                  BinaryFunction binary_op)
{

  typedef typename thrust::iterator_difference<Iterator1>::type difference_type;
  difference_type n = keys_last - keys_first;
  if(n == 0) return thrust::make_pair(keys_result, values_result);

  // XXX this value is a tuning opportunity
  const difference_type parallelism_threshold = 10000;

  if(n < parallelism_threshold)
  {
    // don't bother parallelizing for small n
    thrust::cpp::tag seq;
    return thrust::reduce_by_key(seq, keys_first, keys_last, values_first, keys_result, values_result, binary_pred, binary_op);
  }

  // count the number of processors
  const unsigned int p = thrust::max<unsigned int>(1u, ::tbb::tbb_thread::hardware_concurrency());

  // generate O(P) intervals of sequential work
  // XXX oversubscribing is a tuning opportunity
  const unsigned int subscription_rate = 1;
  difference_type interval_size = thrust::min<difference_type>(parallelism_threshold, thrust::max<difference_type>(n, n / (subscription_rate * p)));
  difference_type num_intervals = reduce_by_key_detail::divide_ri(n, interval_size);

  // decompose the input into intervals of size N / num_intervals
  // add one extra element to this vector to store the size of the entire result
  thrust::detail::temporary_array<difference_type, DerivedPolicy> interval_output_offsets(0, exec, num_intervals + 1);

  // first count the number of tail flags in each interval
  thrust::detail::tail_flags<Iterator1,BinaryPredicate> tail_flags = thrust::detail::make_tail_flags(keys_first, keys_last, binary_pred);
  thrust::system::tbb::detail::reduce_intervals(exec, tail_flags.begin(), tail_flags.end(), interval_size, interval_output_offsets.begin() + 1, thrust::plus<size_t>());
  interval_output_offsets[0] = 0;

  // scan the counts to get each body's output offset
  thrust::cpp::tag seq;
  thrust::inclusive_scan(seq,
                         interval_output_offsets.begin() + 1, interval_output_offsets.end(), 
                         interval_output_offsets.begin() + 1);

  // do a reduce_by_key serially in each thread
  // the final interval never has a carry by definition, so don't reserve space for it
  typedef typename reduce_by_key_detail::partial_sum_type<Iterator2,BinaryFunction>::type carry_type;
  thrust::detail::temporary_array<carry_type, DerivedPolicy> carries(0, exec, num_intervals - 1);

  // force grainsize == 1 with simple_partioner()
  ::tbb::parallel_for(::tbb::blocked_range<difference_type>(0, num_intervals, 1),
    reduce_by_key_detail::make_serial_reduce_by_key_body(keys_first, values_first, interval_output_offsets.begin(), keys_result, values_result, carries.begin(), n, interval_size, num_intervals, binary_pred, binary_op),
    ::tbb::simple_partitioner());

  difference_type size_of_result = interval_output_offsets[num_intervals];

  // sequentially accumulate the carries
  // note that the last interval does not have a carry
  // XXX find a way to express this loop via a sequential algorithm, perhaps reduce_by_key
  for(typename thrust::detail::temporary_array<carry_type,DerivedPolicy>::size_type i = 0; i < carries.size(); ++i)
  {
    // if our interval has a carry, then we need to sum the carry to the next interval's output offset
    // if it does not have a carry, then we need to ignore carry_value[i]
    if(reduce_by_key_detail::interval_has_carry(i, interval_size, num_intervals, tail_flags.begin()))
    {
      difference_type output_idx = interval_output_offsets[i+1];

      values_result[output_idx] = binary_op(values_result[output_idx], carries[i]);
    }
  }

  return thrust::make_pair(keys_result + size_of_result, values_result + size_of_result);
}


} // end detail
} // end tbb
} // end system
} // end thrust