# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from collections.abc import Sequence
from typing import Union

import pandas as pd

from nsys_recipe.lib import data_utils


def combine_runtime_gpu_dfs(
    runtime_df: pd.DataFrame, gpu_df: pd.DataFrame, merge_col: str = "correlationId"
) -> pd.DataFrame:
    """Combine the runtime dataframe and the GPU dataframes.

    The 'start' and 'end' columns of the GPU dataframes are renamed to
    'gpu_start' and 'gpu_end'.
    """
    runtime_cols = runtime_df.columns
    gpu_cols = gpu_df.columns

    data_utils.decompose_bit_fields(gpu_df)
    data_utils.decompose_bit_fields(runtime_df)

    # Restore to original state.
    extra_runtime_cols = list(set(runtime_cols) - set(runtime_df.columns))
    extra_gpu_cols = list(set(gpu_cols) - set(gpu_df.columns))

    gpu_df = gpu_df.rename(columns={"start": "gpu_start", "end": "gpu_end"})
    cuda_df = runtime_df.merge(
        gpu_df, on=[merge_col, "pid"], how="inner", suffixes=("_drop", "")
    )
    # We drop duplicated columns and the globalPid column that can be retrieved
    # from globalTid.
    cuda_df = cuda_df.loc[:, ~cuda_df.columns.str.endswith("_drop")].drop(
        columns=["globalPid"]
    )

    runtime_df.drop(columns=extra_runtime_cols)
    gpu_df.drop(columns=extra_gpu_cols)

    return cuda_df


def combine_runtime_graph_dfs(
    runtime_df: pd.DataFrame, graph_df: pd.DataFrame
) -> pd.DataFrame:
    """Combine the runtime dataframe and the graph event dataframe.

    A graph event is associated with the runtime row that has a start and end
    time that enclose its time.
    """
    runtime_df = runtime_df.sort_values("start")
    graph_df = graph_df.sort_values("start")

    # We associate each graph event with the nearest preceding runtime
    # event that has the same 'globalTid'. If a column exists in both
    # dataframes, we drop the column from the graph dataframe and keep
    # the runtime one.
    graph_df["ref_start"] = graph_df["start"]
    merged_df = pd.merge_asof(
        graph_df,
        runtime_df,
        left_on="ref_start",
        right_on="start",
        by="globalTid",
        suffixes=("_drop", ""),
    )
    merged_df = merged_df.loc[:, ~merged_df.columns.str.endswith("_drop")].drop(
        columns=["ref_start"]
    )

    # We drop rows in the graph dataframe that have non-matching (e.g. events
    # with negative start and end times).
    return merged_df.dropna(subset=["start", "end"])


def _update_graph_node_ids(graph_df: pd.DataFrame) -> pd.DataFrame:
    """Update the graph ID of the runtime graph events to match those of GPU
    events."""
    derivative_graph_df = graph_df.loc[graph_df["originalGraphNodeId"].notnull()].copy()
    original_graph_df = graph_df.loc[graph_df["originalGraphNodeId"].isnull()].copy()

    # The 'GraphInstantiate' events contain the graph ID used for correlating
    # with the GPU graph events. To associate each runtime graph event with a
    # GPU event, we correlate them with the corresponding 'GraphInstantiate'
    # event and update the ID accordingly.
    # Note: If a graph is cloned, the 'GraphInstantiate' event will have the
    # 'originalGraphNodeId' pointing to the 'graphNodeId' of the 'GraphClone'
    # event, which itself contains the 'originalGraphNodeId' of the original
    # graph event.
    instantiate_mask = derivative_graph_df["name"].str.contains("GraphInstantiate")
    instantiate_df = derivative_graph_df[instantiate_mask]
    instantiate_id_map = dict(
        zip(instantiate_df["originalGraphNodeId"], instantiate_df["graphNodeId"])
    )

    # Update the graph ID for derivative graph events such as 'GraphClone'.
    derivative_graph_df["graphNodeId"].update(
        derivative_graph_df.loc[~instantiate_mask, "graphNodeId"].map(
            instantiate_id_map
        )
    )

    # When an original graph event (ex. cudaGraphAdd*, cudaLaunchKernel, etc.)
    # is associated with multiple graphs, it will be correlated with the
    # corresponding GPU event from each graph, forming a one-to-many
    # relationship. Consequently, the output dataframe may have more rows than
    # the original dataframe.
    merged_df = original_graph_df.merge(
        derivative_graph_df,
        left_on=["graphNodeId", "globalTid"],
        right_on=["originalGraphNodeId", "globalTid"],
        suffixes=("", "_drop"),
    )

    merged_df["groupId"] = merged_df.groupby(["graphNodeId", "globalTid"]).ngroup()

    # Update the graph ID for original graph events.
    merged_df["graphNodeId"] = merged_df["graphNodeId_drop"]

    merged_df = merged_df.loc[:, ~merged_df.columns.str.endswith("_drop")]

    # This final dataframe includes both original and derivative graph events.
    # To identify the original graph events, check if the groupId is NaN.
    return pd.concat([merged_df, derivative_graph_df])


def _filter_duplicated_graphs(
    graph_df: pd.DataFrame, subset: Union[str, Sequence[str]]
) -> pd.DataFrame:
    """Filter out duplicated graph rows that contain redundant data."""
    # 'GraphClone' and 'GraphInstantiate' calls create multiple events, where
    # the first ones represent the creation event and the last represents the
    # corresponding graph event. We retain only the last instance for
    # simplicity.
    return graph_df.drop_duplicates(subset=subset, keep="last").reset_index(drop=True)


def derive_node_df(
    runtime_df: pd.DataFrame, node_events_df: pd.DataFrame, gpu_df: pd.DataFrame
) -> pd.DataFrame:
    """Derive a graph node dataframe by consolidating data from both runtime and
    GPU dataframes."""
    runtime_graph_df = combine_runtime_graph_dfs(runtime_df, node_events_df)

    runtime_graph_df = _filter_duplicated_graphs(runtime_graph_df, ["graphNodeId"])
    runtime_graph_df = _update_graph_node_ids(runtime_graph_df)

    return combine_runtime_gpu_dfs(runtime_graph_df, gpu_df, "graphNodeId")


def derive_graph_df(
    runtime_df: pd.DataFrame, graph_events_df: pd.DataFrame, gpu_df: pd.DataFrame
) -> pd.DataFrame:
    """Derive a graph dataframe by consolidating data from both runtime and
    GPU dataframes."""
    runtime_graph_df = combine_runtime_graph_dfs(runtime_df, graph_events_df)
    runtime_graph_df = _filter_duplicated_graphs(
        runtime_graph_df, ["correlationId", "pid"]
    )

    return combine_runtime_gpu_dfs(runtime_graph_df, gpu_df, "graphId")
