# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from typing import Optional, Union

import pandas as pd


def add_postfix_to_rate_metrics(dfs: Union[pd.DataFrame, list[pd.DataFrame]]) -> None:
    """Modify the provided dataframes and add a postfix to bytes metric names that are related to rates.

    Parameters
    ----------
    dfs : list of dataframes or dataframe
        Dataframes to modify.
    """
    if not isinstance(dfs, list):
        dfs = [dfs]

    for df in dfs:
        if df.empty:
            continue

        name_mapping = {
            "IB: Bytes received": "IB: Bytes received rate",
            "IB: Bytes sent": "IB: Bytes sent rate",
            "Eth: Bytes received": "Eth: Bytes received rate",
            "Eth: Bytes sent": "Eth: Bytes sent rate",
        }
        df["metric_name"] = df["metric_name"].apply(
            lambda x: name_mapping[x] if x in name_mapping else x
        )


def filter_by_bytes_threshold(
    dfs: Union[pd.DataFrame, list[pd.DataFrame]],
    bytes_threshold: int,
    byte_related_metrics: list[str],
) -> None:
    """Filter dataframes to retain only devices that have metric values equal or above the provided
    threshold.

    Parameters
    ----------
    dfs : list of dataframes or dataframe
        Dataframes to filter.
    bytes_threshold : int
        Bytes threshold for metrics that determines which devices we present in the plots.
    byte_related_metric_ids : list of str
        A list of strings that correspond to metric names related to bytes.
    """
    if not isinstance(dfs, list):
        dfs = [dfs]

    for df in dfs:
        if df.empty:
            continue

        mask = pd.Series(True, index=df.index)
        # Collect the indexes for metrics that are related to bytes sent or bytes received, and have
        # values lower than the threshold
        mask = df["metric_name"].isin(byte_related_metrics)
        # Since we store values in B/ms but we print values in MiB/s in the Jupyter notebook and
        # different unit multiples of B/s the GUI, we need to adjust the threshold to B/s. Thus,
        # dividing here by 1000.
        #
        # The reasoning is, users see MiB/s in the heatmap. The threshold they input will be based
        # on the values they see, meaning metric values per second. We scale the threshold to
        # correspond to the values stored in the report files. The stored values are in B/ms
        mask &= df["value"] < bytes_threshold / 1000

        # Discard the rows that meet the `mask` criteria.
        df.drop(df[mask].index, inplace=True)


def calculate_and_add_bytes_per_sample(
    df: pd.DataFrame, rate_metrics: list[str]
) -> pd.DataFrame:
    """Calculate the actual amount of bytes that were sent or received for each sample.
    Add the additional counter values to the dataframe.

    The report file contains samples for byte rates per millisecond. We use the rate values together
    with the duration of each sample to reconstruct the counter values for bytes sent and received.
    The availability of the bytes sent and received values depends on whether the related rate
    metric has been captured in the report file.

    Parameters
    ----------
    df : dataframe
        Dataframes to modify.
    rate_metrics : list of str
        A list of strings that correspond to metric names related to rates, these are the metrics
        for which we will reconstruct the counter values for bytes.
    """
    if df.empty:
        return df

    for metric in rate_metrics:
        metric_rows = df["metric_name"] == metric

        # Create a dataframe that duplicates the metric's rows.
        df_bytes = df[metric_rows]

        # Get the duration of each sample in milliseconds. The timestamps are stored in
        # nanoseconds thus, we need to divide by 1'000'000.
        samples_duration = (df_bytes["end"] - df_bytes["start"]) / 1000000

        # Calculate the counter values in Bytes. The rates values are in B/ms.
        df_bytes["value"] = df_bytes["value"] * samples_duration

        df_bytes["metric_name"] = metric.replace(" rate", "")

        df = pd.concat([df, df_bytes])

    return df


def remove_inactive_devices(df: pd.DataFrame, hide_inactive: Optional[bool]) -> None:
    """Remove inactive devices from the dataframe.

    Parameters
    ----------
    df : dataframe
        Dataframe that contains data for network devices.
    hide_inactive : bool
        Whether to hide inactive devices, i.e., devices with zero traffic.
    """

    if df.empty or hide_inactive is None or not hide_inactive:
        return

    mask = pd.Series(True, index=df.index)
    mask = df["value"] > 0

    df.drop(df[~mask].index, inplace=True)
