# fmt: on
import yaml
import os
import ast

import numpy as np
import json
import math
from typing import Optional

from OGOAT.src.Tiler.tiler import Tiler
from layer import Layer
from mha_tiler import TilingError, MHATiler
from cost_model import CostModel
from overlay import Overlay

from OGOAT.src.utils.context import Context, Logger
from OGOAT.src.L1_fusion.py_match.adv.attention import MHAMode

parent_dir = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)


class MHATilingOpt:
    def __init__(
        self,
        layer,
        device,
        overlay_name,
        kernel,
        layer_id: str,
        context: Optional[Context] = None,
    ):
        with open(os.path.join(parent_dir, "Collaterals/overlays.yaml")) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type].keys())
        self.overlay = overlay_name
        self.has_bias = "bias" in layer.op_type

        match layer.attributes["mha_mode"][0]:
            case MHAMode.TWO_P_ONE_MINI | MHAMode.TWO_P_ONE:
                self.mha_mode = "2p1"
            case MHAMode.THREE_P_ZERO_MINI:
                self.mha_mode = "3p0"
            case _:
                assert False, "Unsupported mha mode found."

        if context is None:
            self.logger = Logger.get_null_logger()
        else:
            self.logger = Logger(name=layer_id + "_Tiler", context=context)

        self.tilers = []
        self.cost_models = []

        mini_mha_split_overlays = ["M4K1N1_1col", "M1K1N1_1core"]
        for mode in self.modes:
            if self.mha_mode == "3p0" and mode not in mini_mha_split_overlays:
                self.logger.debug(
                    f"Skipping overlay mode {mode} as not compatible with 3p0 1 column mha mode"
                )
                continue

            overlay = Overlay(overlay_name, layer.orig_op_type, mode)
            tiler = Tiler(layer, device, overlay, kernel, self.logger)
            self.tilers.append(tiler)
            self.cost_models.append(CostModel(tiler))

        self.indexes = []

    def calculate_tiling_cycles(self):
        self.logger.debug("Calculate tiling cycles:")
        for tiler, cost_model in zip(self.tilers, self.cost_models):
            self.logger.debug("\nCompute possible tiling with Tiler: " + str(tiler))
            try:
                tiler.calculate_memtile_tilings()

                # tiler.check_valid_memtile_tilings()
                tiler.calculate_array_tilings()
                tiler.check_core_constraints()
            except TilingError as e:
                self.logger.debug("Tiling failed with error: " + str(e))
                self.logger.debug("Continuing exploration with the next tiler")
                continue

            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

    def get_padding_values(self, padding_dim: list[int]) -> list[Optional[str | int]]:
        """
        Return the values that should be used when padding this dimension:
          - None: no padding
          - zp: padding by zero point value
        """
        padding_values = list()

        for i in range(len(padding_dim)):
            padding = padding_dim[i]
            if padding < 0:
                raise RuntimeError(
                    f"error in shape padding computation, padding cannot be negative: {padding_dim}"
                )

            if padding == 0:
                padding_values.append(None)
            else:
                padding_values.append("zp")

        return padding_values

    def reverse_permutation_on_padding(
        self, tiler: MHATiler, layer_padding: list[dict]
    ) -> None:
        """
        In place transformation of the layer_padding for the K input if transposed
        """
        # No permutation to reverse, return the layer padding without changes
        if tiler.layer.permK_3d is None:
            return

        # Compute the reverse permutation
        reverse_perm = np.argsort(tiler.layer.permK_3d)

        # search for K input padding and reverse the permutation
        for padding_input in layer_padding:
            input_name = next(iter(padding_input.keys()))
            if input_name != "ifm_k":
                continue

            padding = padding_input[input_name]
            padding["dims"] = np.array(padding["dims"])[reverse_perm].tolist()
            padding["value"] = np.array(padding["value"])[reverse_perm].tolist()

    def create_host_and_dma_padding(self, tiler: MHATiler):
        """
        Boolean values specifying for each dimension of each inputs if that dimension is
        an inner dimension or not. This is needed as inner dimensions are padded with zero
        point values and the others with zero values.
         - None: no padding
         - 0: padding by 0 if padding needed
         - zp: padding by zero point value if padding needed (on all inner dimensions)
        """
        host_layer_padding = list(dict())
        for entry, host_padding in tiler.host_layer_padding.items():
            outer_dim = tiler.layer.activations_shapes[entry][0]
            extra_host_padding = host_padding - tiler.activations_2d_shapes[entry]

            padding = {
                entry: {
                    "dims": [outer_dim] + host_padding.astype(int).tolist(),
                    "value": [None] + self.get_padding_values(extra_host_padding),
                }
            }
            host_layer_padding.append(padding)
        self.reverse_permutation_on_padding(tiler, host_layer_padding)

        dma_layer_padding = list(dict())
        for entry, dma_padding in tiler.dma_layer_padding.items():
            outer_dim = tiler.layer.activations_shapes[entry][0]
            extra_dma_padding = dma_padding - tiler.host_layer_padding[entry]

            padding = {
                entry: {
                    "dims": [outer_dim] + dma_padding.astype(int).tolist(),
                    "value": [None] + self.get_padding_values(extra_dma_padding),
                }
            }
            dma_layer_padding.append(padding)
        self.reverse_permutation_on_padding(tiler, dma_layer_padding)

        return host_layer_padding, dma_layer_padding

    def create_layer_dict(self, tiler: Tiler) -> dict:
        layerdict = vars(tiler.layer).copy()

        # Remove input informations that are not needed
        layerdict.pop("in_datatype", None)
        layerdict.pop("activations_shapes", None)

        # Extract and add the relevants input tensors information
        layerdict["in_q_shape"] = tiler.dma_layer_padding["ifm_q"].astype(int).tolist()
        layerdict["orig_in_q_shape"] = list(
            map(int, tiler.activations_2d_shapes["ifm_q"])
        )

        layerdict["in_k_shape"] = tiler.dma_layer_padding["ifm_k"].astype(int).tolist()
        layerdict["orig_in_k_shape"] = list(
            map(int, tiler.activations_2d_shapes["ifm_k"])
        )

        if "ifm_m" in tiler.activations_2d_shapes:
            layerdict["in_m_shape"] = (
                tiler.dma_layer_padding["ifm_m"].astype(int).tolist()
            )
            layerdict["orig_in_m_shape"] = list(
                map(int, tiler.activations_2d_shapes["ifm_m"])
            )

        if "ifm_b" in tiler.activations_2d_shapes:
            layerdict["in_b_shape"] = (
                tiler.dma_layer_padding["ifm_b"].astype(int).tolist()
            )
            layerdict["orig_in_b_shape"] = list(
                map(int, tiler.activations_2d_shapes["ifm_b"])
            )

        if "ifm_v" in tiler.activations_2d_shapes:
            layerdict["in_v_shape"] = (
                tiler.dma_layer_padding["ifm_v"].astype(int).tolist()
            )
            layerdict["orig_in_v_shape"] = list(
                map(int, tiler.activations_2d_shapes["ifm_v"])
            )

        layerdict["out_ofm_shape"] = tiler.dma_layer_padding["ofm"].astype(int).tolist()
        layerdict["orig_out_ofm_shape"] = list(
            map(int, tiler.activations_2d_shapes["ofm"])
        )

        return layerdict

    def find_optimal_tiling(self):
        self.logger.debug("Find optimal tiling for mha operator:")
        self.calculate_tiling_cycles()

        all_cycles = []
        indexes = []
        iters = []
        for midx, cost_model in enumerate(self.cost_models):
            for sched, sched_cycles in cost_model.total_layer_cycles.items():
                for mem_sub_id, cycles in sched_cycles.items():
                    all_cycles.append(cycles)

                    indexes.append(
                        np.hstack(
                            (
                                np.tile([midx, sched, mem_sub_id], (len(cycles), 1)),
                                np.arange(len(cycles)).reshape((len(cycles), 1)),
                            )
                        )
                    )
                    iters.append(
                        np.tile(
                            self.tilers[midx].valid_memtile_iters[sched][mem_sub_id],
                            (len(cycles), 1),
                        )
                    )

        if len(all_cycles) == 0:
            raise RuntimeError("No valid tiling found")

        all_cycles = np.vstack(all_cycles)
        indexes = np.vstack(indexes)
        iters = np.vstack(iters)
        self.indexes = indexes
        # idx, pingpong = np.unravel_index(np.argmin(all_cycles), all_cycles.shape)

        idx, pp = np.where(all_cycles < np.min(all_cycles) * 1.05)

        total_iters = []
        for opt_midx, opt_sched, opt_memtiling, opt_coretiling in indexes[idx]:
            tm, tk, tn = (
                self.tilers[opt_midx]
                .valid_core_iters[opt_sched][opt_memtiling][opt_coretiling]
                .astype(int)
                .tolist()
            )
            valid_memtiler_iters = self.tilers[opt_midx].valid_memtile_iters[opt_sched][
                opt_memtiling
            ]
            total_iters.append(tm * tn * tk * np.prod(valid_memtiler_iters))
        itermin_idx = np.argmin(total_iters)

        nidx = idx[itermin_idx]
        pingpong = pp[itermin_idx]

        opt_midx = indexes[nidx][0]
        opt_sched = indexes[nidx][1]
        opt_memtiling = indexes[nidx][2]
        opt_coretiling = indexes[nidx][3]

        opt_tiler = self.tilers[opt_midx]
        self.logger.debug(f"Optimal tiler found: {opt_tiler}")

        # Add the Core sub volume and iterations
        valid_core_subv = opt_tiler.valid_core_subvols[opt_sched][opt_memtiling]
        core_tiling = {
            "q": valid_core_subv["ifm"][opt_coretiling].astype(int).tolist(),
            "k": valid_core_subv["wgt"][opt_coretiling].astype(int).tolist(),
            **(
                {"v": valid_core_subv["wgt"][opt_coretiling][::-1].astype(int).tolist()}
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": valid_core_subv["ofm"][opt_coretiling].astype(int).tolist(),
        }

        valid_core_iters = opt_tiler.valid_core_iters[opt_sched][opt_memtiling][
            opt_coretiling
        ]
        tm, tk, tn = valid_core_iters.astype(int).tolist()
        core_iters = {
            "q": [tm, tk],
            "k": [tn, tk],
            **({"v": [tn, tk]} if self.mha_mode == "3p0" else {}),
            "ofm": [tm, tn],
        }

        # Add the memtile subvolume and iteration
        valid_memtile_subv = opt_tiler.valid_memtile_subvols[opt_sched][opt_memtiling]
        memtile_tiling = {
            "q": valid_memtile_subv["ifm"].astype(int).tolist(),
            "k": valid_memtile_subv["wgt"].astype(int).tolist(),
            **(
                {"v": valid_memtile_subv["wgt"][::-1].astype(int).tolist()}
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": valid_memtile_subv["ofm"].astype(int).tolist(),
        }
        valid_memtiler_iters = opt_tiler.valid_memtile_iters[opt_sched][opt_memtiling]
        memtile_iters = {
            "q": valid_memtiler_iters[0:2].astype(int).tolist(),
            "k": valid_memtiler_iters[2:4].astype(int).tolist(),
            **(
                {"v": valid_memtiler_iters[2:4][::-1].astype(int).tolist()}
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": valid_memtiler_iters[4:6].astype(int).tolist(),
        }

        ## delete the builins from the vars_dict
        del opt_tiler.vars_dict["__builtins__"]
        placement_dict = {}
        opt_tiler.vars_dict["Sq"] = core_tiling["q"][0]
        assert opt_tiler.vars_dict["Sq"] == 16, "Only Sq = 16 is supported currently"

        opt_tiler.vars_dict["So"] = core_tiling["ofm"][0]
        opt_tiler.vars_dict["Dh"] = core_tiling["q"][1]
        opt_tiler.vars_dict["Skv"] = core_tiling["k"][1]
        for buff, bankdict in opt_tiler.kernel.placement_constraints.items():
            placement_constraint = dict()
            for k, v in bankdict.items():
                res = int(eval(str(v), opt_tiler.vars_dict))

                # FIXME: Dataflow implementation expect the size of the output buffer
                # to be 2 times smaller than the values computed by the tiling engine
                # when the overlay is 8x4.
                # Remove this hack as soon as they are compatible
                if self.overlay == "8x4" and (buff.lower() == "ofm" or buff == "tdm"):
                    res //= 2

                placement_constraint[k] = res
            placement_dict[buff] = placement_constraint

        memtile_sizes = {
            "q": int(np.prod(memtile_tiling["q"]) * opt_tiler.layer.in_bytes),
            "k": int(
                np.prod(memtile_tiling["k"]) * opt_tiler.layer.wgt_bytes
                + core_tiling["k"][1]
                * opt_tiler.vars_dict["bias_bytes"]
                * np.prod(np.array(memtile_tiling["k"]) / np.array(core_tiling["k"]))
            ),
            **(
                {
                    "v": int(
                        np.prod(memtile_tiling["v"]) * opt_tiler.layer.wgt1_bytes
                        + core_tiling["k"][1]
                        * opt_tiler.vars_dict["bias_bytes"]
                        * np.prod(
                            np.array(memtile_tiling["k"]) / np.array(core_tiling["k"])
                        )
                    )
                }
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": int(np.prod(memtile_tiling["ofm"]) * opt_tiler.layer.out_ofm_bytes),
        }

        # FIXME: Dataflow implementation expect the size of the output buffer
        # to be 2 times smaller than the values computed by the tiling engine
        # when the overlay is 8x4.
        # Remove this hack as soon as they are compatible
        if self.overlay == "8x4":
            memtile_sizes["ofm"] //= 2

        if opt_sched == 1:
            schedule_dict = {
                "q": "pin",
                "k": "full",
                "q_ping_pong": True if pingpong == 1 else False,
                "k_ping_pong": False,
                "ofm_ping_pong": False,
            }
            # Correction for ofm Tn values
            core_iters["ofm"][1] = 1
        elif opt_sched == 2:
            schedule_dict = {
                "q": "pin",
                "k": "stream",
                "q_ping_pong": True if pingpong == 1 else False,
                "k_ping_pong": True,
                "ofm_ping_pong": False,
            }
        elif opt_sched == 5:
            schedule_dict = {
                "q": "stream",
                "k": "stream",
                "q_ping_pong": True if pingpong == 1 else False,
                "k_ping_pong": True,
                "ofm_ping_pong": False,
            }
        elif opt_sched == 6:
            schedule_dict = {
                "q": "stream",
                "k": "pin",
                **({"v": "pin"} if self.mha_mode == "3p0" else {}),
                "q_ping_pong": True if pingpong == 1 else False,
                "k_ping_pong": True if pingpong == 1 else False,
                **(
                    {"v_ping_pong": bool(pingpong == 1)}
                    if self.mha_mode == "3p0"
                    else {}
                ),
                "ofm_ping_pong": True if pingpong == 1 else False,
            }
        else:
            assert False, f"opt_sched=={opt_sched} not implemented"

        shim_tilings = {
            k: (np.array(memtile_tiling[k]) * np.array(memtile_iters[k]))
            .astype(int)
            .tolist()
            for k in memtile_iters.keys()
        }
        shim_sizes = {
            k: (memtile_sizes[k] * np.prod(memtile_iters[k])).astype(int).tolist()
            for k in memtile_iters.keys()
        }

        dram_shapes = {
            "q": (shim_tilings["q"] * opt_tiler.overlay.mem_splits["ifm"])
            .astype(int)
            .tolist(),
            "k": (shim_tilings["k"] * opt_tiler.overlay.mem_splits["wgt"])
            .astype(int)
            .tolist(),
            **(
                {
                    "v": (shim_tilings["k"] * opt_tiler.overlay.mem_splits["wgt"][::-1])
                    .astype(int)
                    .tolist()
                }
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": (shim_tilings["ofm"] * opt_tiler.overlay.mem_splits["ofm"])
            .astype(int)
            .tolist(),
        }
        dram_sizes = {
            "q": (shim_sizes["q"] * np.prod(opt_tiler.overlay.mem_splits["ifm"]))
            .astype(int)
            .tolist(),
            "k": (shim_sizes["k"] * np.prod(opt_tiler.overlay.mem_splits["wgt"]))
            .astype(int)
            .tolist(),
            **(
                {
                    "v": (
                        shim_sizes["k"]
                        * np.prod(opt_tiler.overlay.mem_splits["wgt"][::-1])
                    )
                    .astype(int)
                    .tolist()
                }
                if self.mha_mode == "3p0"
                else {}
            ),
            "ofm": (shim_sizes["ofm"] * np.prod(opt_tiler.overlay.mem_splits["ofm"]))
            .astype(int)
            .tolist(),
        }

        layerdict = self.create_layer_dict(opt_tiler)
        host_layer_padding, dma_layer_padding = self.create_host_and_dma_padding(
            opt_tiler
        )

        test_cpp_name = opt_tiler.kernel.testbench_args["HostName"]
        opt_tb_cflags = opt_tiler.kernel.testbench_args["CFLAGS"]
        tb_cflags = {
            opt_tb_cflags[0]: layerdict["in_q_shape"][0],  # M_GEMM_A16W8
            opt_tb_cflags[1]: layerdict["in_q_shape"][1],  # K_GEMM_A16W8
            opt_tb_cflags[2]: layerdict["in_k_shape"][1],  # N_GEMM_A16W8
            opt_tb_cflags[3]: core_tiling["q"][0],  # M_GEMM_SUBV_A16W8
            opt_tb_cflags[4]: core_tiling["q"][1],  # K_GEMM_SUBV_A16W8
            opt_tb_cflags[5]: core_tiling["k"][1],  # N_GEMM_SUBV_A16W8
        }

        tiling_params = {
            "core_tile_params": {"subvols": core_tiling, "iters": core_iters},
            "mem_tile_params": {
                "subvols": memtile_tiling,
                "iters": memtile_iters,
                "sizes": memtile_sizes,
            },
            "shim_tile_params": {"subvols": shim_tilings, "sizes": shim_sizes},
            "dram_params": {"shapes": dram_shapes, "sizes": dram_sizes},
            "scheduling": schedule_dict,
            "dma_layer_padding": dma_layer_padding,
            "host_layer_padding": host_layer_padding,
            "kernel_info": {"placement_constraints": placement_dict},
            "overlay_info": {
                "overlay": self.overlay,
                "mode": opt_tiler.overlay.mode,
                "shape": {
                    "row": self.tilers[opt_midx].overlay.rows,
                    "col": self.tilers[opt_midx].overlay.cols,
                },
            },
            "layer_info": layerdict,
            "testbench_args": {"HOST_NAME": test_cpp_name, "COMPILE_FLAGS": tb_cflags},
        }

        return tiling_params


if __name__ == "__main__":
    import json
    from layer import Layer

    with open("OGOAT/src/Tiler/tst_layer.json") as f:
        mdict = json.load(f)
    ld = mdict["tst"]
    ld["in_act_residency"] = "L3"
    ld["out_act_residency"] = "L3"

    l = Layer(ld)

    from kernel import Kernel
    from device import Device

    d = Device("strix")
    k = Kernel(l)

    t = MHATilingOpt(l, d, "8x4", k, "MHA_tiler_test")
    r = t.find_optimal_tiling()
