#ifndef ROOTFINDER_HPP
#define ROOTFINDER_HPP

#include "rootcnt.hpp"

/* find the roots of a function f which is two times continuously
   differentiable using the approach of

      V. P. Plagianakos et al, Locating and computing in
      parallel all the simple roots of special functions using PVM,
      Journal of Computational and Applied Mathematics 133(2001) 545-554
      https://www.math.upatras.gr/~vrahatis/papers/journals/PlagianakosNV01_J_COMPUT_APPL_MATH_133_pp545-554_2001.pdf

      Parameters of the constructor:

      f   function which is two times continuously differentiable,
	  i.e. f' and f'' exist and both are continuous functions
      f1  first derivative f'(x)
      f2  second derivative f''(x)

   parallel version using OpenMP
*/

template<typename F, typename F1, typename F2, typename Real = double>
class RootFinder {
   public:
      RootFinder(F f, F1 f1, F2 f2) : f(f), f1(f1), f2(f2) {
      }

      template<typename OutputIterator>
      void get_roots(OutputIterator outit,
	    Real a, Real b, Real eps) const {
	 unsigned int numOfRoots = get_count(a, b);
	 if (numOfRoots > 0) {
	    #pragma omp parallel
	    #pragma omp single
	    get_roots(outit, a, b, eps, numOfRoots);
	 }
      }

   private:
      F f; F1 f1; F2 f2;

      unsigned int get_count(Real a, Real b) const {
	 return get_root_count(a, b, f, f1, f2);
      }

      static int sgn(Real x) {
	 return x < 0? -1: (x > 0? 1: 0);
      }

      Real bisection(Real a, Real b, Real eps) const {
	 Real fa = f(a); Real fb = f(b);
	 for(;;) {
	    Real midpoint = (a + b) / 2; Real fm = f(midpoint);
	    if (fm == 0 || b - a < eps) return midpoint;
	    if (sgn(fa) * sgn(fm) < 0) {
	       b = midpoint; fb = fm;
	    } else {
	       a = midpoint; fa = fm;
	    }
	 }
      }

      template<typename OutputIterator>
      void get_roots(OutputIterator& outit,
	       Real a, Real b, Real eps, unsigned int numOfRoots) const {
	 if (numOfRoots == 0) return;

	 if (numOfRoots == 1) {
	    Real root = bisection(a, b, eps);
	    #pragma omp critical
	    {
	       *outit++ = root;
	    }
	 } else {
	    Real midpoint = (a + b) / 2;
	    unsigned int numOfLeftRoots = 0; // shared
	    unsigned int numOfRightRoots = 0; // shared
	    #pragma omp taskgroup
	    {
	       #pragma omp task shared(numOfLeftRoots)
	       numOfLeftRoots = get_count(a, midpoint); 
	       #pragma omp task shared(numOfRightRoots)
	       numOfRightRoots = get_count(midpoint, b);
	    }
	    if (numOfLeftRoots + numOfRightRoots < numOfRoots) {
	       #pragma omp critical
	       {
		  *outit++ = midpoint;
	       }
	    }
	    #pragma omp task shared(outit)
	    get_roots(outit, a, midpoint, eps, numOfLeftRoots);
	    #pragma omp task shared(outit)
	    get_roots(outit, midpoint, b, eps, numOfRightRoots);
	 }
      }
};

template<typename OutputIterator, typename F, typename Real = double>
void get_roots(OutputIterator outit, F f, Real a, Real b, Real eps) {
   NumF1<F, Real> f1(f);
   NumF2<F, Real> f2(f);
   RootFinder<F, decltype(f1), decltype(f2), Real> rfinder(f, f1, f2);
   rfinder.get_roots(outit, a, b, eps);
}

template<typename OutputIterator,
   typename F, typename F1, typename F2, typename Real = double>
void get_roots(OutputIterator outit, F f, F1 f1, F2 f2,
      Real a, Real b, Real eps) {
   RootFinder<F, F1, F2, Real> rfinder(f, f1, f2);
   rfinder.get_roots(outit, a, b, eps);
}

#endif
