/*******************************************************************************
* Copyright (C) 2014 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file SetupHalo.cpp

 HPCG routine
 */

#ifndef HPCG_NO_MPI
#include <mpi.h>
#include <map>
#include <set>
#endif

#include "SetupHalo.hpp"
#include "UsmUtil.hpp"
#include "SetupHalo_ref.hpp"
#include "PrefixSum.hpp"

template <typename T>
using device_atomic_ref = sycl::atomic_ref<T,
                                          sycl::memory_order::relaxed,
                                          sycl::memory_scope::device>;

/*!
  Prepares system matrix data structure and creates data necessary necessary
  for communication of boundary values of this process.

  @param[inout] A    The known system matrix

  @see ExchangeHalo
*/

void SetupHalo(SparseMatrix & A, sycl::queue & main_queue)
{


#if !defined(HPCG_LOCAL_LONG_LONG) && !defined(HPCG_NO_MPI)
    if( A.geom->size > 1 )
    {
        local_int_t localNumberOfRows = A.localNumberOfRows;
        char  * nonzerosInRow = A.nonzerosInRow;
        global_int_t ** mtxIndG = A.mtxIndG;
        local_int_t ** mtxIndL = A.mtxIndL;
        double t1 = 0;

        local_int_t totalToBeReceived = 0;
        local_int_t number_of_neighbors = A.numberOfSendNeighbors;
        local_int_t *receiveLength_s = NULL, *sendLength_s = NULL, *map_neib_send_s = NULL;
        local_int_t *map_neib_recv_h = A.work;
        int *neighbors_h = NULL;
        global_int_t *map_send_s = NULL;

        double *sendBuffer_d = NULL, *sendBuffer_h = NULL;
        local_int_t  *all_send_h = NULL, *all_recv_h = NULL;
        local_int_t  *elementsToSend_h = NULL, *elementsToSend_d = NULL;
        global_int_t *elementsToSend_G_h = NULL, *elementsToSend_G_d = NULL;
        global_int_t *elementsToRecv_G_h = NULL, *elementsToRecv_G_d = NULL;

        local_int_t* totalToBeSent_h = (local_int_t*) sparse_malloc_host(   sizeof(local_int_t)*1, main_queue);
        local_int_t* totalToBeSent_d = (local_int_t*) sparse_malloc_device( sizeof(local_int_t)*1, main_queue);

        map_send_s      = (global_int_t*) sparse_malloc_shared( sizeof(global_int_t)*A.numOfBoundaryRows*number_of_neighbors, main_queue );
        map_neib_send_s = (local_int_t *) sparse_malloc_shared( sizeof(local_int_t )*A.geom->size                           , main_queue );
        neighbors_h     = (int*         ) sparse_malloc_host(   sizeof(int         )*number_of_neighbors                    , main_queue );
        receiveLength_s = (local_int_t *) sparse_malloc_shared( sizeof(local_int_t )*number_of_neighbors                    , main_queue );
        sendLength_s    = (local_int_t *) sparse_malloc_shared( sizeof(local_int_t )*number_of_neighbors                    , main_queue );

        if ( map_neib_send_s == NULL || map_send_s == NULL || neighbors_h == NULL || receiveLength_s == NULL || sendLength_s == NULL ) return;

        local_int_t local_work_group_size = 256;
        local_int_t total_size = 0;
        local_int_t size = A.geom->size;

        //The iterations are too small, should we offload?
        {
            local_int_t cnt = 0;
            for (local_int_t i = 0; i < size; i ++ ) {
              map_neib_send_s[i] = 0;
              if (map_neib_recv_h[i] > 0) {
                neighbors_h[cnt] = i;
                map_neib_send_s[i] = cnt++;
              }
            }
        }

        auto ev = main_queue.submit([&](sycl::handler &cgh) {
            auto kernel = [=](sycl::item<1> item) {
                local_int_t row = item.get_id(0);
                map_send_s[row] = -1;
                if(row<number_of_neighbors) {
                  receiveLength_s[row] = 0;
                  sendLength_s[row] = 0;
                }
            };
            cgh.parallel_for<class SetupHaloClass1>(sycl::range<1>(A.numOfBoundaryRows*number_of_neighbors), kernel);
        });
        ev.wait();

        global_int_t nx = A.geom->nx;
        global_int_t ny = A.geom->ny;
        global_int_t nz = A.geom->nz;
        global_int_t gnx = A.geom->gnx;
        global_int_t gny = A.geom->gny;
        global_int_t gnz = A.geom->gnz;
        int npx = A.geom->npx;
        int npy = A.geom->npy;
        int npz = A.geom->npz;
        local_int_t npartz = A.geom->npartz;
        local_int_t *partz_nz = A.geom->partz_nz;
        int *partz_ids = A.geom->partz_ids;

        // device-side subroutine
        // TODO: This is nearly identical to a lambda-function in GenerateProblem.cpp, could
        //       maybe share the code somehow
        auto ComputeRankOfMatrixRow = [=](global_int_t gnx, global_int_t gny, local_int_t npartz,
                                          local_int_t *partz_nz,int *partz_ids, local_int_t nx,
                                          local_int_t ny, int npx, int npy,
                                          global_int_t index) -> int {

            global_int_t iz = index/(gny*gnx);
            global_int_t iy = (index-iz*gny*gnx)/gnx;
            global_int_t ix = index%gnx;
            // We now permit varying values for nz for any nx-by-ny plane of MPI processes.
            // npartz is the number of different groups of nx-by-ny groups of processes.
            // partz_ids is an array of length npartz where each value indicates the z process of
            //    the last process in the ith nx-by-ny group.
            // partz_nz is an array of length npartz containing the value of nz for the ith group.
            //        With no variation, npartz = 1, partz_ids[0] = npz, partz_nz[0] = nz

            int ipz = 0;
            int ipartz_ids = 0;
            for (int i=0; i< npartz; ++i) {
                int ipart_nz = partz_nz[i];
                ipartz_ids = partz_ids[i] - ipartz_ids;
                if (iz<= ipart_nz*ipartz_ids) {
                    ipz += iz/ipart_nz;
                    break;
                } else {
                    ipz += ipartz_ids;
                    iz -= ipart_nz*ipartz_ids;
                }
            }
            int ipy = iy/ny;
            int ipx = ix/nx;
            int rank = ipx+ipy*npx+ipz*npy*npx;
            return rank;
        };

        local_int_t* A_boundaryRows = A.boundaryRows;
        global_int_t* A_localToGlobalMap = A.localToGlobalMap;
        local_int_t A_numOfBoundaryRows = A.numOfBoundaryRows;
        ev = main_queue.submit([&](sycl::handler& cgh) {
            cgh.parallel_for(
                sycl::range<1>(A_numOfBoundaryRows),
                [=](sycl::item<1> item) {
                    auto row = item.get_id(0);
                    local_int_t i = A_boundaryRows[row];
                    for (local_int_t j = 0; j < nonzerosInRow[i]; j++) {
                        if (mtxIndL[i][j] < 0) {
                            global_int_t curIndex = mtxIndG[i][j];
                            int rankIdOfColumnEntry = ComputeRankOfMatrixRow(
                                gnx, gny, npartz, partz_nz, partz_ids, nx, ny, npx, npy, curIndex);
                                // *(A.geom), curIndex);
                            local_int_t jj = map_neib_send_s[rankIdOfColumnEntry];

                            if (map_send_s[jj*A_numOfBoundaryRows + row] < 0) {
                                map_send_s[jj*A_numOfBoundaryRows + row] = A_localToGlobalMap[i];
                                auto tmp = device_atomic_ref<local_int_t>(*totalToBeSent_d).fetch_add(1);
                            }
                        }
                    }
                });
        });
        ev.wait();
        main_queue.memcpy(totalToBeSent_h, totalToBeSent_d, sizeof(local_int_t)).wait();

        sendBuffer_d       = (double      *) sparse_malloc_device(sizeof(double      )* (totalToBeSent_h[0]), main_queue);
        sendBuffer_h       = (double      *) malloc(              sizeof(double      )* (totalToBeSent_h[0]));
        elementsToSend_d   = (local_int_t *) sparse_malloc_device(sizeof(local_int_t )* (totalToBeSent_h[0]), main_queue);
        elementsToSend_h   = (local_int_t *) malloc(              sizeof(local_int_t )* (totalToBeSent_h[0]));
        elementsToSend_G_d = (global_int_t*) sparse_malloc_device(sizeof(global_int_t)* (totalToBeSent_h[0]), main_queue);
        elementsToSend_G_h = (global_int_t*) malloc(              sizeof(global_int_t)* (totalToBeSent_h[0]));

        if ( sendBuffer_d == NULL       || sendBuffer_h == NULL       ||  
             elementsToSend_d == NULL   || elementsToSend_h == NULL   ||
             elementsToSend_G_d == NULL || elementsToSend_G_h == NULL) return;

        local_int_t numOfBoundaryRows = A.numOfBoundaryRows;
        ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.parallel_for(
                sycl::range<2>(number_of_neighbors, numOfBoundaryRows),
                [=](sycl::item<2> item) {
                    const local_int_t i = item.get_id(0);
                    const local_int_t j = item.get_id(1);
                    global_int_t ind = map_send_s[i*numOfBoundaryRows + j];
                    if (ind >= 0) {
                        device_atomic_ref<local_int_t>(sendLength_s[i]).fetch_add(1);
                    }
                });
        });

        auto sendLength_sum_d = (local_int_t*) sparse_malloc_device(sizeof(local_int_t) * (number_of_neighbors + 1), main_queue);
        auto ev_fill = main_queue.fill<local_int_t>(sendLength_sum_d, 0, number_of_neighbors + 1);
        auto ev_copy_sendLength = main_queue.memcpy(sendLength_sum_d + 1, sendLength_s, number_of_neighbors * sizeof(local_int_t), {ev, ev_fill});
        auto ev_prefix_sum = prefix_sum(main_queue, number_of_neighbors, sendLength_sum_d, {ev_copy_sendLength});
        ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_prefix_sum);
            cgh.parallel_for(
                sycl::range<2>(number_of_neighbors, numOfBoundaryRows),
                [=](sycl::item<2> item) {
                    const local_int_t i = item.get_id(0);
                    const local_int_t j = item.get_id(1);
                    global_int_t ind = map_send_s[i*numOfBoundaryRows + j];
                    if (ind >= 0) {
                        local_int_t k = device_atomic_ref<local_int_t>(sendLength_sum_d[i]).fetch_add(1);
                        elementsToSend_G_d[k] = ind;
                    }
                });
        });
        main_queue.memcpy(elementsToSend_G_h, elementsToSend_G_d, totalToBeSent_h[0]*sizeof(global_int_t), ev).wait();
        all_send_h = (local_int_t *) sparse_malloc_host(sizeof(local_int_t )*A.geom->size, main_queue);
        all_recv_h = (local_int_t *) sparse_malloc_host(sizeof(local_int_t )*A.geom->size, main_queue);

        if ( all_send_h == NULL || all_recv_h == NULL ) return;

        //parallel_for -- too small, should we offload?
        for( local_int_t i = 0 ; i < A.geom->size; i++ ) { all_send_h[i] = 0; all_recv_h[i] = 0; }

        //parallel_for -- too small, should we offload?
        for (int i = 0; i < number_of_neighbors; i++)
        {
            all_send_h[ neighbors_h[i] ] = sendLength_s[i];
        }
        MPI_Alltoall( all_send_h, 1, MPI_INT, all_recv_h, 1, MPI_INT, MPI_COMM_WORLD );
        local_int_t totalToBeRecv = 0;

        //single task or parallel_for with atomic? -- too small, should we offload?
        for (int i = 0; i < number_of_neighbors; i++)
        {
            receiveLength_s[ i ] = all_recv_h[ neighbors_h[i] ];
            totalToBeRecv += receiveLength_s[ i ];
        }

        elementsToRecv_G_d = (global_int_t*) sparse_malloc_device(sizeof(global_int_t)*totalToBeRecv, main_queue);
        if ( elementsToRecv_G_d == NULL ) return;

        elementsToRecv_G_h = (global_int_t*) malloc(sizeof(global_int_t)*totalToBeRecv);
        if ( elementsToRecv_G_h == NULL ) return;

        int MPI_MY_TAG = 98;

        MPI_Request *request_h = (MPI_Request*) sparse_malloc_host(sizeof(MPI_Request)*number_of_neighbors, main_queue);

        if (request_h == NULL ) return;

        global_int_t *elementsToRecv_pt_h = elementsToRecv_G_h;
        for (int i = 0; i < number_of_neighbors; i++) {
            local_int_t n_recv = receiveLength_s[i];
            MPI_Irecv(elementsToRecv_pt_h, n_recv, MPI_LONG_LONG_INT, neighbors_h[i], MPI_MY_TAG, MPI_COMM_WORLD, request_h+i);
            elementsToRecv_pt_h += n_recv;
        }
        global_int_t *elementsToSend_pt_h = elementsToSend_G_h;
        for (int i = 0; i < number_of_neighbors; i++) {
            local_int_t n_send = sendLength_s[i];
            MPI_Send(elementsToSend_pt_h, n_send, MPI_LONG_LONG_INT, neighbors_h[i], MPI_MY_TAG, MPI_COMM_WORLD);
            elementsToSend_pt_h += n_send;
        }
        MPI_Status status;
        for (int i = 0; i < number_of_neighbors; i++) {
            if ( MPI_Wait(request_h+i, &status) ) {
                std::exit(-1); // TODO: have better error exit
            }
        }
        auto ev_copy_recv = main_queue.memcpy(elementsToRecv_G_d, elementsToRecv_G_h, totalToBeRecv*sizeof(global_int_t));

        //using map - what can be done?
        for( local_int_t i = 0; i < totalToBeSent_h[0]; i++ )
        {
            elementsToSend_h[ i ] = A.globalToLocalMap[ elementsToSend_G_h[ i ] ];
        }

        auto ev_copy_send = main_queue.memcpy(elementsToSend_d, elementsToSend_h, sizeof(local_int_t)*totalToBeSent_h[0]);

        //
        // set up a simple hash map from global indices to local indices for the boundary values
        //
        std::int32_t mapsize = 4*totalToBeRecv; // larger for fewer collisions (speed), costs memory
        global_int_t* syclmap_keys_d = (global_int_t*) sparse_malloc_device(sizeof(global_int_t) * mapsize, main_queue);
        local_int_t* syclmap_values_d = (local_int_t*) sparse_malloc_device(sizeof(local_int_t) * mapsize, main_queue);
        main_queue.fill<global_int_t>(syclmap_keys_d, -1, mapsize).wait();

        auto hash_func = [=](const global_int_t key, const local_int_t hash_size) -> std::int32_t {
            return static_cast<std::int32_t>(key % static_cast<global_int_t>(hash_size));
        };

        auto hash_probe = [=](const std::int32_t h, const local_int_t hash_size) -> std::int32_t {
            return static_cast<std::int32_t>( static_cast<local_int_t>(h + 1) % hash_size);
        };

        auto map_insert = [=](local_int_t N, global_int_t* keys, local_int_t* values,
                              global_int_t key, local_int_t value) {
            // choose bucket with simple hash of the key
            std::int32_t h = hash_func(key, N);
            while (true) {
                global_int_t ckey = keys[h];
                if (ckey == -1) {
                    // atomics to de-conflict parallel setting to the same bucket
                    sycl::atomic_ref<global_int_t,
                                     sycl::memory_order::relaxed,
                                     sycl::memory_scope::device> akey(keys[h]);
                    bool result = akey.compare_exchange_strong(ckey, key);
                    if (result) {
                        values[h] = value;
                        return;
                    }
                }
                h = hash_probe(h, N);
            }
        };

        auto map_get = [=](local_int_t N, global_int_t* keys, local_int_t* values,
                           global_int_t key) -> local_int_t {
            std::int32_t h = hash_func(key, N);
            while (true) {
                if (keys[h] == key) {
                    return values[h];
                }
                h = hash_probe(h, N);
            }
        };

        // fill the hash map with the received global indices in elementsToRecv_G_d
        auto ev_map_insert = main_queue.submit([&](sycl::handler& cgh) {
            cgh.depends_on(ev_copy_recv);
            cgh.parallel_for(
                sycl::range<1>(totalToBeRecv),
                [=](sycl::id<1> id) {
                    map_insert(mapsize, syclmap_keys_d, syclmap_values_d, elementsToRecv_G_d[id], id);
                });
        });
        // fill mtxIndL[i][j] with local indices corresponding to received global indices
        ev = main_queue.submit([&](sycl::handler& cgh) {
            cgh.depends_on(ev_map_insert);
            cgh.parallel_for(
                sycl::range<1>(A.numOfBoundaryRows),
                [=](sycl::item<1> item) {
                    local_int_t row = item.get_id(0);
                    local_int_t i = A_boundaryRows[row];
                    for (local_int_t j = 0; j < nonzerosInRow[i]; j++) {
                        if (mtxIndL[i][j] < 0) {
                            global_int_t curIndex = mtxIndG[i][j];
                            mtxIndL[i][j] = localNumberOfRows + map_get(mapsize, syclmap_keys_d, syclmap_values_d, curIndex);
                        }
                    }
                });
        });
        ev.wait();
        ev_copy_send.wait();

        sycl::free(syclmap_keys_d, main_queue);
        sycl::free(syclmap_values_d, main_queue);

        local_int_t *sdispls = (local_int_t*) sparse_malloc_host( sizeof(local_int_t)*A.geom->size, main_queue);
        local_int_t *rdispls = (local_int_t*) sparse_malloc_host( sizeof(local_int_t)*A.geom->size, main_queue);
        local_int_t *scounts = (local_int_t*) sparse_malloc_host( sizeof(local_int_t)*A.geom->size, main_queue);
        local_int_t *rcounts = (local_int_t*) sparse_malloc_host( sizeof(local_int_t)*A.geom->size, main_queue);
 
        local_int_t tmp_s = 0, tmp_r = 0;

        if(sdispls == NULL || rdispls == NULL || scounts == NULL || rcounts == NULL) return;

        //parallel_for -- too small, should we offload?
        for( local_int_t i = 0; i < A.geom->size; i++ )
        {
            scounts[i] = 0;
            rcounts[i] = 0;
            sdispls[i] = 0;
            rdispls[i] = 0;
        }
        //single_task -- too small, should we offload?
        for( local_int_t i = 0; i < number_of_neighbors; i++ )
        {
            local_int_t root = neighbors_h[i];
            scounts[root] = sendLength_s[i];
            rcounts[root] = receiveLength_s[i];
            sdispls[root] = tmp_s; tmp_s+=sendLength_s[i];
            rdispls[root] = tmp_r; tmp_r+=receiveLength_s[i];
        }
        A.scounts = scounts;
        A.rcounts = rcounts;
        A.sdispls = sdispls;
        A.rdispls = rdispls;

        A.numberOfExternalValues = totalToBeRecv;
        A.localNumberOfColumns = A.localNumberOfRows + A.numberOfExternalValues;
        A.numberOfSendNeighbors = number_of_neighbors;
        A.totalToBeSent = totalToBeSent_h[0];
        A.elementsToSend_d = elementsToSend_d;
        A.elementsToSend_h = elementsToSend_h;
        A.neighbors = neighbors_h;
        A.receiveLength = receiveLength_s;
        A.sendLength = sendLength_s;
        A.sendBuffer = sendBuffer_d;
        A.sendBuffer_h = sendBuffer_h;

        sycl::free(totalToBeSent_h, main_queue);
        sycl::free(totalToBeSent_d, main_queue);
        sycl::free(map_send_s, main_queue);
        sycl::free(map_neib_send_s, main_queue);
        free(elementsToSend_G_h);
        sycl::free(elementsToSend_G_d, main_queue);
        sycl::free(all_send_h, main_queue);
        sycl::free(all_recv_h, main_queue);
        free(elementsToRecv_G_h);
        sycl::free(elementsToRecv_G_d, main_queue);
        sycl::free(request_h, main_queue);

    } else {
        A.numberOfExternalValues = 0;
        A.localNumberOfColumns = A.localNumberOfRows;
        A.numberOfSendNeighbors = 0;
        A.totalToBeSent = 0;
        A.elementsToSend_h = NULL;
        A.elementsToSend_d = NULL;
    }
#else
    SetupHalo_ref(A, main_queue);
#endif
}
