import os
import sys
from typing import List, Union, Dict
from dataclasses import dataclass

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)

from dataflow_common import iceil, ceildiv
from dataflow_utils import CommonDims

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def make_slice_dict(updated_input: list, updated_axis: int, out_start: int, out_stop: int):
    axis_mapping = ['N', 'H', 'W', 'C']
    axis_dim = axis_mapping[updated_axis]
    slice_dict = {
        'N': [0, updated_input[0]],
        'H': [0, updated_input[1]],
        'W': [0, updated_input[2]],
        'C': [0, updated_input[3]]
    }
    slice_dict[axis_dim] = [out_start, out_stop]
    return slice_dict

@dataclass
class SliceDims(CommonDims):
    def __init__(
        self,
        aie_cols: int,
        aie_rows: int,
        Ni: Union[int, List[int]],
        Yi: Union[int, List[int]],
        Xi: Union[int, List[int]],
        Ci: Union[int, List[int]],
        No: Union[int, List[int]],
        Yo: Union[int, List[int]],
        Xo: Union[int, List[int]],
        Co: Union[int, List[int]],
        Ni_gran: int,
        Yis: int,
        Xis: int,
        Cis: int,
        Cos: int,
        Cop: Union[int, List[int]],
        ifm_bits: int,
        ofm_bits: int,
        fixed_point_bits: int,

        # SliceDims-specific fields
        slice: Dict,
        Ci_orig: int,
        Ni_slice_start: int,
        Ni_slice_stop: int,
        axis: int,
        out_start: int,
        out_stop: int,
        wgt_subv_size: int,
        is_qdq: bool,
        qdq_mode: int,
        is_kernel: bool,
        enable_padding: int,
        kernel_padding: bool,
        shard: str,
        innerC: int,
        startC: int,
        Com: int,
        row_alignment: int,
        param_subv_size: int,
        subv_elem: int,
        has_scratch_buf: bool,
        out_step_dims: list,
    ):
        # Initialize CommonDims fields via super
        super().__init__(
            aie_cols=aie_cols,
            aie_rows=aie_rows,
            ifm_bits=ifm_bits,
            ofm_bits=ofm_bits,
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            Cop=Cop,
            Ni_gran=Ni_gran,
            Yis=Yis,
            Xis=Xis,
            Cis=Cis,
        )

        # SliceDims-specific fields
        self.fixed_point_bits = fixed_point_bits
        self.slice = slice
        self.Ci_orig = Ci_orig
        self.Ni_slice_start = Ni_slice_start
        self.Ni_slice_stop = Ni_slice_stop
        self.axis = axis
        self.out_start = out_start
        self.out_stop = out_stop
        self.wgt_subv_size = wgt_subv_size
        self.is_qdq = is_qdq
        self.qdq_mode = qdq_mode
        self.is_kernel = is_kernel
        self.enable_padding = enable_padding
        self.kernel_padding = kernel_padding
        self.shard = shard
        self.innerC = innerC
        self.startC = startC
        self.Com = Com
        self.Cos = Cos
        self.row_alignment = row_alignment
        self.param_subv_size = param_subv_size
        self.subv_elem= subv_elem
        self.has_scratch_buf=has_scratch_buf
        self.out_step_dims=out_step_dims


def split_cost(dims: SliceDims):
    aie_cols = dims.aie_cols
    aie_rows = dims.aie_rows
    Yo = dims.Yo
    Xo = dims.Xo
    Co = dims.Co
    slice = dims.slice
    is_kernel = dims.is_kernel
    enable_padding = dims.enable_padding
    Yis = dims.Yis
    Xis = dims.Xis
    Cis = dims.Cis
    Cos = dims.Cos

    def generate_split():
        """
        Allocate strategy:
            1. the Y_loop split cross columns
            2. the X_loop * C_loop split cross rows with phase.
            3. pack conseuctive 4(aie_rows) subv into memtile.
        """
        Yi_start = slice['H'][0]
        Yi_stop  = slice['H'][1]
        Xi_start = slice['W'][0]
        Xi_stop  = slice['W'][1]
        Ci_start = slice['C'][0]
        Ci_stop  = slice['C'][1]

        Y_loop = ceildiv(Yo, Yis)
        X_loop = ceildiv(Xo, Xis)
        C_loop = 1

        kernels = [[[] for _ in range(aie_rows)] for _ in range(aie_cols)]

        total_phase = 0

        current_col = 0
        current_row = 0

        XC_loop = X_loop * C_loop
        Y_iter = ceildiv(Y_loop, aie_cols)
        XC_iter = ceildiv(XC_loop, aie_rows)

        for y_idx in range(Y_iter):
            for current_col in range(aie_cols):
                # current_row = 0
                Y_start_real = Yi_start + y_idx * aie_cols * Yis + current_col * Yis
                Y_start = min(Y_start_real, Yi_stop)
                Y_stop = min(Y_start + Yis, Yi_stop)
                x_idx = 0
                X_stop = Xi_stop    # for lint error fix
                C_stop = Ci_stop    # for lint error fix
                for xc_idx in range(XC_iter):
                    for current_row in range(aie_rows):
                        X_start = min(Xi_start + x_idx * Xis, Xi_stop)
                        X_stop = min(X_start + Xis, Xi_stop)
                        C_start = Ci_start
                        C_stop = Ci_start + Cis
                        x_idx += 1
                        # Assign slice
                        kernels[current_col][current_row].append(
                            ((Y_start, Y_stop),
                             (X_start, X_stop),
                             (C_start, C_stop)) if Y_start_real < Yi_stop else [])
                        total_phase += 1
        return kernels

    def generate_shim_split(kernel_subv: list):
        def filter_value(row_split: list, idx: int):
            valid_splits = []
            for split in row_split:
                if len(split) > 1 and split[idx]:
                    valid_splits.append(split[idx])
            all_values = [v for pt in valid_splits for v in pt]
            return all_values

        # IFM:
        #  1. each inputs from one BD
        #  2. each columns has ceildiv(X_loop * C_loop, aie_rows) * ceildiv(Y_loop, aie_cols) phase
        #  3. phase reconfig for each subv : to be optimized.
        Yi_start = slice['H'][0]
        Xi_start = slice['W'][0]

        shim_ifm = [[] for _ in range(aie_cols)]
        for idxCol, colSplit in enumerate(kernel_subv):
            assert aie_rows == len(colSplit)
            total_phase = len(colSplit[0])
            for phase in range(total_phase):
                row_split = []
                for row in range(aie_rows):
                    if colSplit[row]:
                        row_split.append(colSplit[row][phase])
                    else:
                        row_split.append([])
                # combine each colomun
                all_values = filter_value(row_split, 0)
                if all_values:
                    Yim = (min(all_values), max(all_values))
                else:
                    Yim = (slice['H'][1], slice['H'][1])
                all_values = filter_value(row_split, 1)
                if all_values:
                    Xim = (min(all_values), max(all_values))
                else:
                    Xim = (slice['W'][1], slice['W'][1])
                all_values = filter_value(row_split, 2)
                if all_values:
                    Cim = row_split[0][2]
                else:
                    Cim = (slice['C'][1], slice['C'][1])
                shim_ifm[idxCol].append([Yim, Xim, Cim])

        # OFM:
        #  1. all inputs combined to one output with one BD
        #  2. each column has ceildiv(X_loop * C_loop, aie_rows) * ceildiv(Y_loop, aie_cols) phase
        #  3. phase reconfig for each subv : to be optimized.
        total_phase = len(shim_ifm[0])
        shim_ofm = [[] for _ in range(aie_cols)]
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [shim_ifm[col][p]] if len(shim_ifm[col][p]) > 0 else []
                s2mm = None     # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    shim_Y_start = split[0][0][0] - Yi_start
                    shim_Y_stop = split[0][0][1] - Yi_start
                    shim_X_start = split[0][1][0] - Xi_start
                    shim_X_stop = split[0][1][1] - Xi_start
                    # because we don't split C, the C_loop = 1
                    # shim_C_start = split[0][2][0] - Ci_start
                    # shim_C_stop  = split[0][2][1] - Ci_start
                    shim_C_start = 0
                    shim_C_stop  = Co if not enable_padding else iceil(Co, 8)
                    s2mm = ((shim_Y_start, shim_Y_stop),
                            (shim_X_start, shim_X_stop),
                            (shim_C_start, shim_C_stop))
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                shim_ofm[col].append(s2mm)

        shim_transfer = {}
        shim_transfer['shim_ifm'] = shim_ifm
        shim_transfer['shim_ofm'] = shim_ofm
        return shim_transfer

    def generate_mt_ifm_split_no_kernel(shim_ifm: list):
        mt_ifm = [[] for _ in range(aie_cols)]
        for idxCol, colSplit in enumerate(shim_ifm):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] - phaseSplit[0][0],
                        phaseSplit[1][1] - phaseSplit[1][0],
                        phaseSplit[2][1] - phaseSplit[2][0]
                        )
                mt_ifm[idxCol].append(size)
        # generating the mt_ifm.memory
        mt_ifm_mem = [[] for _ in range(aie_cols)]
        total_phase = len(mt_ifm[0])
        mem_empty = None
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                if not all(len(inner) == 0 for inner in split):
                    mt_Y = split[0][0]
                    mt_X = split[0][1]
                    mt_C = split[0][2]
                    mem = (mt_Y, mt_X, mt_C)
                    if col == 0 and p == 0:
                        mem_empty = mem
                else:
                    mem = mem_empty
                mem = mem_empty if any(m <= 0 for m in mem) else mem
                mt_ifm_mem[col].append(mem)

        # generating the mt_ifm.s2mm
        mt_ifm_s2mm = [[] for _ in range(aie_cols)]

        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                s2mm = None     # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    mt_Y_start = 0
                    mt_Y_stop = split[0][0]
                    mt_X_start = 0
                    mt_X_stop = split[0][1]
                    mt_C_start = 0
                    mt_C_stop = split[0][2]
                    s2mm = ((mt_Y_start, mt_Y_stop),
                            (mt_X_start, mt_X_stop),
                            (mt_C_start, mt_C_stop))
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                mt_ifm_s2mm[col].append(s2mm)
        # generating the mt_ifm.mm2s
        mt_ifm_mm2s = [[] for _ in range(aie_cols)]
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                mm2s = None     # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    mt_Y_start = 0
                    mt_Y_stop = split[0][0]
                    mt_X_start = 0
                    mt_X_stop = split[0][1]
                    mt_C_start = 0
                    mt_C_stop = split[0][2]
                    mm2s = ((mt_Y_start, mt_Y_stop),
                            (mt_X_start, mt_X_stop),
                            (mt_C_start, mt_C_stop))
                else:
                    mm2s = ((0, 0), (0, 0), (0, 0))
                mt_ifm_mm2s[col].append(mm2s)
        mt_ifm_transfer ={}
        mt_ifm_transfer['mt_ifm_mem'] = mt_ifm_mem
        mt_ifm_transfer['mt_ifm_s2mm'] = mt_ifm_s2mm
        mt_ifm_transfer['mt_ifm_mm2s'] = mt_ifm_mm2s
        return mt_ifm_transfer

    def generate_mt_ifm_split_kernel(shim_ifm: list):
        mt_ifm = [[] for _ in range(aie_cols)]
        for idxCol, colSplit in enumerate(shim_ifm):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] - phaseSplit[0][0],
                        phaseSplit[1][1] - phaseSplit[1][0],
                        phaseSplit[2][1] - phaseSplit[2][0],)
                mt_ifm[idxCol].append(size)
        # generating the mt_ifm.memory
        mt_ifm_mem = [[] for _ in range(aie_cols)]
        total_phase = len(mt_ifm[0])
        mem_empty = None
        for col in range(aie_cols):
            for p in range(total_phase):
                split = mt_ifm[col]
                if split:
                    mem = split[p]
                    if col == 0 and p == 0:
                        mem_empty = mem
                else:
                    mem = mem_empty
                mem = (Yis, Xis, Cis) if any(m == 0 for m in mem) else mem
                mt_ifm_mem[col].append(mem)

        # generating the mt_ifm.s2mm
        mt_ifm_s2mm = [[] for _ in range(aie_cols)]
        for col in range(aie_cols):
            for p in range(total_phase):
                split = mt_ifm[col]
                if split:
                    mt_Y_start = 0
                    mt_Y_stop = split[p][0]
                    mt_X_start = 0
                    mt_X_stop = split[p][1]
                    mt_C_start = 0
                    mt_C_stop = split[p][2]
                    s2mm = ((mt_Y_start, mt_Y_stop),
                            (mt_X_start, mt_X_stop),
                            (mt_C_start, mt_C_stop))
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                mt_ifm_s2mm[col].append(s2mm)
        # generating the mt_ifm.mm2s
        mt_ifm_mm2s = [[[] for _ in range(aie_rows)] for _ in range(aie_cols)]
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                if not all(len(inner) == 0 for inner in split):
                    Xim = Xis * aie_rows
                    Xi_real = split[0][1]
                    mt_Y_start = 0
                    mt_Y_stop = Yis
                    mt_C_start = 0
                    mt_C_stop = max(split[0][2], Cis)
                    for row in range(aie_rows):
                        mt_X_start = row * Xis
                        if mt_X_start >= Xi_real:
                            if Xi_real == 0:
                                mt_X_start = 0
                            else:
                                mt_X_start = Xi_real - Xis
                        mt_X_stop = mt_X_start + Xis
                        mm2s = ((mt_Y_start, mt_Y_stop),
                                (mt_X_start, mt_X_stop),
                                (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)
                else:
                    Xi_real = Xis
                    mt_Y_start = 0
                    mt_Y_stop = Yis
                    mt_C_start = 0
                    mt_C_stop = Cis
                    for row in range(aie_rows):
                        mt_X_start = row * Xis
                        if mt_X_start >= Xi_real:
                            mt_X_start = Xi_real - Xis
                        mt_X_stop = mt_X_start + Xis
                        mm2s = ((mt_Y_start, mt_Y_stop),
                                (mt_X_start, mt_X_stop),
                                (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)

        mt_ifm_transfer = {}
        mt_ifm_transfer['mt_ifm_mem'] = mt_ifm_mem
        mt_ifm_transfer['mt_ifm_s2mm'] = mt_ifm_s2mm
        mt_ifm_transfer['mt_ifm_mm2s'] = mt_ifm_mm2s
        return mt_ifm_transfer

    def generate_mt_ofm_split(shim_ifm: list):
        mt_ifm = [[] for _ in range(aie_cols)]
        for idxCol, colSplit in enumerate(shim_ifm):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] - phaseSplit[0][0],
                        phaseSplit[1][1] - phaseSplit[1][0],
                        phaseSplit[2][1] - phaseSplit[2][0],)
                mt_ifm[idxCol].append(size)
        # generating the mt_ifm.memory
        mt_ofm_mem = [[] for _ in range(aie_cols)]
        total_phase = len(mt_ifm[0])
        mem_empty = None
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                if not all(len(inner) == 0 for inner in split):
                    mt_Y = Yis
                    mt_X = aie_rows * Xis
                    # mt_C = Cis
                    mt_C = Co
                    mem = (mt_Y, mt_X, mt_C)
                    if col == 0 and p == 0:
                        mem_empty = mem
                else:
                    mem = mem_empty
                mem = mem_empty if any(m <= 0 for m in mem) else mem
                mt_ofm_mem[col].append(mem)

        # generating the mt_ifm.s2mm
        mt_ofm_s2mm = [[[] for _ in range(aie_rows)] for _ in range(aie_cols)]

        for col in range(aie_cols):
            for p in range(total_phase):
                empty_s2mm = None
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                if not all(len(inner) == 0 for inner in split):
                    mt_Y_start = 0
                    mt_Y_stop = Yis
                    mt_C_start = 0
                    # mt_C_stop  = Cis
                    mt_C_stop = Co
                    for row in range(aie_rows):
                        mt_X_start = row * Xis
                        mt_X_stop = mt_X_start + Xis
                        s2mm = ((mt_Y_start, mt_Y_stop),
                                (mt_X_start, mt_X_stop),
                                (mt_C_start, mt_C_stop))
                        mt_ofm_s2mm[col][row].append(s2mm)
                else:
                    for row in range(aie_rows):
                        s2mm = mt_ofm_s2mm[0][row][p]
                        mt_ofm_s2mm[col][row].append(s2mm)

        mt_ofm_mm2s = [[] for _ in range(aie_cols)]
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [mt_ifm[col][p]] if len(mt_ifm[col][p]) > 0 else []
                mm2s = None     # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    mt_Y_start = 0
                    mt_Y_stop = split[0][0]
                    mt_X_start = 0
                    mt_X_stop = split[0][1]
                    mt_C_start = 0
                    # mt_C_stop = split[0][2]
                    mt_C_stop = Co if not enable_padding else iceil(Co, 8)
                    mm2s = ((mt_Y_start, mt_Y_stop),
                            (mt_X_start, mt_X_stop),
                            (mt_C_start, mt_C_stop))
                else:
                    mm2s = ((0, 0), (0, 0), (0, 0))
                mt_ofm_mm2s[col].append(mm2s)

        mt_ofm_transfer = {}
        mt_ofm_transfer['mt_ofm_mem'] = mt_ofm_mem
        mt_ofm_transfer['mt_ofm_s2mm'] = mt_ofm_s2mm
        mt_ofm_transfer['mt_ofm_mm2s'] = mt_ofm_mm2s
        return mt_ofm_transfer

    kernel_subv = generate_split()

    shim_transfer = generate_shim_split(kernel_subv)

    if is_kernel:
        mt_ifm_transfer = generate_mt_ifm_split_kernel(shim_transfer['shim_ifm'])
    else:
        mt_ifm_transfer = generate_mt_ifm_split_no_kernel(shim_transfer['shim_ifm'])

    mt_ofm_transfer = generate_mt_ofm_split(shim_transfer['shim_ifm'])

    return shim_transfer, mt_ifm_transfer, mt_ofm_transfer


def slice_preproc_directives(
    dims: SliceDims,
    back_end: BackEnd,
) -> List[str]:
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('N_IN', dims.Ni),
        directive('Y_IN', dims.Yi),
        directive('X_IN', dims.Xi),
        directive('C_IN', dims.Ci_orig),
        directive('N_OUT', dims.No),
        directive('Y_OUT', dims.Yo),
        directive('X_OUT', dims.Xo),
        directive('C_OUT', dims.Co),
        directive('AXIS', dims.axis),
        directive('OUT_START', dims.out_start),
        directive('OUT_STOP', dims.out_stop),
        directive('OUT_STEP', dims.out_step_dims[dims.axis]),
        directive('WGT_SIZE', dims.wgt_subv_size),
        directive('QDQ_MODE', dims.qdq_mode),
        directive('COUT_PAD', int(dims.enable_padding)),
        directive('TXN_MODE', txn_mode),
        directive('FIXED_POINT_BIT_SIZE', dims.fixed_point_bits)
    ]