# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from datetime import datetime
from pathlib import Path

import pandas as pd

from nsys_recipe import log
from nsys_recipe.data_service import DataService
from nsys_recipe.lib import data_utils, gpu_metrics, helpers, nvtx, recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.table_config import CompositeTable
from nsys_recipe.log import logger


class GpuMetricUtilSum(recipe.Recipe):
    @staticmethod
    def get_range_df(df_dict, parsed_args, report_path):
        if parsed_args.per_kernel:
            range_df = df_dict[CompositeTable.CUDA_KERNEL]
            range_df = range_df.rename(columns={"shortName": "name"})
        elif parsed_args.per_projected_nvtx:
            range_df = nvtx.project_nvtx_onto_gpu(
                df_dict[CompositeTable.NVTX],
                df_dict[CompositeTable.CUDA_COMBINED],
                group_columns="deviceId",
            )
            if range_df.empty:
                logger.info(
                    f"{report_path}: Report does not contain any NVTX data that can be projected onto the GPU."
                )
                return None
            range_df = range_df.rename(columns={"text": "name"})
        else:
            return None

        range_df = range_df.merge(
            df_dict["TARGET_INFO_CUDA_DEVICE"], on="pid", how="left"
        )

        if parsed_args.match_ranges is not None:
            range_df = data_utils.filter_by_pattern(
                range_df, parsed_args.match_ranges, "name"
            )
            if range_df.empty:
                logger.error(
                    f"{report_path}: Report does not contain any range matching '{parsed_args.match_ranges}'."
                )
                return None

        return range_df

    @staticmethod
    def _mapper_func(report_path, parsed_args):
        service = DataService(report_path, parsed_args)

        service.queue_custom_table(CompositeTable.GPU_METRICS)
        service.queue_table("TARGET_INFO_CUDA_DEVICE", ["gpuId", "pid"])

        per_range = True
        if parsed_args.per_kernel:
            service.queue_custom_table(CompositeTable.CUDA_KERNEL)
        elif parsed_args.per_projected_nvtx:
            service.queue_custom_table(CompositeTable.NVTX)
            service.queue_custom_table(CompositeTable.CUDA_COMBINED)
        else:
            per_range = False

        df_dict = service.read_queued_tables()
        if df_dict is None:
            return None

        gpu_metrics_df = df_dict[CompositeTable.GPU_METRICS]
        err_msg = service.filter_and_adjust_time(gpu_metrics_df)
        if err_msg is not None:
            logger.error(f"{report_path}: {err_msg}")
            return None

        for df in df_dict.values():
            if df.empty:
                logger.info(
                    f"{report_path}: Report was successfully processed, but no data was found."
                )
                return None

        if per_range:
            range_df = GpuMetricUtilSum.get_range_df(df_dict, parsed_args, report_path)
            if range_df is None:
                return None

            stats_df = gpu_metrics.calculate_stats_by_range(
                gpu_metrics_df, range_df, parsed_args.longest_n
            )
        else:
            stats_df = gpu_metrics.calculate_stats(gpu_metrics_df)

        if stats_df is None:
            logger.info(
                f"{report_path}: Report does not contain any data that can be matched with the GPU metrics."
            )
            return None

        filename = Path(report_path).stem
        return filename, stats_df

    @log.time("Mapper")
    def mapper_func(self, context):
        return context.wait(
            context.map(
                self._mapper_func,
                self._parsed_args.input,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(self, mapper_res):
        filtered_res = helpers.filter_none(mapper_res)
        # Sort by file name.
        filtered_res = sorted(filtered_res, key=lambda x: x[0])
        filenames, stats_by_device_dfs = zip(*filtered_res)

        files_df = pd.DataFrame({"File": filenames}).rename_axis("Rank")
        files_df.to_parquet(self.add_output_file("files.parquet"))

        stats_by_device_dfs = [
            df.assign(Rank=rank) for rank, df in enumerate(stats_by_device_dfs)
        ]
        rank_stats_by_device_df = pd.concat(stats_by_device_dfs)
        rank_stats_by_device_df.to_parquet(self.add_output_file("rank_stats.parquet"))

        if self._parsed_args.csv:
            files_df.to_csv(self.add_output_file("files.csv"))
            rank_stats_by_device_df.to_csv(self.add_output_file("rank_stats.csv"))

    def save_notebook(self):
        self.create_notebook(
            "stats.ipynb", replace_dict={"REPLACE_N": self._parsed_args.longest_n}
        )
        self.add_notebook_helper_file("nsys_display.py")

    def save_analysis_file(self):
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context):
        super().run(context)

        mapper_res = self.mapper_func(context)
        self.reducer_func(mapper_res)

        self.save_notebook()
        self.save_analysis_file()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        per_group = parser.recipe_group.add_mutually_exclusive_group()
        parser.add_argument_to_group(
            per_group,
            "--per-projected-nvtx",
            action="store_true",
            help="Generate results for each projected NVTX range, grouped by GPU.",
        )
        parser.add_argument_to_group(
            per_group,
            "--per-kernel",
            action="store_true",
            help="Generate results for each kernel.",
        )

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(
            "--match-ranges",
            nargs="+",
            type=str,
            metavar="PATTERN",
            help="Output results for the ranges that match the provided regex pattern.\n"
            "Valid only when used with --per-kernel or --per-projected-nvtx.",
        )
        parser.add_recipe_argument(
            "--longest-n",
            type=int,
            metavar="N",
            default=100,
            help="Output results for the longest N ranges.\n"
            " Valid only when used with --per-kernel or --per-projected-nvtx.",
        )
        parser.add_recipe_argument(Option.CSV)

        filter_group = parser.recipe_group.add_mutually_exclusive_group()
        parser.add_argument_to_group(filter_group, Option.FILTER_TIME)
        parser.add_argument_to_group(filter_group, Option.FILTER_NVTX)
        parser.add_argument_to_group(filter_group, Option.FILTER_PROJECTED_NVTX)

        return parser
