/*******************************************************************************
 * Copyright 2016 Intel Corporation.
 *
 *
 * This software and the related documents are Intel copyrighted materials, and your use of them is governed by
 * the express license under which they were provided to you ('License'). Unless the License provides otherwise,
 * you may not use, modify, copy, publish, distribute, disclose or transmit this software or the related
 * documents without Intel's prior written permission.
 * This software and the related documents are provided as is, with no express or implied warranties, other than
 * those that are expressly stated in the License.
 *******************************************************************************/

#include "base_proc.h"

class BaseParallelInterface
{
public:
    BaseParallelInterface() {}

    virtual ~BaseParallelInterface() {}

    virtual Status parallel_for(BaseProc &proc, Image &src, Image &dst, Size) { return proc.Run(src, dst); }

    virtual unsigned int ThreadsGetNum() { return 1; }

    virtual void ThreadsSetNum(unsigned int) {}
};

#ifdef USE_TBB
  #include "tbb/task_arena.h"
  #include "tbb/parallel_for.h"
  #include "tbb/blocked_range2d.h"
  #define TBB_PREVIEW_GLOBAL_CONTROL 1
  #include "tbb/global_control.h"

class TBBParallelInterface : public BaseParallelInterface
{
public:
    TBBParallelInterface() { ThreadsGetNum(); }

    virtual ~TBBParallelInterface() {}

    virtual Status parallel_for(BaseProc &proc, Image &src, Image &dst, Size tile)
    {
        tbb::blocked_range2d<size_t, size_t> tbbRange(0, (size_t)dst.m_size.height, (size_t)tile.height, 0, (size_t)dst.m_size.width,
                                                      (size_t)tile.width);
        try {
            tbb::global_control set_num_threads(tbb::global_control::max_allowed_parallelism, TBBParallelInterface::m_threads);
            tbb::parallel_for(tbbRange, TBBTask(proc, src, dst), m_tbbPart);
        } catch (Status status) {
            return status;
        }
        return STS_OK;
    }

    virtual unsigned int ThreadsGetNum()
    {
        if (!TBBParallelInterface::m_threads) {
  #if TBB_INTERFACE_VERSION >= 9100
            TBBParallelInterface::m_threads = tbb::this_task_arena::max_concurrency();
  #else
            TBBParallelInterface::m_threads = tbb::task_arena::max_concurrency();
  #endif
        }
        return TBBParallelInterface::m_threads;
    }

    virtual void ThreadsSetNum(unsigned int threads)
    {
        if (threads > 0) {
            TBBParallelInterface::m_threads = threads;
        } else {
  #if TBB_INTERFACE_VERSION >= 9100
            TBBParallelInterface::m_threads = tbb::this_task_arena::max_concurrency();
  #else
            TBBParallelInterface::m_threads = tbb::task_arena::max_concurrency();
  #endif
        }
    }

    class TBBTask
    {
    public:
        TBBTask(BaseProc &proc, Image &src, Image &dst) : m_proc(proc), m_src(src), m_dst(dst) {}

        void operator()(tbb::blocked_range2d<size_t, size_t> &r) const
        {
            Status status = m_proc.Run(
                m_src, m_dst, Rect(r.cols().begin(), r.rows().begin(), r.cols().end() - r.cols().begin(), r.rows().end() - r.rows().begin()));
            if (status != STS_OK)
                throw(status);
        }

        BaseProc &m_proc;
        Image &m_src;
        Image &m_dst;
    };

protected:
    static unsigned int m_threads;
    tbb::affinity_partitioner m_tbbPart;
};

unsigned int TBBParallelInterface::m_threads = 0;
#endif

#ifdef _OPENMP
  #include "omp.h"

class OMPParallelInterface : public BaseParallelInterface
{
public:
    OMPParallelInterface() { ThreadsGetNum(); }

    virtual ~OMPParallelInterface() {}

    virtual Status parallel_for(BaseProc &proc, Image &src, Image &dst, Size tile)
    {
        Status mainStatus = STS_OK;

        if (!tile.width)
            tile.width = dst.m_size.width;
        if (!tile.height)
            tile.height = dst.m_size.height;

        int numTilesX = (int)((dst.m_size.width + tile.width - 1) / tile.width);
        int numTilesY = (int)((dst.m_size.height + tile.height - 1) / tile.height);
        int numTiles = numTilesX * numTilesY;

        int saved_threads_num = omp_get_max_threads();
        omp_set_num_threads(OMPParallelInterface::m_threads);

  #pragma omp parallel num_threads(ThreadsGetNum())
        {
  #pragma omp for schedule(static) nowait
            for (int tile_num = 0; tile_num < numTiles; tile_num++) {
                if (mainStatus == STS_OK) {
                    int tileY = tile_num / numTilesX;
                    int tileX = tile_num - tileY * numTilesX;
                    Rect roi(tileX * tile.width, tileY * tile.height, tile.width, tile.height);

                    try {
                        Status status = proc.Run(src, dst, roi);
                        if (status < 0)
                            throw status;
                    } catch (...) {
                        mainStatus = STS_ERR_FAILED;
                    }
                }
            }
        }
        omp_set_num_threads(saved_threads_num);
        return mainStatus;
    }

    virtual unsigned int ThreadsGetNum()
    {
        if (!OMPParallelInterface::m_threads)
            OMPParallelInterface::m_threads = omp_get_max_threads();
        return OMPParallelInterface::m_threads;
    }

    virtual void ThreadsSetNum(unsigned int threads)
    {
        if (threads > 0) {
            OMPParallelInterface::m_threads = threads;
        } else {
            OMPParallelInterface::m_threads = omp_get_max_threads();
        }
    }

protected:
    static unsigned int m_threads;
};

unsigned int OMPParallelInterface::m_threads = 0;
#endif

static ParallelInterface allocateInterface(ParallelInterface parallel, BaseParallelInterface **ppInter)
{
    if (ppInter)
        *ppInter = NULL;

    switch (parallel) {
    case PARALLEL_ANY:
#ifdef USE_TBB
    case PARALLEL_TBB:
        if (ppInter)
            *ppInter = new TBBParallelInterface();
        return PARALLEL_TBB;
#endif
#ifdef _OPENMP
    case PARALLEL_OMP:
        if (ppInter)
            *ppInter = new OMPParallelInterface();
        return PARALLEL_OMP;
#endif
    case PARALLEL_NONE:
    default:
        if (ppInter)
            *ppInter = new BaseParallelInterface();
        return PARALLEL_NONE;
    }
}

BaseProc::BaseProc()
{
    m_parallelType = PARALLEL_NONE;
    m_pParallel = new BaseParallelInterface();
}

BaseProc::~BaseProc() { RemoveParallelInterface(); }

ParallelInterface BaseProc::CheckParallelInterface(ParallelInterface parallel) { return allocateInterface(parallel, NULL); }

Status BaseProc::SetParallelInterface(ParallelInterface parallel)
{
    RemoveParallelInterface();

    m_parallelType = allocateInterface(parallel, &m_pParallel);
    if (!m_pParallel)
        return STS_ERR_ALLOC;
    return STS_OK;
}

Status BaseProc::RemoveParallelInterface()
{
    if (m_pParallel)
        delete m_pParallel;
    m_pParallel = NULL;
    m_parallelType = PARALLEL_NONE;
    return STS_OK;
}

Status BaseProc::RunParallel(Image &src, Image &dst, Size tile)
{
    if (m_pParallel && m_pParallel->ThreadsGetNum() > 1)
        return m_pParallel->parallel_for(*this, src, dst, tile);
    else {
        if (!tile.width)
            tile.width = dst.m_size.width;
        if (!tile.height)
            tile.height = dst.m_size.height;

        Rect roi(0, 0, tile.width, tile.height);
        for (roi.y = 0; roi.y < dst.m_size.height; roi.y += roi.height) {
            for (roi.x = 0; roi.x < dst.m_size.width; roi.x += roi.width) {
                Status status = Run(src, dst, roi);
                if (status < 0)
                    throw status;
            }
        }
        return STS_OK;
    }
}

unsigned int BaseProc::ThreadsGetNum() { return m_pParallel->ThreadsGetNum(); }

void BaseProc::ThreadsSetNum(unsigned int threads) { m_pParallel->ThreadsSetNum(threads); }
