"""Common utils for Uniop"""
from typing import Tuple, List
from itertools import accumulate

from dmacompiler import (
    DataTransfer,
    DmaDir,
    AieTile,
    TileType,
    memtile_dma,
    shim_dma,
    pack_reconfig_transfers,
)
from utils.utils_common import (
   log,
)
from scheduler.common import (
    overlay_3x4_F_ids,
    overlay_3x4_A_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
)


class SpatialSplitModes:
    """All possible SpatialSplitModes"""
    def __init__(self):
        self.Table = {
            "N1X1C12": {"N": (1, 1), "X": (1, 1), "C": (3, 4)},
            "N1X2C6": {"N": (1, 1), "X": (1, 2), "C": (3, 2)},
            "N1X3C4": {"N": (1, 1), "X": (3, 1), "C": (1, 4)},
            "N1X4C3": {"N": (1, 1), "X": (1, 4), "C": (3, 1)},
            "N1X6C2": {"N": (1, 1), "X": (3, 2), "C": (1, 2)},
            "N1X12C1": {"N": (1, 1), "X": (3, 4), "C": (1, 1)},
            "N2X1C6": {"N": (1, 2), "X": (1, 1), "C": (3, 2)},
            "N2X2C3": {"N": (1, 2), "X": (1, 2), "C": (3, 1)},
            "N2X3C2": {"N": (1, 2), "X": (3, 1), "C": (1, 2)},
            "N2X6C1": {"N": (1, 2), "X": (3, 2), "C": (1, 1)},
            "N3X1C4": {"N": (3, 1), "X": (1, 1), "C": (1, 4)},
            "N3X2C2": {"N": (3, 1), "X": (1, 2), "C": (1, 2)},
            "N3X4C1": {"N": (3, 1), "X": (1, 4), "C": (1, 1)},
            "N4X1C3": {"N": (1, 4), "X": (1, 1), "C": (3, 1)},
            "N4X3C1": {"N": (1, 4), "X": (3, 1), "C": (1, 1)},
            "N6X1C2": {"N": (3, 2), "X": (1, 1), "C": (1, 2)},
            "N6X2C1": {"N": (3, 2), "X": (1, 2), "C": (1, 1)},
            "N12X1C1": {"N": (3, 4), "X": (1, 1), "C": (1, 1)},
        }
        """
        self.Table = {
            'X12C1': {'X':(3,4), 'C':(1,1)},
            'X3C4' : {'X':(3,1), 'C':(1,4)},
            'X1C12': {'X':(1,1), 'C':(3,4)},

            'X2C6' : {'X':(1,2), 'C':(3,2)},
            'X4C3' : {'X':(1,4), 'C':(3,1)},
            'X6C2' : {'X':(3,2), 'C':(1,2)},
            }
        """


class Axis:
    """Function to generate Axis"""
    def __init__(self, dim: int, dim_subv: int, space_split_factors: Tuple):

        self.all_spatial_split_factor = space_split_factors[0] * space_split_factors[1]
        self.shm_spatial_split_factor = space_split_factors[0]
        self.mem_spatial_split_factor = space_split_factors[1]
        self.AieCols = 3

        self.l1_dim = dim_subv
        self.l2_dim = self.mem_spatial_split_factor * self.l1_dim
        self.l3_dim = dim

        self.time_split_factor = (
            self.l3_dim // self.all_spatial_split_factor
        ) // self.l1_dim

        self.mem_split_dim = self.l2_dim // self.mem_spatial_split_factor

        if self.shm_spatial_split_factor == 1:
            self.shm_start = [0] * 3
            self.shm_end = [self.l3_dim] * 3
        else:
            if self.l3_dim % (3 * self.l2_dim) == 0:
                full_shm_split_dim = self.l3_dim // 3  # self.AieCols
                self.shm_split_dim = [
                    full_shm_split_dim,
                    full_shm_split_dim,
                    self.l3_dim - 2 * full_shm_split_dim,
                ]
            else:
                self.shm_split_dim = [0] * 3
                remain = self.l3_dim
                i = 0
                while remain > 0:
                    log("remain :", remain)
                    delta = self.l2_dim if (remain >= self.l2_dim) else remain
                    self.shm_split_dim[i % 3] += delta
                    i += 1
                    remain -= delta
                log("self.shm_split_dim:", self.shm_split_dim)

            assert list(accumulate(self.shm_split_dim))[-1] == self.l3_dim
            self.shm_end = list(accumulate(self.shm_split_dim))
            self.shm_start = [0] + self.shm_end[:-1]

        assert self.l1_dim > 0 and self.l2_dim > 0 and self.l3_dim > 0
        assert self.l2_dim >= self.l1_dim
        assert self.shm_spatial_split_factor in {1, 3}
        assert self.mem_spatial_split_factor in {1, 2, 4}

        # assert self.l3_dim >= self.l1_dim

    # ########################################################################################
    # ##   Obtain number of iterations (subvolumes) to cover all the elements in axis
    # ########################################################################################
    def get_iters(self) -> int:
        """Get Memtile Iterations"""
        return (
            self.time_split_factor
            if self.l3_dim % (self.AieCols * self.l2_dim) == 0
            else (self.time_split_factor + 1)
        )

    # ########################################################################################
    # ##   For L3 (Shimtile) axis access pattern :
    # ########################################################################################
    def get_shm_interval(self, aie_col_id: int, phase: int = -1) -> int:
        """Get Shim Interval"""
        _ = phase
        assert aie_col_id in {0, 1, 2}
        return (self.shm_start[aie_col_id], self.shm_end[aie_col_id])

    # ########################################################################################
    # ##   For L2 (Memtile) axis access pattern :
    # ########################################################################################
    def get_mem_interval(self, aie_col_id=-1, aie_row_id=-1, phase: int = -1) -> int:
        """Get Memtile Interval"""
        assert aie_row_id in {-1, 0, 1, 2, 3}

        if aie_row_id in {
            0,
            1,
            2,
            3,
        }:  # For input  memtile mm2s (L2->L1) or output memtile s2mm (L1->L2)
            assert aie_col_id == -1
            mem_spatial_split_id = aie_row_id % self.mem_spatial_split_factor

            axis_mem_start_offset = mem_spatial_split_id * self.mem_split_dim
            axis_mem_end_offset = min(
                axis_mem_start_offset + self.mem_split_dim, self.l2_dim
            )

            assert axis_mem_end_offset - axis_mem_start_offset == self.l1_dim
            return (axis_mem_start_offset, axis_mem_end_offset)

        # For output memtile mm2s (L2->L3)	or input memtile s2mm (L3->L2)
        assert aie_col_id != -1
        if phase in [-1, 0]:
            return (
                0,
                min(
                    (self.shm_end[aie_col_id] - self.shm_start[aie_col_id]),
                    self.l2_dim,
                ),
            )
        # phase == 1
        return (
            0,
            (self.shm_end[aie_col_id] - self.shm_start[aie_col_id])
            - self.time_split_factor * self.l2_dim,
        )


class UniOpTensor:  # NXC or e.q. HMN/HMK  , for 3 dimensional tensor
    """Class to generate Uniop Tensor"""
    def __init__(
        self,
        dims: Tuple,
        dims_subv: Tuple,
        spatial_split_mode_index: str,
        bytes_per_elem=2,
        inputBuf=True,
        L2BufferAddr=0,
        input_npass=1,
        granC=64
    ):
        assert bytes_per_elem in [2, 4]
        log("dim:", dims)
        log("dims_subv:", dims_subv)

        self.SplitMode = SpatialSplitModes()
        self.space_split_factors = self.SplitMode.Table[spatial_split_mode_index]
        self.AieRows = 4
        self.AieCols = 3
        self.bytes = bytes_per_elem
        self.bits_per_elem = self.bytes * 8
        self.is_input_tensor = inputBuf
        self.input_npass = input_npass
        self.npass = self.input_npass  # if self.is_input_tensor else 1

        self.N = Axis(dims[0], dims_subv[0], self.space_split_factors["N"])
        self.X = Axis(dims[1], dims_subv[1], self.space_split_factors["X"])
        self.C = Axis(dims[2], dims_subv[2], self.space_split_factors["C"])

        self.memtile_mem_format = "N:" + str(self.N.l2_dim) + " X:" + str(self.X.l2_dim) + " C:" + str(self.C.l2_dim)
        self.shmtile_mem_format = "N:" + str(self.N.l3_dim) + " X:" + str(self.X.l3_dim) + " C:" + str(self.C.l3_dim)

        self.BufferObject_Id = 1 if (self.is_input_tensor) else 0
        self.BufferOject_offset = 0
        self.L2BufferAddr = L2BufferAddr
        self.L2BufferSize = self.N.l2_dim * self.X.l2_dim * self.C.l2_dim * self.bytes
        self.L3BufferSize = self.N.l3_dim * self.X.l3_dim * self.C.l3_dim * self.bytes

        # self.num_phases = 1 if self.X.l3_dim % (self.AieCols * self.X.l2_dim) == 0 else 2
        self.num_phases = 1 if self.X.l3_dim % (self.X.shm_spatial_split_factor * self.X.l2_dim) == 0 else 2
        self.time_split_factor = self.X.time_split_factor  # temporary

        self.granC = granC

    def set_bo_id(self, bo_id):
        """Function to set XRT_ID"""
        self.BufferObject_Id = bo_id

    def set_bo_offset(self, bo_offset):
        """Function to set XRT Offset"""
        self.BufferOject_offset = bo_offset

    def shim_tiling_format(self, phase: int, aie_col_id: int) -> str:
        """Function to generate Shim tiling"""
        N_stt, N_end = self.N.get_shm_interval(aie_col_id)
        X_stt, X_end = self.X.get_shm_interval(aie_col_id, phase)
        C_stt, C_end = self.C.get_shm_interval(aie_col_id)

        return f"N:{N_stt}:{N_end} X:{X_stt}:{X_end} C:{C_stt}:{C_end}"

    def mem_s2mm_tiling_format(
        self, phase: int, aie_col_id: int, aie_row_id: int
    ) -> str:
        """Function to generate Mem S2MM tiling"""
        if self.is_input_tensor:  # L3 -> L2 in NXC    format if input
            assert aie_row_id == -1
            N_stt, N_end = self.N.get_mem_interval(aie_col_id=aie_col_id)
            X_stt, X_end = self.X.get_mem_interval(aie_col_id=aie_col_id, phase=phase)
            C_stt, C_end = self.C.get_mem_interval(aie_col_id=aie_col_id)

            return f"N:{0}:{N_end-N_stt} X:{0}:{X_end-X_stt} C:{0}:{C_end-C_stt}"

        # L1 -> L2 in NCXC64 format if output
        assert aie_row_id != -1 and aie_row_id in {0, 1, 2, 3}
        N_stt, N_end = self.N.get_mem_interval(aie_row_id=aie_row_id)
        X_stt, X_end = self.X.get_mem_interval(aie_row_id=aie_row_id)
        C_stt, C_end = self.C.get_mem_interval(aie_row_id=aie_row_id)

        return f"N:{N_stt}:{N_end} C:{C_stt}:{C_end}:{self.granC} X:{X_stt}:{X_end} C:0:{self.granC}"

    def mem_mm2s_tiling_format(
        self, phase: int, aie_col_id: int, aie_row_id: int
    ) -> str:
        """Function to generate Mem MM2S tiling"""
        if self.is_input_tensor:  # L2 -> L1 in NCXC64 format if input
            assert aie_row_id != -1 and aie_row_id in {0, 1, 2, 3}
            N_stt, N_end = self.N.get_mem_interval(aie_row_id=aie_row_id)
            X_stt, X_end = self.X.get_mem_interval(aie_row_id=aie_row_id)
            C_stt, C_end = self.C.get_mem_interval(aie_row_id=aie_row_id)

            # assert (C_end - C_stt) % 64 == 0
            return f"N:{N_stt}:{N_end} C:{C_stt}:{C_end}:{self.granC} X:{X_stt}:{X_end} C:0:{self.granC}"

        # L2 -> L3 in NXC    format if output
        assert aie_row_id == -1
        N_stt, N_end = self.N.get_mem_interval(aie_col_id=aie_col_id)
        X_stt, X_end = self.X.get_mem_interval(aie_col_id=aie_col_id, phase=phase)
        C_stt, C_end = self.C.get_mem_interval(aie_col_id=aie_col_id)

        return f"N:{0}:{N_end-N_stt} X:{0}:{X_end-X_stt} C:{0}:{C_end-C_stt}"

    def core_amount(self,) -> int:
        """Core Subv Elements"""
        return self.N.l1_dim * self.X.l1_dim * self.C.l1_dim * self.bytes

    def get_iters(self, axis_id: str) -> int:
        """Generate Core Iterations"""
        assert axis_id in {"N", "X", "C"}
        # if (self.X.l3_dim % (self.AieCols * self.X.l2_dim)) == 0:
        if (self.X.l3_dim % (self.X.shm_spatial_split_factor * self.X.l2_dim)) == 0:
            return (
                self.N.time_split_factor
                if (axis_id == "N")
                else (
                    self.X.time_split_factor
                    if (axis_id == "X")
                    else self.C.time_split_factor
                )
            )

        return (
            self.N.time_split_factor
            if (axis_id == "N")
            else (
                self.X.time_split_factor + 1
                if (axis_id == "X")
                else self.C.time_split_factor
            )
        )

    def num_total_subvolumes(self,) -> int:
        """Get total core subvolumes"""
        return (
            self.N.time_split_factor
            * self.X.time_split_factor
            * self.C.time_split_factor
        )

    def len_rptcnt_list(self,) -> int:
        """Return length of repeat count list"""
        l2_rptcnt_list = self.L2_repeat()
        l3_rptcnt_list = self.L3_repeat()

        assert isinstance(l2_rptcnt_list, list)
        assert isinstance(l3_rptcnt_list, list)
        assert len(l2_rptcnt_list) == len(l3_rptcnt_list)

        return len(l2_rptcnt_list)

    def L2_repeat(self,) -> List:
        """Memtile Repeat Count"""
        # if (self.X.l3_dim % (self.AieCols * self.X.l2_dim)) == 0:
        if self.X.l3_dim % (self.X.shm_spatial_split_factor * self.X.l2_dim) == 0:
            # return [self.time_split_factor*self.npass]     # *self.num_phases
            return [self.time_split_factor*self.input_npass] if self.is_input_tensor else [self.time_split_factor]
        # return [self.time_split_factor*npass, 1*npass]  # *self.num_phases
        if self.is_input_tensor:
            return [self.time_split_factor, 1] * self.input_npass
        return [0, 0] + [self.time_split_factor, 1] if self.input_npass == 2 else [self.time_split_factor, 1]

    def L3_repeat(self,) -> List:
        """Shim Repeat Count"""
        # if (self.X.l3_dim % (self.AieCols * self.X.l2_dim)) == 0:
        if self.X.l3_dim % (self.X.shm_spatial_split_factor * self.X.l2_dim) == 0:
            return [self.input_npass] if self.is_input_tensor else [1]
        # return [npass, 0]
        if self.is_input_tensor:
            return [1, 0] * self.input_npass
        return [0, 0] + [1, 0] if self.input_npass == 2 else [1, 0]

    # ################################################################################################################
    # return a list of DataTransfer, each item in the list is a DataTransfer object for an Aie Column
    # ################################################################################################################
    def L2_DataTransfer(self, ):
        """Generate Mem DataTtansfer"""
        aie_nullid = -1

        def input_L2_s2mm_transfer_param(aie_col_id: int):
            """Generate IFM Mem S2MM DataTtansfer"""
            return [
                pack_reconfig_transfers(
                    memtile_dma(aie_col_id, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    [self.memtile_mem_format] * self.len_rptcnt_list(),  # * self.num_phases * 1, #self.npass,
                    [
                        self.mem_s2mm_tiling_format(p, aie_col_id, aie_nullid)
                        for p in range(self.num_phases)
                    ] * (self.len_rptcnt_list() // self.num_phases),  # * 1, #self.npass,
                    bits_per_elem=self.bits_per_elem,
                    buffer_offset=[self.L2BufferAddr]  # * self.num_phases * 1, #* self.npass,
                )
            ]

        def input_L2_mm2s_transfer_param(aie_col_id: int):
            """Generate IFM Mem MM2S DataTtansfer"""
            return [
                pack_reconfig_transfers(
                    memtile_dma(aie_col_id, DmaDir.MM2S, overlay_3x4_A_ids()[aie_row_id]),
                    [self.memtile_mem_format] * self.len_rptcnt_list(),  # * self.num_phases * 1, #self.npass,
                    [
                        self.mem_mm2s_tiling_format(p, aie_nullid, aie_row_id)
                        for p in range(self.num_phases)
                    ] * (self.len_rptcnt_list() // self.num_phases),  # * 1, # * self.npass,
                    bits_per_elem=self.bits_per_elem,
                    buffer_offset=[self.L2BufferAddr]  # * self.num_phases * 1, #* self.npass,
                )
                for aie_row_id in range(self.AieRows)
            ]

        def output_L2_s2mm_transfer_param(aie_col_id: int):
            """Generate OFM Mem S2MM DataTtansfer"""
            return [
                pack_reconfig_transfers(
                    memtile_dma(aie_col_id, DmaDir.S2MM, overlay_3x4_O_ids()[aie_row_id]),
                    [self.memtile_mem_format] * self.len_rptcnt_list(),  # * self.num_phases * 1, #self.npass,
                    [
                        self.mem_s2mm_tiling_format(p, aie_nullid, aie_row_id)
                        for p in range(self.num_phases)
                    ] * (self.len_rptcnt_list() // self.num_phases),  # * 1, # * self.npass,
                    bits_per_elem=self.bits_per_elem,
                    buffer_offset=[self.L2BufferAddr]  # * self.num_phases * 1, #* self.npass,
                )
                for aie_row_id in range(self.AieRows)
            ]

        def output_L2_mm2s_transfer_param(aie_col_id: int):
            """Generate OFM Mem MM2S DataTtansfer"""
            return [
                pack_reconfig_transfers(
                    memtile_dma(aie_col_id, DmaDir.MM2S, overlay_3x4_S_ids(aie_col_id)[0]),
                    [self.memtile_mem_format] * self.len_rptcnt_list(),  # * self.num_phases * 1, #self.npass,
                    [
                        self.mem_mm2s_tiling_format(p, aie_col_id, aie_nullid)
                        for p in range(self.num_phases)
                    ] * (self.len_rptcnt_list() // self.num_phases),  # * 1, # * self.npass,
                    bits_per_elem=self.bits_per_elem,
                    buffer_offset=[self.L2BufferAddr]  # * self.num_phases * 1, #* self.npass,
                )
            ]
        log("self.L2_repeat():", self.L2_repeat())
        return [
            DataTransfer(
                self.L2_repeat(),
                AieTile(TileType.Memtile, aie_col_id),
                [self.L2BufferAddr],
                self.L2BufferSize,
                (
                    input_L2_s2mm_transfer_param(aie_col_id)
                    if self.is_input_tensor
                    else output_L2_s2mm_transfer_param(aie_col_id)
                ),
                (
                    input_L2_mm2s_transfer_param(aie_col_id)
                    if self.is_input_tensor
                    else output_L2_mm2s_transfer_param(aie_col_id)
                ),
            )
            for aie_col_id in range(self.AieCols)
        ]

    # ################################################################################################################
    # # return a list of DataTransfer, each item in the list is a DataTransfer object for an Aie Column
    # ################################################################################################################
    def L3_DataTransfer(
        self,
    ):
        """Generate Shim DataTtansfer"""
        SHIM_CHANNEL_ID = 0  # For both s2mm and mm2s ### !!!! CHECK !!!!!!!!!!!!
        DmaDirection = DmaDir.MM2S if self.is_input_tensor else DmaDir.S2MM

        def get_pattern(aie_col_id: int):
            """Generate Shim Tiling"""
            return [
                pack_reconfig_transfers(
                    shim_dma(aie_col_id, DmaDirection, SHIM_CHANNEL_ID),
                    [self.shmtile_mem_format] * self.len_rptcnt_list(),  # * self.num_phases * 1, #self.npass,
                    [
                        self.shim_tiling_format(p, aie_col_id)
                        for p in range(self.num_phases)
                    ] * (self.len_rptcnt_list() // self.num_phases),  # * 1, #self.npass,
                    bits_per_elem=self.bits_per_elem,
                    buffer_offset=[self.BufferOject_offset]
                )
            ]
        log("self.L3_repeat():", self.L3_repeat())
        return [
            DataTransfer(
                self.L3_repeat(),
                AieTile(TileType.Shim, aie_col_id, 0),
                [self.BufferObject_Id], self.L3BufferSize,
                [] if (self.is_input_tensor) else get_pattern(aie_col_id),
                [] if (not self.is_input_tensor) else get_pattern(aie_col_id),
            )
            for aie_col_id in range(self.AieCols)
        ]
