/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       SPARSE BLAS and BLAS USM APIs to solve a system of linear equations (Ax=b)
*       by preconditioned Conjugate Gradient (PCG) method with the Symmetric
*       Gauss-Seidel preconditioner:
*
*       Solve A*x = b
*
*       x_0 initial guess
*       r_0 = b - A*x_0
*       k = 0
*       while (||r_k|| / ||r_0|| > relTol and k < maxIter )
*
*           solve M*z_k = r_k for z_k
*           if (k == 0)
*               p_1 = z_0
*           else
*               beta_k = dot(r_k, z_k) / dot(r_{k-1}, z_{k-1})
*               p_{k+1} = z_k + beta_k * p_k
*           end if
*           Ap_{k+1} = A*p_{k+1}
*           alpha_{k+1} = (r_k, z_k) / (p_{k+1}, Ap_{k+1})
*
*           x_{k+1} = x_k + alpha_{k+1} * p_{k+1}
*           r_{k+1} = r_k - alpha_{k+1} * Ap_{k+1}
*           if (||r_k|| < absTol) break with convergence
*
*           k=k+1
*       end
*
*       where A = L+D+L^T; M = (D+L)*D^{-1}*(D+L^t).
*
*       Note that
*
*         x is the solution
*         r is the residual
*         z is the preconditioned residual
*         p is the search direction
*
*       and we are using ||r||_2 for stopping criteria and alpha/beta scalars are
*       provided as constants from the host.
*
*
*       The supported floating point data types for matrix data in this example
*       are:
*           float
*           double
*
*       This example uses a matrix in CSR format.
*
*/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

template <typename dataType, typename intType>
class extractDiagonalClass;

template <typename dataType, typename intType>
class modifyDiagonalClass;

template <typename dataType, typename intType>
class diagonalMVClass;


//
// extract diagonal from matrix
//
template <typename dataType, typename intType>
sycl::event extract_diagonal(sycl::queue q,
                             const intType n,
                             const intType *ia_d,
                             const intType *ja_d,
                             const dataType *a_d,
                                   dataType *d_d,
                             const std::vector<sycl::event> &deps = {})
{
    return q.submit([&](sycl::handler &cgh) {
        cgh.depends_on(deps);
        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            for (intType i = ia_d[row]; i < ia_d[row + 1]; i++) {
                if (ja_d[i] == row) {
                    dataType diagVal = a_d[i];
                    d_d[row] = diagVal;
                    break;
                }
            }
        };
        cgh.parallel_for<class extractDiagonalClass<dataType, intType>>(sycl::range<1>(n), kernel);
    });
}

//
// Modify diagonal value in matrix
//
template <typename dataType, typename intType>
sycl::event modify_diagonal(sycl::queue q,
                            const dataType new_diagVal,
                            const intType n,
                            const intType *ia_d,
                            const intType *ja_d,
                                  dataType *a_d, // to be modified
                                  dataType *d_d, // to be modified
                            const std::vector<sycl::event> &deps = {})
{
    assert(new_diagVal != dataType(0.0) );
    return q.submit([&](sycl::handler &cgh) {
        cgh.depends_on(deps);
        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            for (intType i = ia_d[row]; i < ia_d[row + 1]; i++) {
                if (ja_d[i] == row) {
                    a_d[i] = new_diagVal;
                    d_d[row] = new_diagVal;
                    break;
                }
            }
        };
        cgh.parallel_for<class modifyDiagonalClass<dataType, intType>>(sycl::range<1>(n), kernel);
    });
}


//
// Scale by diagonal
//
// t = D * t
//
template <typename dataType, typename intType>
sycl::event diagonal_mv(sycl::queue q,
                        const intType n,
                        const dataType *d,
                              dataType *t,
                        const std::vector<sycl::event> &deps = {})
{
    return q.submit([&](sycl::handler &cgh) {
        cgh.depends_on(deps);
        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            t[row] *= d[row];
        };
        cgh.parallel_for<class diagonalMVClass<dataType, intType>>(sycl::range<1>(n), kernel);
    });
}


//
// Gauss-Seidel Preconditioner
//
// solve M z = r   where M = (L+D)*inv(D)*(D+U)
//
// t = inv(D+L) * r;   // forward triangular solve
// t = D*t             // diagonal mv
// z = inv(D+U) * t    // backward triangular solve
//
template <typename dataType, typename intType>
sycl::event precon_gauss_seidel(sycl::queue q,
                                const intType n,
                                oneapi::mkl::sparse::matrix_handle_t csrA,
                                const dataType *d,
                                const dataType *r,
                                      dataType *t, // temporary workspace
                                      dataType *z, // output
                                const std::vector<sycl::event> &deps = {})
{

    auto ev_trsvL = oneapi::mkl::sparse::trsv(q, oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans,
            oneapi::mkl::diag::nonunit, dataType(1.0) /* alpha */, csrA, r, t, deps);
    auto ev_diagmv = diagonal_mv<dataType, intType>(q, n, d, t, {ev_trsvL});
    auto ev_trsvU = oneapi::mkl::sparse::trsv(q, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans,
            oneapi::mkl::diag::nonunit, dataType(1.0) /* alpha */, csrA, t, z, {ev_diagmv});

    return ev_trsvU;
}



template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    // handle for sparse matrix
    oneapi::mkl::sparse::matrix_handle_t csrA = nullptr;

    // declare all arrays to be allocated
    intType *ia_h = nullptr, *ja_h = nullptr;
    intType *ia_d = nullptr, *ja_d = nullptr;
    dataType *a_h = nullptr, *x_h = nullptr, *b_h = nullptr;
    dataType *a_d = nullptr, *x_d = nullptr, *b_d = nullptr;
    dataType *r_d = nullptr, *z_d = nullptr, *p_d = nullptr, *d_d = nullptr, *t_d = nullptr;
    dataType *temp_d = nullptr, *temp_h = nullptr;

    try {

        // Matrix data size
        const intType size  = 16;
        const intType n = size * size * size; // A is n x n

        const intType nnzUB = 27 * n; // upper bound of nnz from 27 point stencil

        // PCG settings
        const intType maxIter = 500;
        const dataType relTol = 1.0e-5;
        const dataType absTol = 5.0e-4;

        // Input matrix in CSR format
        ia_h = sycl::malloc_host<intType>(n+1, q);
        ja_h = sycl::malloc_host<intType>(nnzUB, q);
        a_h  = sycl::malloc_host<dataType>(nnzUB, q);
        x_h  = sycl::malloc_host<dataType>(n, q);
        b_h  = sycl::malloc_host<dataType>(n, q);

        if (!ia_h || !ja_h || !a_h || !x_h || !b_h ) {
            std::string errorMessage =
                "Failed to allocate USM host memory arrays \n"
                " for CSR A matrix: ia_h(" + std::to_string((n+1)*sizeof(intType)) + " bytes)\n"
                "                   ja_h(" + std::to_string((nnzUB)*sizeof(intType)) + " bytes)\n"
                "                   a_h(" + std::to_string((nnzUB)*sizeof(dataType)) + " bytes)\n"
                " and vectors:      x_h(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   b_h(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n";

            throw std::runtime_error(errorMessage);
        }

        //
        // Generate a 27 point stencil for 3D laplacian using size elements in each dimension
        //
        generate_sparse_matrix<dataType, intType>(size, ia_h, ja_h, a_h);

        const intType nnz = ia_h[n]; // assumes zero indexing

        // Init right hand side and vector x
        for (int i = 0; i < n; i++) {
            b_h[i] = set_fp_value(dataType(1.0), dataType(0.0)); // rhs b = 1
            x_h[i] = set_fp_value(dataType(0.0), dataType(0.0)); // initial guess x0 = 0
        }

        //
        // Execute Preconditioned Conjugate Gradient Algorithm
        //
        // solve  A x = b  starting with initial guess x = x0
        //

        std::cout << "\n\t\tsparse PCG parameters:\n";

        std::cout << "\t\t\tA size: (" << n << ", " << n << ") with nnz = " << nnz << " elements stored" << std::endl;
        std::cout << "\t\t\tPreconditioner = Symmetric Gauss-Seidel" << std::endl;
        std::cout << "\t\t\tmax iterations = " << maxIter << std::endl;
        std::cout << "\t\t\trelative tolerance limit = " << relTol << std::endl;
        std::cout << "\t\t\tabsolute tolerance limit = " << absTol << std::endl;



        // create arrays for help
        ia_d = sycl::malloc_device<intType>(n+1, q);  // matrix rowptr
        ja_d = sycl::malloc_device<intType>(nnz, q);  // matrix columns
        a_d  = sycl::malloc_device<dataType>(nnz, q); // matrix values
        x_d  = sycl::malloc_device<dataType>(n, q);   // solution
        b_d  = sycl::malloc_device<dataType>(n, q);   // right hand side
        r_d  = sycl::malloc_device<dataType>(n, q);   // residual
        z_d  = sycl::malloc_device<dataType>(n, q);   // preconditioned residual
        p_d  = sycl::malloc_device<dataType>(n, q);   // search direction
        t_d  = sycl::malloc_device<dataType>(n, q);   // helper array
        d_d  = sycl::malloc_device<dataType>(n, q);   // matrix diagonals

        const intType width = 8; // width * sizeof(dataType) >= cacheline size (64 Bytes)
        temp_d = sycl::malloc_device<dataType>(3*width, q);
        temp_h = sycl::malloc_host<dataType>(3*width, q);

        if ( !ia_d || !ja_d || !a_d || !x_d || !b_d || !z_d || !p_d || !t_d || !d_d || !temp_d || !temp_h) {
            std::string errorMessage =
                "Failed to allocate USM device/USM host memory arrays \n"
                " for CSR A matrix: ia_d(" + std::to_string((n+1)*sizeof(intType)) + " bytes)\n"
                "                   ja_d(" + std::to_string((nnz)*sizeof(intType)) + " bytes)\n"
                "                   a_d(" + std::to_string((nnz)*sizeof(dataType)) + " bytes)\n"
                " and vectors:      x_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   b_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   z_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   p_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   t_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                "                   d_d(" + std::to_string((n)*sizeof(dataType)) + " bytes)\n"
                " and temp arrays:  temp_d(" + std::to_string((4*width)*sizeof(dataType)) + " bytes)\n"
                "                   temp_h(" + std::to_string((4*width)*sizeof(dataType)) + " bytes)";

            throw std::runtime_error(errorMessage);
        }

        // device side aliases scattered by width elements each
        dataType *normr_h  = temp_h;
        dataType *rtz_h    = temp_h+1*width;
        dataType *pAp_h    = temp_h+2*width;
        dataType *normr_d  = temp_d;
        dataType *rtz_d    = temp_d+1*width;
        dataType *pAp_d    = temp_d+2*width;

        // copy data from host to device arrays
        q.copy(ia_h, ia_d, n+1).wait();
        q.copy(ja_h, ja_d, nnz).wait();
        q.copy(a_h, a_d, nnz).wait();
        q.copy(x_h, x_d, n).wait();
        q.copy(b_h, b_d, n).wait();

        extract_diagonal<dataType, intType>(q,n, ia_d, ja_d, a_d, d_d, {}).wait();

        modify_diagonal<dataType, intType>(q, dataType(52.0), n, ia_d, ja_d, a_d, d_d, {}).wait();

        // setup optimizations and properties we know about A matrix
        oneapi::mkl::sparse::init_matrix_handle(&csrA);

        auto ev_set = oneapi::mkl::sparse::set_csr_data(q, csrA, n, n,
                oneapi::mkl::index_base::zero, ia_d, ja_d, a_d, {});

        oneapi::mkl::sparse::set_matrix_property(csrA, oneapi::mkl::sparse::property::symmetric);
        oneapi::mkl::sparse::set_matrix_property(csrA, oneapi::mkl::sparse::property::sorted);

        auto ev_optSvL = oneapi::mkl::sparse::optimize_trsv(q,
                oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans,
                oneapi::mkl::diag::nonunit, csrA, {ev_set});
        auto ev_optSvU = oneapi::mkl::sparse::optimize_trsv(q,
                oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans,
                oneapi::mkl::diag::nonunit, csrA, {ev_optSvL});
        auto ev_optGemv = oneapi::mkl::sparse::optimize_gemv(q,
                oneapi::mkl::transpose::nontrans, csrA, {ev_optSvU});
        // done setting up optimizations for A matrix

        // initial residual r_0 = b - A * x_0
        auto ev_r = oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans, 1.0, csrA,
                                      x_d, 0.0, r_d, {ev_optGemv}); // r := A * x

        ev_r = oneapi::mkl::blas::axpby(q, n, 1.0, b_d, 1, -1.0, r_d, 1, {ev_r}); // r := 1 * b + -1 * r

        auto ev_normr = oneapi::mkl::blas::nrm2(q, n, r_d, 1, normr_d, {ev_r});
        dataType oldrTz = 0.0, rTz = 0.0, pAp = 0.0, normr = 0.0, normr_0 = 0.0;
        {
            q.copy(normr_d, normr_h, 12, {ev_normr}).wait();
            normr = std::sqrt(normr_h[0]);
            normr_0 = normr;
        }

        sycl::event ev_z, ev_rtz, ev_p, ev_Ap, ev_pAp, ev_x;

        std::int32_t k = 0;
        while ( normr / normr_0 > relTol && k < maxIter) {

            // Calculation z_k = M^{-1}r_k
            ev_z = precon_gauss_seidel<dataType, intType>(q, n, csrA, d_d, r_d, t_d, z_d, {ev_r});

            if (k == 0 ) {
                ev_rtz = oneapi::mkl::blas::dot(q, n, r_d, 1, z_d, 1, rtz_d, {ev_r, ev_z});
                {
                    q.copy(rtz_d, rtz_h, 1, {ev_rtz}).wait(); // synch point
                    rTz = rtz_h[0];
                }

                // copy D2D: p_1 = z_0
                ev_p = oneapi::mkl::blas::copy(q, n, z_d, 1, p_d, 1, {ev_z, ev_rtz});
            }
            else {
                // beta_{k+1} = dot(r_k, z_k) / dot(r_{k-1}, z_{k-1})
                ev_rtz = oneapi::mkl::blas::dot(q, n, r_d, 1, z_d, 1, rtz_d, {ev_r, ev_z});
                {
                    q.copy(rtz_d, rtz_h, 1, {ev_rtz}).wait(); // synch point
                    oldrTz = rTz;
                    rTz = rtz_h[0];
                }

                // Calculate p_{k+1} = z_{k+1} + beta_{k+1} * p_k
                ev_p = oneapi::mkl::blas::axpby(q, n, 1.0, z_d, 1, rTz / oldrTz, p_d, 1, {ev_rtz});

            }

            // Calculate Ap_{k+1} = A*p_{k+1}
            ev_Ap = oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans,
                    1.0, csrA, p_d, 0.0, t_d, {ev_p});

            // alpha_{k+1} = dot(r_k, z_k) / dot(p_{k+1}, Ap_{k+1})
            ev_pAp = oneapi::mkl::blas::dot(q, n, p_d, 1, t_d, 1, pAp_d, {ev_Ap});
            {
                q.copy(pAp_d, pAp_h, 1, {ev_pAp}).wait(); // synch point
                pAp = pAp_h[0];
            }

            // Calculate x_{k+1} = x_k + alpha_{k+1}*p_{k+1}
            ev_x = oneapi::mkl::blas::axpy(q, n, rTz / pAp, p_d, 1, x_d, 1, {});

            // Calculate r_{k+1} = r_k - alpha_{k+1}*Ap_{k+1} (note that t = A*p_{k+1} right now so it can be reused here)
            ev_r = oneapi::mkl::blas::axpy(q, n, -rTz / pAp, t_d, 1, r_d, 1, {});

            // temp_d = ||r_{k+1}||^2
            ev_normr = oneapi::mkl::blas::nrm2(q, n, z_d, 1, normr_d, {ev_r});
            {
                q.copy(normr_d, normr_h, 1, {ev_normr}).wait(); // synch point
                normr = std::sqrt(normr_h[0]);
            }

            k++; // increment k counter
            std::cout << "\t\t\t\trelative norm of residual on " << std::setw(4) << k  // output in 1 base indexing
                      << " iteration: " << normr / normr_0 << std::endl;
            if (normr <= absTol) {
                std::cout << "\t\t\t\tabsolute norm of residual on " << std::setw(4) << k // output in 1-based indexing
                    << " iteration: " <<  normr << std::endl;
                break;
            }

        } // while normr / normr_0 > relTol && k < maxIter

        if (normr < absTol) {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has successfully converged in absolute error in " << std::setw(4) << k << " steps with" << std::endl;
            good = true;
        }
        else if (k <= maxIter && normr / normr_0 <= relTol) {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has successfully converged in relative error in " << std::setw(4) << k << " steps with" << std::endl;
            good = true;
        } else {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has not converged after " << k << " steps with" << std::endl;
            good = false;
        }

        oneapi::mkl::sparse::release_matrix_handle(q, &csrA, {}).wait();

        std::cout << "\t\t relative error ||r||_2 / ||r_0||_2 = " << normr / normr_0 << (normr / normr_0 < relTol ? " < " : " > ") << relTol << std::endl;
        std::cout << "\t\t absolute error ||r||_2             = " << normr << (normr < absTol ? " < " : " > ") << absTol << std::endl;
        std::cout << "" << std::endl;

        q.wait_and_throw();
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        good = false;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of matrix handle and others for if exceptions happened
    oneapi::mkl::sparse::release_matrix_handle(q, &csrA, {}).wait();

    //  clean up USM memory allocations
    if(ia_h) sycl::free(ia_h, q);
    if(ja_h) sycl::free(ja_h, q);
    if(a_h) sycl::free(a_h, q);
    if(x_h) sycl::free(x_h, q);
    if(b_h) sycl::free(b_h, q);
    if(ia_d) sycl::free(ia_d, q);
    if(ja_d) sycl::free(ja_d, q);
    if(a_d) sycl::free(a_d, q);
    if(x_d) sycl::free(x_d, q);
    if(b_d) sycl::free(b_d, q);
    if(r_d) sycl::free(r_d, q);
    if(z_d) sycl::free(z_d, q);
    if(p_d) sycl::free(p_d, q);
    if(t_d) sycl::free(t_d, q);
    if(d_d) sycl::free(d_d, q);
    if(temp_d) sycl::free(temp_d, q);
    if(temp_h) sycl::free(temp_h, q);

    q.wait();

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Preconditioned CG Example with USM: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# A * x = b" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a sparse matrix in CSR format, x and b are "
                 "dense vectors"
              << std::endl
              << "# and the alpha/beta scalars are managed from host side"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }


    } // for loop over devices

    mkl_free_buffers();
    return status;
}
