import os
import sys
from typing import List
from dataclasses import dataclass

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, "..", "..", ))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'OGOAT', 'src', 'L1_fusion'))

from dmacompiler import config
from dataflow_common import ceildiv, iceil
from dataflow_utils import CommonDims
from kernel_func_list import kernel_func_list


def concat_L1_limit_addr() -> int:
    return iceil(57000, 64)


def padding(inDim: int, padding_enable: bool = True):
    # outDim = max(64, iceil(inDim, 8)) if padding_enable else inDim
    outDim = iceil(inDim, 8) if padding_enable else inDim
    return outDim


def find_wrap_less_than_1024(y):
    for x in range(1, 1025):  # x from 1 to 1024 inclusive
        if y % x == 0 and y // x <= 1024:        # Check if y is divisible by x
            n = y // x
            return x, n       # Return the first valid pair (x, n)
    return None               # No such x found

def split_cost(
    aie_cols: int,
    aie_rows: int,
    input_rows: list,
    input_cols: list,
    input_chs: list,
    output_ch_p: int,
    is_kernel_dePad: bool,
    concat_mode: int,
    ifm_bits: int,
    ofm_bits: int,
    is_qdq: bool,
    has_scratch_buf: bool,
    scratch_buf_bits: int,
):
    from functools import cmp_to_key
    def sort_subv_cost(
        aie_cols: int, aie_rows: int,
        input_rows: list,
        input_cols: list,
        input_chs: list,
        subv_splits: list,
        eps: float = 1e-2
    ) -> list:
        def cmp(a, b):
            _, _, _, _, oc_a, tl_a = a
            _, _, _, _, oc_b, tl_b = b
            if abs(oc_a - oc_b) <= eps:
                if tl_a != tl_b:
                    return -1 if tl_a < tl_b else 1
                # fall back to over_compute to keep ordering deterministic
            if oc_a == oc_b:
                return 0
            return -1 if oc_a < oc_b else 1

        return sorted(subv_splits, key=cmp_to_key(cmp))


    num_inputs = len(input_rows)
    usable_mt_size = config.MAX_MEMTILE_ADDR - aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE
    usable_core_size = concat_L1_limit_addr()

    is_kernel = is_kernel_dePad


    #NOTE: the output dims might voilate the shim ofm steps
    # 1. here juest exclude one case;
    #    1) when W*C > 2^20 words, the Yis = 1
    if concat_mode == 0:
        out_row = input_rows[0]
        out_col = input_cols[0]
        # out_ch  = sum(input_chs)
    elif concat_mode == 1:
        out_row = input_rows[0]
        out_col = sum(input_cols)
        # out_ch  = input_chs[0]
    else:
        assert False, "Not Support Row(H) concat"

    out_ch = output_ch_p

    shim_ofm_stepsize_limit_3dim = ((out_row * out_col * out_ch) * ofm_bits // 32) >=  2**20
    """
    Cis:
        when concat_mode = 0, Cis = C[i], i = 0..num_inputs-1
        else: Cis = n * C_gran , C_gran = 1 or 2, combing with X_gran
    Xis:
        when concat_mode = 1, Xis = X[i], i = 0..num_inputs -1
        else: Wis = n * X_gran, X_gran = 1 or 2, combining with C_gran
    Yis:
        = n * Y_gran, Y_gran = 1
    """

    Yis_grid = [1 + n for n in range(input_rows[0])] if not shim_ofm_stepsize_limit_3dim else [1]

    total_elements = sum(input_rows[n] * input_cols[n] * input_chs[n] \
                    for n in range(num_inputs))
    subv_splits = []
    for Yis_0 in Yis_grid:
        Yis = [Yis_0] * num_inputs
        Y_loop = ceildiv(input_rows[0], Yis_0 * aie_cols)
        if concat_mode == 0: # C-wise
            Cis = input_chs
            C_loop = 1
            Xis_grid = [1 + n for n in range(input_cols[0])]
            for Xis_0 in Xis_grid:
                C_odd = 1 if any(n % 2 == 1 for n in input_chs) else 0
                Xis_0 = iceil(Xis_0, 32 // ifm_bits) if C_odd else Xis_0
                Xis = [Xis_0] * num_inputs
                X_loop = ceildiv(input_cols[0], Xis_0 * aie_rows)
                ifm_subv_size = sum([Yis[n] * Xis[n] * Cis[n] * ifm_bits //8 for n in range(num_inputs)])
                scratch_buf = sum([iceil(Yis[n] * Xis[n] * Cis[n], 64) * scratch_buf_bits // 8 for n in range(num_inputs)]) + 64 * scratch_buf_bits // 8 \
                              if has_scratch_buf else 0
                ofm_subv_size = Yis[0] * Xis[0] * out_ch * ofm_bits //8
                total_loop = Y_loop * X_loop * C_loop

                all_mt_size = Yis[0] * Xis[0] * aie_rows * sum(Cis) * ifm_bits // 8
                all_mt_size += Yis[0] * Xis[0] * aie_rows * out_ch *  ofm_bits // 8
                #NOTE: for qdq enabled, the total element for each input has to be 64 elems aligned
                #so:  1) will reserve some buffer before ifm start for 64 elems overflow
                qdq_overflow = sum(iceil(y * x * c, 64) - y * x * c for y, x, c in zip(Yis, Xis, Cis))

                is_valid = scratch_buf + qdq_overflow + ifm_subv_size + ofm_subv_size <= usable_core_size and all_mt_size <= usable_mt_size
                """NOTE: we need add fuse failure execption here to avoid the split. """
                # 1. if any(Cis[n]) is able to be fused with Xis[0] -- use offset
                #    1) Xis[0] is dividable, it can fused with Cis[n];
                #    2) if odd or (not dividable), Xis[0] <= wrap_max(=1023)
                # 2. if any(Cis[n]) is too big : > 1023*4 bytes, can't be fused.
                #    1) has to be de-fused with Xis[n]
                mt_mm2s_D0_D1_fusion = all(Xis[0] * Cis[n] * (ifm_bits // 8) // 4 <= 1023 for n in range(num_inputs)) \
                                    if input_cols[0] % (Xis[0]) != 0  else True # 1023 is the AIE2P limitation
                mt_mm2s_D0_D1_nofusion = Xis[0] <=1023 and all(Cis[n] * (ifm_bits // 8) // 4 <= 1023 for n in range(num_inputs))
                mt_mm2s_fusion_valid = mt_mm2s_D0_D1_fusion or mt_mm2s_D0_D1_nofusion
                # mt_mm2s_fusion_valid = True
                total_compute  = total_loop * Yis_0 * Xis_0 * \
                                 sum(input_chs) * aie_cols * aie_rows
                over_compute_ratio = total_compute / total_elements
                Cos = out_ch
                if is_valid and mt_mm2s_fusion_valid:
                    subv_splits.append((
                        Yis, Xis, Cis, Cos,
                        # Y_loop, X_loop, C_loop,
                        over_compute_ratio,
                        total_loop
                    )
                )
        elif concat_mode == 1: #W-wise
            Xis = input_cols
            X_loop = 1
            Cis_grid = [1 + n for n in range(input_chs[0])]
            for Cis_0 in Cis_grid:
                if Cis_0 == 4:
                    print("debug")
                Cis = [Cis_0] * num_inputs
                C_loop = ceildiv(input_chs[0], Cis_0 * aie_rows)
                #NOTE: the Kernel should not pad ch (if with kernel) because input already being padded.
                ifm_subv_size = sum([Yis[n] * Xis[n] * Cis[n] * ifm_bits //8 for n in range(num_inputs)])
                ofm_subv_size = sum([Yis[n] * Xis[n] * Cis[n] * ofm_bits //8 for n in range(num_inputs)])
                scratch_buf = sum([iceil(Yis[n] * Xis[n] * Cis[n], 64) * scratch_buf_bits // 8 for n in range(num_inputs)]) \
                              if has_scratch_buf else 0
                total_loop = Y_loop * X_loop * C_loop
                #NOTE: for qdq enabled, the total element for each input has to be 64 elems aligned
                #so:  1) will reserve some buffer before ifm start for 64 elems overflow
                qdq_overflow = sum(iceil(y * x * c, 64) - y * x * c for y, x, c in zip(Yis, Xis, Cis))
                is_valid = scratch_buf + qdq_overflow + ifm_subv_size + ofm_subv_size <= usable_core_size
                is_word_alignment = 1 if any((Cis[n] * Xis[n]) % (ifm_bits // 8) == 0 \
                                    for n in range(num_inputs)) else 0
                if is_kernel or is_qdq:
                    is_Cis_even = (concat_mode == 1) and (Cis[0] % (32 // ifm_bits) == 0) and (Cis[0] % (32 // ofm_bits) == 0)
                else:
                    if any(input_cols[n] % (32 // ifm_bits) != 0 for n in range(num_inputs)):
                        is_Cis_even = Cis[0] % (32 // ifm_bits) == 0 and Cis[0] % (32 // ofm_bits) == 0
                    else:
                        is_Cis_even = True
                total_compute  = total_loop * Yis_0 * Cis_0 * sum(input_cols) * aie_cols * aie_rows
                over_compute_ratio = total_compute / total_elements
                Cos = Cis

                if is_valid and is_word_alignment and is_Cis_even:
                    subv_splits.append((
                        Yis, Xis, Cis, Cos,
                        over_compute_ratio,
                        total_loop
                    )
                )

    sorted_subv =  sort_subv_cost(aie_cols, aie_rows,
                                  input_rows, input_cols, input_chs,
                                  subv_splits,
                                  )
    Yis, Xis, Cis, Cos, _, _ = sorted_subv[0]

    return is_kernel, Yis, Xis, Cis, Cos


@dataclass
class ConcatDims(CommonDims):
    def __init__(
        self,
        aie_cols: int,
        aie_rows: int,
        num_inputs: int,
        concat_mode: int,
        input_rows: list,
        input_cols: list,
        input_chs: list,
        # ifm_bits: int,
        # ofm_bits: int,
        is_int16: bool,
        is_qdq: bool,
        qdq_mode: int = 3,
        is_signed: bool = False,
        padding_enable: bool = True,
        input_types: list = ["act", "act"],
    ):
        has_scratch_buf = False
        scratch_buf_bits = 8
        ifm_bits = 16
        ofm_bits = 16
        qdq_mode = int(qdq_mode)
        if is_int16:
            ifm_bits = 16
            ofm_bits = ifm_bits
            if qdq_mode == 0 or qdq_mode == 2:
                has_scratch_buf = True
            else:
                has_scratch_buf = False
            scratch_buf_bits = 16
            is_int16_concat = True
        else: # int8
            if qdq_mode == 0:  #dq only
                ifm_bits = 8
                ofm_bits = 16
                has_scratch_buf = True
                scratch_buf_bits = 16
                is_int16_concat = True
            elif qdq_mode == 1: #q only
                ifm_bits = 16
                ofm_bits = 8
                has_scratch_buf = False # q output use ifm buffer
                scratch_buf_bits = 8
                is_int16_concat = False
            elif qdq_mode == 2:
                ifm_bits = 8
                ofm_bits = 8
                has_scratch_buf = True
                scratch_buf_bits = 16
                is_int16_concat = False
            elif qdq_mode == 3:
                ifm_bits = 8
                ofm_bits = 8
                has_scratch_buf = False
                scratch_buf_bits = 8
                is_int16_concat = False
            else:
                assert False, f"qdq_mode:{qdq_mode} is not in range(0..3) !"
        # Initialize the base class with known CommonDims fields
        super().__init__(
            aie_cols=aie_cols,
            aie_rows=aie_rows,
            ifm_bits=ifm_bits,
            ofm_bits=ofm_bits,
            Yis=[],  # Will be populated later via split_cost
            Xis=[],
            Cis=[],
            Cos=[]
        )

        self.MAX_INPUTS = 6
        self.MAX_INPUTS_NO_KERNEL = 12
        self.num_inputs = num_inputs
        self.concat_mode = concat_mode
        self.input_rows = input_rows
        self.input_cols = input_cols
        self.input_chs = input_chs
        self.is_qdq = is_qdq
        self.padding_enable = padding_enable
        self.is_int16 = is_int16
        self.is_signed = is_signed
        self.has_scratch_buf = has_scratch_buf
        self.scratch_buf_bits = scratch_buf_bits
        self.is_int16_concat = is_int16_concat
        if concat_mode == 0:  # channel concat
            assert all(y == self.input_rows[0] for y in self.input_rows)
            assert all(x == self.input_cols[0] for x in self.input_cols)
        elif concat_mode == 1:  # column concat
            assert all(y == self.input_rows[0] for y in self.input_rows)
            assert all(c == self.input_chs[0] for c in self.input_chs)

        assert num_inputs == len(input_rows), \
            f"The num of inputs:{num_inputs} doesn't match the tensor num given: {len(input_rows)}"

        self.CoreqdqPrmSize = 256
        self.wgt_subv_size = self.CoreqdqPrmSize

        self.input_types = input_types
        # Atmost allowed const input = 1
        # self.num_act_inputs = (self.num_inputs - 1) if self.is_const_input else self.num_inputs

        self.row_alignment = 8 # this is the W8 alignment

        """
        This op support below combinations:
           1. concat + no_kernel                   --> the concat done @ifm_mt.s2mm
           2. concat + kernel_concat               --> all input subv sent into kernel to do concat
           3. concat + kernel_qdq                  --> the concat done @ifm_mt.s2mm and send kernel for qdq
           4. concat + kernel_concat + kernel_qdq  --> all input subv sent into kernel to do concat then qdq

           for 1:  is_kernel = False,  is_qdq = False
           for 2:  is_kernel = True,   is_qdq = False
           for 3:  is_kernel = False,  is_qdq = True
           for 4:  is_kernel = True,   is_qdq = True


        padding depadding
          1. for C-dim, it assumes the C-dim padded by previous layers or host to max(64, W8)
          2. if C-dim, all inputs are even, the dataflow dma doing the depad
              from shim.mm2s
          3. if C-dim, any of the input is odd, sending to kernel for depad
             currently, because of kernel not availble, don't do anything
             leave a flag
          4. for the C-dim at ofm.mm2s, doing padding.

        """
        #NOTE: not available yet:
        self.is_kernel_depad_available = True

        self.input_chs_p = [padding(input_chs[n]) for n in range(num_inputs)]
        #current kernel doesn't support depad yet
        if self.is_kernel_depad_available:
            self.output_ch_core = padding(sum(self.input_chs) if concat_mode == 0 else self.input_chs[0])
        else:
            self.output_ch_core = sum(self.input_chs_p) if concat_mode == 0 else self.input_chs_p[0]
        self.output_ch_dma  = padding(sum(self.input_chs) if concat_mode == 0 else self.input_chs[0])

        #NOTE: for some big C-dim, the wrap except need to consider:
        # 1. wrap limit: 1024 words
        # 2. if any C-dim * (32 // ifm_bits) // > 1024
        # 3. and can't be MOD
        if any (find_wrap_less_than_1024(c * (32 // ifm_bits) //4) == None for c in input_chs):
            C_wrap_valid = False
        else:
            C_wrap_valid = True

        #NOTE: for shim ifm dim, it has to be 3 or less.
        #   1) but if the DMA doing depad,  if the size is too big, it might generate 2dims in C
        #   2) plus X and Y dim, it will generate total 4dims, but shim has only 3.
        if any (self.input_chs[n] * (32 // ifm_bits) //4 > 1024 and  self.input_chs[n] != self.input_chs_p[n]
                for n in range(num_inputs) ) and concat_mode == 0:
            shim_ifm_wrap_less_4 = False
        else:
            shim_ifm_wrap_less_4 = True

        #NOTE: currently the depad kernel is not integrated
        # so for the case when of input has odd-C
        #   1. we send all C-dim of all inputs to the kernel
        #   2. the kernel will concat all without depad
        # when the depad kernel available
        #  1. we will send C-dim of all inputs to the kernel
        #  2. the output from kernel will just be the valid part concat.

        if concat_mode == 0:
            if any(input_chs[n] % (32 // ifm_bits) != 0 for n in range(num_inputs)) or \
            C_wrap_valid == False or shim_ifm_wrap_less_4 == False:
                self.is_kernel_dePad = True
                self.output_ch_p = self.output_ch_core
                self.input_chs_valid = self.input_chs_p
            else:
                # The below commented block is to enable dma concat on inner-most dimension
                # self.is_kernel_dePad = False
                # self.output_ch_p = self.output_ch_dma
                # self.input_chs_valid = self.input_chs
                self.is_kernel_dePad = True
                self.output_ch_p = self.output_ch_core
                self.input_chs_valid = self.input_chs_p
        else:
            self.is_kernel_dePad = False
            self.output_ch_p = self.output_ch_dma
            self.input_chs_valid = self.input_chs_p
        """this is a temporary solution for the case when Y is small
           1. current split froms shim to mt is Y8
           2. from mt to core, it is combined X and C
           3. for small Y and X is even, we can reconcstrct Y and X
        """
        if concat_mode == 0:
            Yi = input_rows[0]
            Xi = input_cols[0]
            star_iter = 1
            while Yi * star_iter < aie_cols and Xi % (star_iter * 32 // ifm_bits) == 0:
                 star_iter *= 2
                 Yi *= star_iter
                 Xi //= star_iter
            if Yi * Xi == input_rows[0] * input_cols[0]:
                Yi_adj = Yi
                Xi_adj = Xi
                self.input_rows_orig = self.input_rows
                self.input_cols_orig = self.input_cols
                self.input_rows = [Yi_adj] * num_inputs
                self.input_cols = [Xi_adj] * num_inputs
            else:
                self.input_rows_orig = self.input_rows
                self.input_cols_orig = self.input_cols
                self.input_rows = self.input_rows
                self.input_cols = self.input_cols
        else:
            self.input_rows_orig = self.input_rows
            self.input_cols_orig = self.input_cols
        sum_ch = sum(self.input_chs_p) * (ifm_bits + ofm_bits) // 8
        sum_col = sum(self.input_cols) * (ifm_bits + ofm_bits) // 8 * 2 # C_gran = 2
        sum_row = sum(self.input_rows) * (ifm_bits + ofm_bits) // 8 * 2 # C_gran = 2
        if concat_mode == 0 and sum_ch >= concat_L1_limit_addr():
            assert False, f"the Concated C-dim: {sum_ch} exceeds the core allowed memory size, can't support "
        if concat_mode == 1 and sum_col >= concat_L1_limit_addr():
            assert False, f"the Concated W-dim: {sum_col} exceeds the core allowed memory size, can't support "

        self.is_kernel, \
        self.Yis, \
        self.Xis, \
        self.Cis, \
        self.Cos = split_cost(aie_cols, aie_rows,
                    self.input_rows, self.input_cols, self.input_chs_valid, self.output_ch_p,
                    self.is_kernel_dePad,
                    self.concat_mode,
                    self.ifm_bits, self.ofm_bits,
                    self.is_qdq,
                    self.has_scratch_buf, self.scratch_buf_bits,
                    )
        self.num_inputs_exception = not self.is_kernel and self.num_inputs > self.MAX_INPUTS_NO_KERNEL and\
                       ( (concat_mode == 0 and all(c == self.input_chs[0] for c in self.input_chs)) or \
                        (concat_mode == 1 and all(x == self.input_cols[0] for x in self.input_cols)))


        self.kernel_names = {}
        self.kernel_includes = ["super.hh"]

        if not (self.is_kernel) and (is_qdq):
            self.kernel_names["run_combined_qdq"] = kernel_func_list.index("run_combined_qdq")
            self.kernel_includes.append("qdq/wrapper_qdq.cc")
        elif self.is_kernel and num_inputs == 2 and self.input_chs == [63, 1]:
            self.kernel_names["run_concat_c6463"] = kernel_func_list.index("run_concat_c6463")
            self.kernel_includes.append("concat_c6463/wrapper_concat_c6463.cc")
        else:
            self.kernel_names["run_concat"] = kernel_func_list.index("run_concat")
            self.kernel_includes.append("concat/wrapper_concat.cc")

        if self.is_kernel:
            assert num_inputs <= self.MAX_INPUTS, \
                f"total input is {num_inputs}, it should equal or less than {self.MAX_INPUTS}"
        else:
            if not self.num_inputs_exception:
                assert num_inputs <= self.MAX_INPUTS_NO_KERNEL, \
                    f"total input {num_inputs}, it should equal or less than {self.MAX_INPUTS_NO_KERNEL}"
        Y_loop = ceildiv(input_rows[0], aie_cols * self.Yis[0])
        if self.concat_mode == 0:
            X_loop = ceildiv(self.input_cols[0], aie_rows * self.Xis[0])
            C_loop = 1
        else:
            X_loop = 1
            C_loop = ceildiv(self.input_chs[0], aie_rows * self.Cis[0])

        self.phase = Y_loop * X_loop * C_loop

        if self.concat_mode == 0:
            self.output_row = self.input_rows[0]
            self.output_col = self.input_cols[0]
        elif self.concat_mode == 1:
            self.output_row = self.input_rows[0]
            self.output_col = sum(self.input_cols)
        else:
            self.output_row = sum(self.input_rows)
            self.output_col = self.input_cols[0]
        self.output_ch = self.output_ch_p
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.qdq_mode = qdq_mode


def run_tiler(aie_cols, aie_rows,
        num_inputs, concat_mode,
        input_rows, input_cols, input_chs,
        # ifm_bits, ofm_bits,
        is_int16,
        is_qdq, qdq_mode, is_signed,
        input_types):

    # initialize concat Dims

    dims = ConcatDims(
        aie_cols, aie_rows,
        num_inputs, concat_mode,
        input_rows, input_cols, input_chs,
        # ifm_bits, ofm_bits,
        is_int16,
        is_qdq, qdq_mode, is_signed,
        input_types=input_types,
    )
    return dims
