/*
 * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

// Compile:
//  c++  NetworkPlugin.cpp -I${NVTX_PATH} -I${NVTXEXT_PATH} -ldl -o network_plugin
//
// Run as Nsight Systems plugin:
//  nsys profile --enable network_interface,-i10000,-d"e.*",-m".*_bytes",-m".*error.*" -t nvtx
// -o net -f true -s none sleep 7
//
// When testing an Nsight Systems plugin, you can profile it with Nsight Systems and trace its NVTX
// output to confirm it generates the expected data. Use the following `nsys` command line to test
// the plugin as a target process:
//  nsys profile -t nvtx -o net -s none ./network_interface -i 100000 -d "e.*" -m tx_bytes
// -m ".*error"

#include <getopt.h>
#include <stdint.h>
#include <stdio.h>
#include <unistd.h>

#include <filesystem>
#include <fstream>
#include <iostream>
#include <regex>
#include <string>
#include <unordered_map>
#include <vector>

#include <nvtx3/nvToolsExtCounters.h>

#ifdef DEBUG
#define LOG_DBG(...)          \
fprintf(stderr, __VA_ARGS__); \
fputs("\n", stderr)
#else
#define LOG_DBG(...)
#endif

#define LOG_ERR(...)          \
fprintf(stderr, __VA_ARGS__); \
fputs("\n", stderr)

namespace fs = std::filesystem;

namespace {

struct Device
{
    std::unordered_map<std::string, std::ifstream> Metrics;
    uint64_t CountersId;
    std::vector<int64_t> PrevValues;

    Device() = default;
    Device(const Device&) = delete;
    Device(Device&&) = default;
};

struct FilesystemDevice
{
    std::string DevicePath;
    Device DeviceInfo;

    FilesystemDevice() = default;
    FilesystemDevice(const std::string& devicePath)
        : DevicePath(devicePath)
    {
    }
};

using DeviceContainer = std::vector<FilesystemDevice>;

struct Config
{
    unsigned int UsecSleep = 100000;
    std::vector<std::regex> DeviceFilters;
    std::vector<std::regex> MetricFilters;
    bool OnlyPhysicalDevices = false;
};

nvtxDomainHandle_t NetDomain = 0;

// Base path to network devices.
const fs::path NetBasePath = "/sys/class/net";

int64_t ReadFile(std::ifstream& fstream)
{
    std::string line;
    std::getline(fstream, line);

    fstream.clear();
    fstream.seekg(0);

    long long value = 0;
    try
    {
        value = std::stoll(line);
    }
    catch (std::logic_error& e)
    {
        LOG_ERR("Error while reading the value %s (%s)", line.c_str(), e.what());
    }

    return value;
}

Config ParseCliArgs(const int argc, char** argv)
{
    Config config;

    static const struct option longOptions[] = {
        {"help", no_argument, 0, 'h'},
        {"interval", required_argument, 0, 'i'},
        {"device", required_argument, 0, 'd'},
        {"metric", required_argument, 0, 'm'},
        {0, 0, 0, 0}};

    const char* optString = "hi:d:m:";
    auto PrintHelpOutput = [optString](const char* progName)
    {
        printf("Usage: %s -[%s]\n", progName, optString);
        printf("  -i | --interval \t Sampling interval in microseconds (default: 100000)\n");
        printf("  -d | --device   \t Device name as regular expression\n"
               "                  \t (default: all physical devices)\n");
        printf("  -m | --metric   \t Metric name as regular expression (default: .*_bytes)\n");
        printf("  -h | --help     \t Print this help message\n");

        printf("\nTo specify multiple devices and metrics the respective options can be used\n"
               "multiple times and as regular expression.\n");
        printf("When using `-d all` all adapters will be considered (not only physical ones).\n");
        printf("When using `-m all` all metrics are selected (not only metrics in bytes).\n");
    };

    // Silence `getopt_long` error messages.
    opterr = 0;
    int result, index = -1;
    while ((result = getopt_long(argc, argv, optString, longOptions, &index)) != -1)
    {
        switch (result)
        {
        case 'i':
            config.UsecSleep = std::stoul(optarg, nullptr);
            break;
        case 'd':
            if (std::string("all") == optarg)
            {
                config.DeviceFilters.clear();
                config.DeviceFilters.emplace_back(".+");
                break;
            }

            try
            {
                config.DeviceFilters.emplace_back(optarg);
            }
            catch (const std::regex_error& e)
            {
                LOG_ERR("Cannot create regex for %s (%s)", optarg, e.what());
                exit(3);
            }
            break;
        case 'm':
            if (std::string("all") == optarg)
            {
                config.MetricFilters.clear();
                config.MetricFilters.emplace_back(".+");
                break;
            }

            try
            {
                config.MetricFilters.emplace_back(optarg);
            }
            catch (const std::regex_error& e)
            {
                LOG_ERR("Cannot create regex for %s (%s)", optarg, e.what());
                exit(3);
            }
            break;
        case 'h':
            PrintHelpOutput(argv[0]);
            exit(0);
        default:
            if (optopt != '\0')
            {
                LOG_ERR("Unknown option or missing argument '-%c'\n", optopt);
            }
            else
            {
                LOG_ERR("Unknown option or missing argument '%s'\n", argv[optind - 1]);
            }
            PrintHelpOutput(argv[0]);
            exit(0);
        }
    }

    if (config.DeviceFilters.empty())
    {
        config.DeviceFilters.emplace_back(".+");
        config.OnlyPhysicalDevices = true;
    }

    // Set default metric names.
    if (config.MetricFilters.empty())
    {
        config.MetricFilters.emplace_back(".*_bytes");
    }

    return config;
}

bool IsDeviceValid(const fs::path& devicePath)
{
    // https://www.kernel.org/doc/Documentation/ABI/testing/sysfs-class-net
    // https://github.com/torvalds/linux/blob/master/include/uapi/linux/if.h
    // https://github.com/torvalds/linux/blob/master/include/uapi/linux/if_arp.h

    std::ifstream ifs(devicePath / "operstate");
    std::string content((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
    bool isUp = (content.compare(0, 2, "up") == 0);

    ifs = std::ifstream(devicePath / "ifindex");
    content.assign((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
    uint32_t ifindex = 0;
    try
    {
        ifindex = std::stol(content, nullptr);
    }
    catch (std::logic_error& e)
    {
        LOG_ERR("Error while reading the ifindex value for %s (%s)", devicePath.c_str(), e.what());
        return false;
    }

    ifs = std::ifstream(devicePath / "iflink");
    content.assign((std::istreambuf_iterator<char>(ifs)), (std::istreambuf_iterator<char>()));
    uint32_t iflink = 0;
    try
    {
        iflink = std::stol(content, nullptr);
    }
    catch (std::logic_error& e)
    {
        LOG_ERR("Error while reading the iflink value for %s (%s)", devicePath.c_str(), e.what());
        return false;
    }

    bool hasDevice = fs::exists(devicePath / "device");
    bool hasBridge = fs::exists(devicePath / "bridge");

    LOG_DBG("Network device path: %s", devicePath.c_str());
    LOG_DBG("  IsUp: %d", isUp);
    LOG_DBG("  Link: %u %u", ifindex, iflink);
    LOG_DBG("  HasDevice: %d", hasDevice);
    LOG_DBG("  HasBridge: %d", hasBridge);

    const bool isDeviceValid = (isUp && (ifindex == iflink) && hasDevice && !hasBridge);
    if (!isDeviceValid)
    {
        LOG_DBG("%s is not a valid physical device for profiling.", devicePath.filename().c_str());
    }

    return isDeviceValid;
}

DeviceContainer GetNetworkDevices(
    const std::vector<std::regex>& devicesRegex,
    const bool onlyPhysicalDevices)
{
    DeviceContainer devices;

    for (const auto& entry : fs::directory_iterator(NetBasePath))
    {
        if (onlyPhysicalDevices && !IsDeviceValid(entry.path()))
        {
            continue;
        }

        for (const std::regex& deviceRegex : devicesRegex)
        {
            const std::string devicePathName = entry.path().filename();
            bool found = false;
            try
            {
                found = std::regex_match(devicePathName, deviceRegex);
            }
            catch (std::regex_error& e)
            {
                LOG_ERR("Error in applying regex to %s (%s)", devicePathName.c_str(), e.what());
            }

            if (found)
            {
                devices.emplace_back(devicePathName);
                break;
            }
        }
    }

    return devices;
}

size_t OpenMetricFiles(DeviceContainer& devices, const Config& config)
{
    size_t maxMetrics = 0;
    for (auto& [deviceName, device] : devices)
    {
        const fs::path statsPath = NetBasePath / deviceName / "statistics";

        for (const auto& entry : fs::directory_iterator(statsPath))
        {
            const fs::path& filePath = entry.path();
            if (!entry.exists() || !entry.is_regular_file())
            {
                LOG_ERR("%s does not exist or is not a regular file.", filePath.c_str());
                continue;
            }

            const std::string metricName = filePath.filename();
            for (const std::regex& metricRegex : config.MetricFilters)
            {
                bool found = false;
                try
                {
                    found = std::regex_match(metricName, metricRegex);
                }
                catch (std::regex_error& e)
                {
                    LOG_ERR("Error in applying regex to %s (%s)", metricName.c_str(), e.what());
                }

                if (found)
                {
                    auto metric = device.Metrics.emplace(metricName, filePath);
                    if (!metric.second)
                    {
                        LOG_DBG("%s could not be inserted.", filePath.c_str());
                        continue;
                    }

                    if (!metric.first->second.is_open())
                    {
                        LOG_ERR("Could not open %s for reading.", filePath.c_str());
                        device.Metrics.erase(metric.first);
                    }
                }
            }
        }

        maxMetrics = std::max(maxMetrics, device.Metrics.size());
    }

    return maxMetrics;
}

void SetupNvtxCounters(DeviceContainer& devices)
{
    for (auto& [deviceName, device] : devices)
    {
        std::vector<nvtxPayloadSchemaEntry_t> schema;
        for (const auto& [metricName, metricPath] : device.Metrics)
        {
            schema.push_back(
                {0, NVTX_PAYLOAD_ENTRY_TYPE_INT64, metricName.c_str(), "", 0, 0, nullptr, nullptr});
        }

        const size_t sizeOfCounterGroup = schema.size() * sizeof(int64_t);

        nvtxPayloadSchemaAttr_t schemaAttr;
        schemaAttr.fieldMask = NVTX_PAYLOAD_SCHEMA_ATTR_FIELD_TYPE | NVTX_PAYLOAD_SCHEMA_ATTR_FIELD_ENTRIES
                               | NVTX_PAYLOAD_SCHEMA_ATTR_FIELD_NUM_ENTRIES
                               | NVTX_PAYLOAD_SCHEMA_ATTR_FIELD_STATIC_SIZE;
        schemaAttr.type = NVTX_PAYLOAD_SCHEMA_TYPE_STATIC;
        schemaAttr.entries = schema.data();
        schemaAttr.numEntries = schema.size();
        schemaAttr.payloadStaticSize = sizeOfCounterGroup;
        const uint64_t schemaId = nvtxPayloadSchemaRegister(NetDomain, &schemaAttr);

        nvtxCounterAttr_t cntAttr;
        memset(&cntAttr, 0, sizeof(cntAttr));
        cntAttr.structSize = sizeof(nvtxCounterAttr_t);
        cntAttr.schemaId = schemaId;
        cntAttr.name = deviceName.c_str();
        cntAttr.scopeId = NVTX_SCOPE_CURRENT_VM;
        cntAttr.semantics = nullptr;
        device.CountersId = nvtxCounterRegister(NetDomain, &cntAttr);
    }
}

} //namespace

int main(int argc, char** argv)
{
    Config config = ParseCliArgs(argc, argv);

    DeviceContainer devices = GetNetworkDevices(config.DeviceFilters, config.OnlyPhysicalDevices);
    if (devices.size() == 0)
    {
        LOG_ERR("No matching network devices available.");
        return 1;
    }

    size_t maxMetricsPerDevice = OpenMetricFiles(devices, config);
    if (maxMetricsPerDevice == 0)
    {
        LOG_ERR("No valid metrics found.");
        return 2;
    }

    NetDomain = nvtxDomainCreateA("Network");

    SetupNvtxCounters(devices);

    // Reserve space for the previous values and set the initial values,
    // because counters are incrementally increasing.
    for (auto& [deviceName, device] : devices)
    {
        device.PrevValues.reserve(device.Metrics.size());
        for (auto& [metricName, metricPath] : device.Metrics)
        {
            device.PrevValues.push_back(ReadFile(metricPath));
        }
    }

    // Reserve space for the values. No need to do this per device,
    // because devices are queried sequentially.
    std::vector<int64_t> values;
    values.reserve(maxMetricsPerDevice);

    while (true)
    {
        for (auto& [deviceName, device] : devices)
        {
            values.clear();
            if (device.PrevValues.size() > values.capacity())
            {
                LOG_ERR(
                    "The `values` container should hold equal or more elements than each of the"
                    "device value containers. Device %s has more elements than what `value` can "
                    "hold",
                    deviceName.c_str());
            }

            std::vector<int64_t>::iterator prevValueIter = device.PrevValues.begin();
            for (auto& [metricName, metricPath] : device.Metrics)
            {
                int64_t value = ReadFile(metricPath);

                if (value < *prevValueIter)
                {
                    LOG_ERR(
                        "[%s] %s: current value is negative or decreasing.",
                        metricName.c_str(),
                        deviceName.c_str());
                    LOG_ERR("  Current value: %ld, previous value: %ld", value, *prevValueIter);
                    value = 0;
                    values.push_back(value);
                }
                else
                {
                    values.push_back(value - *prevValueIter);
                }

                *prevValueIter = value;
                LOG_DBG("[%s] %s: %ld", deviceName.c_str(), metricName.c_str(), values.back());
                ++prevValueIter;
            }

            nvtxCounterSample(
                NetDomain,
                device.CountersId,
                values.data(),
                device.Metrics.size() * sizeof(int64_t));

            usleep(config.UsecSleep);
        }
    }

    // We could close the metric files here, but the program is going to exit anyway so the files
    // will be automatically closed.

    return 0;
}
