import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, "..", "..", ))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'OGOAT', 'src', 'L1_fusion'))

from dmacompiler import BackEnd, \
    set_dev_gen, DevGen, config
from concat_run_tiler import ConcatDims
from dataflow_common import ceildiv, iceil
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def generate_split(
    aie_cols: int, aie_rows: int,
    concat_mode: int,
    input_rows: list,
    input_cols: list,
    input_chs: list,
    Yis: list,
    Xis: list,
    Cis: list,
    Cos: int,
):
    """allocate strategy
    1. the Y_loop split cross columns
    2. the X_loop * C_loop split cross rows with phase.
    3. pack conseuctive 4(aie_rows) subv into memtile.
    Returns:
        _type_: _description_
    """
    Y_loop = ceildiv(input_rows[0], Yis[0])
    X_loop = 1 if concat_mode == 1 else \
        ceildiv(input_cols[0], Xis[0])
    C_loop = 1 if concat_mode == 0 else \
        ceildiv(input_chs[0], Cis[0])


    num_inputs = len(input_rows)
    kernels = [[] for _ in range(num_inputs)]
    kernels = [[[[ ] for _ in range(aie_rows)] for _ in range(aie_cols)] for _ in range(num_inputs)]

    total_phase =0
    for n in range(num_inputs):
        current_col = 0
        current_row = 0
        Y = input_rows[n]
        X = input_cols[n]
        C = input_chs[n]
        XC_loop = X_loop * C_loop
        Y_iter = ceildiv(Y_loop, aie_cols)
        XC_iter = ceildiv(XC_loop, aie_rows)
        for y_idx in range(Y_iter):
            for current_col in range(aie_cols):
                Y_start_real = y_idx * aie_cols *  Yis[n] + current_col * Yis[n]
                Y_start = min(Y_start_real, Y)
                Y_stop = min(Y_start + Yis[n], Y)
                x_idx = 0
                c_idx = 0
                X_stop = 0 # for lint error fix
                C_stop = 0 # for lint error fix
                for xc_idx in range(XC_iter):
                    for current_row in range(aie_rows):
                        if concat_mode == 0:
                            X_start = min(x_idx * Xis[n], X)
                            X_stop = min(X_start + Xis[n], X)
                            C_start = 0
                            C_stop = Cis[n]
                            x_idx += 1
                        elif concat_mode == 1:
                            X_start = 0
                            X_stop = Xis[n]
                            C_start = min(c_idx * Cis[n], C)
                            C_stop = min(C_start + Cis[n], C)
                            c_idx += 1
                        # Assign slice
                        kernels[n][current_col][current_row].append(
                            ( (Y_start, Y_stop), (X_start, X_stop), (C_start, C_stop)) \
                                if Y_start_real < Y else []
                        )
                        total_phase +=1
    return kernels


def generate_shim_split(
    aie_cols: int, aie_rows: int,
    input_rows: list, input_cols: list, input_chs: list, Cos: int,
    num_inputs: int,
    concat_mode: int,
    kernel_subv: list,
):
    def filter_value(row_split: list, idx: int):
        valid_splits = []
        for split in row_split:
            if len(split) > 1 and split[idx]:
                valid_splits.append(split[idx])
        all_values = [v for pt in valid_splits for v in pt]
        return all_values
    def combine_value(split: list, idx: int):
        valid_splits = []
        valid_size = []
        for split in split:
            if len(split) > 1 and split[idx]:
                valid_splits.append(split[idx])
                valid_size.append(split[idx][1] - split[idx][0])
        all_values = [v for pt in valid_splits for v in pt]
        return (min(all_values), sum(valid_size))
    """ifm:
    1. each inputs from one BD
    2. each columns has ceildiv(X_loop * C_loop, aie_rows) * ceildiv(Y_loop, aie_cols) phase
    3. phase reconfig for each subv : to be optimized.

    """
    shim_ifm = [[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for idxInput, input in enumerate(kernel_subv):
        assert aie_cols == len(input)
        for idxCol, colSplit in enumerate(input):
            assert aie_rows == len(colSplit)
            total_phase = len(colSplit[0])
            for phase in range(total_phase):
                row_split = []
                for row in range(aie_rows):
                    if colSplit[row]:
                        row_split.append(colSplit[row][phase])
                    else:
                        row_split.append([])
                #combine each colomun
                Xim = 0 # for lint error fix
                Cim = 0 # for lint error fix
                Yim = 0 # for lint error fix
                if concat_mode == 0: #channel-wise
                    all_values = filter_value(row_split, 0)
                    if all_values:
                        Yim = (min(all_values), max(all_values))
                    else:
                        Yim = (input_rows[idxInput], input_rows[idxInput])
                    all_values = filter_value(row_split, 1)
                    if all_values:
                        Xim = (min(all_values), max(all_values))
                    else:
                        Xim = (input_cols[idxInput], input_cols[idxInput])
                    all_values = filter_value(row_split, 2)
                    if all_values:
                        Cim = row_split[0][2]
                    else:
                        Cim = (input_chs[idxInput], input_chs[idxInput])
                elif concat_mode == 1:
                    all_values = filter_value(row_split, 0)
                    if all_values:
                        Yim = (min(all_values), max(all_values))
                    else:
                        Yim = (input_rows[idxInput], input_rows[idxInput])
                    all_values = filter_value(row_split, 2)
                    if all_values:
                        Cim = (min(all_values), max(all_values))
                    else:
                        Cim = (input_chs[idxInput], input_chs[idxInput])
                    all_values = filter_value(row_split, 1)
                    if all_values:
                        Xim = row_split[0][1]
                    else:
                        Xim = (input_cols[idxInput], input_cols[idxInput])
                else:
                    print("placeholder or Dont support!")
                shim_ifm[idxInput][idxCol].append([Yim, Xim, Cim])

    """ofm:
    1. all inputs combined to one output with one BD
    2. each columns has ceildiv(X_loop * C_loop, aie_rows) * ceildiv(Y_loop, aie_cols) phase
    3. phase reconfig for each subv : to be optimized.

    """
    total_phase = [len(split[0]) for split in shim_ifm]
    shim_ofm = [[] for _ in range(aie_cols)]
    for col in range(aie_cols):
        for p in range(total_phase[0]):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in shim_ifm]
            s2mm = None # for lint error fix
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    shim_Y_start = split[0][0][0]
                    shim_Y_stop  = split[0][0][1]
                    shim_X_start = split[0][1][0]
                    shim_X_stop  = split[0][1][1]
                    # shim_C_start, shim_C_stop  = combine_value(split, 2)
                    shim_C_start = 0
                    shim_C_stop  = Cos # why: because for concat_mode=0, we don't split C.
                    s2mm = ((shim_Y_start, shim_Y_stop), (shim_X_start, shim_X_stop), (shim_C_start, shim_C_stop))
                elif concat_mode ==1:
                    shim_Y_start = split[0][0][0]
                    shim_Y_stop  = split[0][0][1]
                    shim_X_start, shim_X_stop  = combine_value(split, 1)
                    shim_C_start = split[0][2][0]
                    shim_C_stop  = split[0][2][1]
                    s2mm = ((shim_Y_start, shim_Y_stop), (shim_X_start, shim_X_stop), (shim_C_start, shim_C_stop))
                else:
                    print("placeholder or Dont support!")
            else:
                s2mm = ((0, 0), (0, 0), (0, 0))
            shim_ofm[col].append(s2mm)

    shim_transfer ={}
    shim_transfer['shim_ifm'] = shim_ifm
    shim_transfer['shim_ofm'] = shim_ofm


    return shim_transfer


def generate_mt_ifm_split_no_kernel(
    aie_cols: int, aie_rows: int,
    concat_mode: int, num_inputs: int,
    shim_ifm: list,
    ):
    mt_ifm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for idxInput, input in enumerate(shim_ifm):
        assert aie_cols == len(input)
        for idxCol, colSplit in enumerate(input):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] -phaseSplit[0][0], phaseSplit[1][1] -phaseSplit[1][0], phaseSplit[2][1] -phaseSplit[2][0],)
                mt_ifm[idxInput][idxCol].append(size)
    #generating the mt_ifm.memory
    mt_ifm_mem =[[] for _ in range(aie_cols)]
    total_phase = len(mt_ifm[0][0])
    mem_empty = None
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y = split[0][0]
                    mt_X = split[0][1]
                    mt_C = sum(split[n][2] for n in range(num_inputs))
                    mem = (mt_Y, mt_X, mt_C)
                elif concat_mode ==1:
                    mt_Y = split[0][0]
                    mt_X = sum(split[n][1] for n in range(num_inputs))
                    mt_C = split[0][2]
                    mem = (mt_Y, mt_X, mt_C)
                else:
                    print("placeholder or Dont support!")
                if col == 0 and p == 0:
                    mem_empty = mem
            else:
                mem = mem_empty
            mem = mem_empty if any(m <= 0 for m in mem) else mem
            mt_ifm_mem[col].append(mem)

    #generating the mt_ifm.s2mm
    mt_ifm_s2mm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for n in range(num_inputs):
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
                s2mm = None # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    if concat_mode == 0:
                        mt_Y_start = 0
                        mt_Y_stop  = split[0][0]
                        mt_X_start = 0
                        mt_X_stop  = split[0][1]
                        mt_C_start = 0 if n == 0 else sum(split[i][2] for i in range(n))
                        mt_C_stop  = sum(split[i][2] for i in range(n+1))
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                    elif concat_mode ==1:
                        mt_Y_start = 0
                        mt_Y_stop  = split[0][0]
                        mt_X_start = 0 if n == 0 else sum(split[i][1] for i in range(n))
                        mt_X_stop  = sum(split[i][1] for i in range(n+1))
                        mt_C_start = 0
                        mt_C_stop  = split[0][2]
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                    else:
                        print("placeholder or Dont support!")
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                mt_ifm_s2mm[n][col].append(s2mm)
    #generating the mt_ifm.mm2s
    mt_ifm_mm2s =[[] for _ in range(aie_cols)]
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            mm2s = None # for lint error fix
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y_start = 0
                    mt_Y_stop  = split[0][0]
                    mt_X_start = 0
                    mt_X_stop  = split[0][1]
                    mt_C_start = 0
                    mt_C_stop  = sum(split[i][2] for i in range(num_inputs))
                    mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                elif concat_mode ==1:
                    mt_Y_start = 0
                    mt_Y_stop  = split[0][0]
                    mt_X_start = 0
                    mt_X_stop  = sum(split[i][1] for i in range(num_inputs))
                    mt_C_start = 0
                    mt_C_stop  = split[0][2]
                    mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                else:
                    print("placeholder or Dont support!")
            else:
                mm2s = ((0, 0), (0, 0), (0, 0))
            mt_ifm_mm2s[col].append(mm2s)
    mt_ifm_transfer ={}
    mt_ifm_transfer['mt_ifm_mem'] = mt_ifm_mem
    mt_ifm_transfer['mt_ifm_s2mm'] = mt_ifm_s2mm
    mt_ifm_transfer['mt_ifm_mm2s'] = mt_ifm_mm2s
    return mt_ifm_transfer


def generate_mt_ifm_split_kernel(
    aie_cols: int, aie_rows: int,
    concat_mode: int, num_inputs: int,
    Yis: list, Xis: list, Cis: list,
    shim_ifm: list,
    ):
    mt_ifm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for idxInput, input in enumerate(shim_ifm):
        assert aie_cols == len(input)
        for idxCol, colSplit in enumerate(input):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] -phaseSplit[0][0], phaseSplit[1][1] -phaseSplit[1][0], phaseSplit[2][1] -phaseSplit[2][0],)
                mt_ifm[idxInput][idxCol].append(size)
    #generating the mt_ifm.memory
    mt_ifm_mem =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    total_phase = len(mt_ifm[0][0])
    mem_empty = None
    for n in range(num_inputs):
        for col in range(aie_cols):
            for p in range(total_phase):
                split = mt_ifm[n][col]
                if split:
                    mem = split[p]
                    if col == 0 and p == 0:
                        mem_empty = mem
                else:
                    mem = mem_empty
                mem = (Yis[n], Xis[n], Cis[n]) if any(m == 0 for m in mem) \
                    else mem
                mt_ifm_mem[n][col].append(mem)

    #generating the mt_ifm.s2mm
    mt_ifm_s2mm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for n in range(num_inputs):
        for col in range(aie_cols):
            for p in range(total_phase):
                split = mt_ifm[n][col]
                if split:
                    mt_Y_start = 0
                    mt_Y_stop  = split[p][0]
                    mt_X_start = 0
                    mt_X_stop  = split[p][1]
                    mt_C_start = 0
                    mt_C_stop  = split[p][2]
                    s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                mt_ifm_s2mm[n][col].append(s2mm)
    #generating the mt_ifm.mm2s
    mt_ifm_mm2s =[[[[] for _ in range(aie_rows)] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                for n in range (num_inputs):
                    if concat_mode == 0:
                        Xim = Xis[n] * aie_rows
                        Xi_real = split[n][1]
                        mt_Y_start = 0
                        mt_Y_stop  = Yis[n]
                        mt_C_start = 0
                        mt_C_stop  = max(split[n][2], Cis[n])
                        for row in range(aie_rows):
                            mt_X_start = row * Xis[n]
                            if mt_X_start >= Xi_real:
                                if Xi_real == 0:
                                    mt_X_start = 0
                                else:
                                    mt_X_start = Xi_real -Xis[n]
                            mt_X_stop = mt_X_start + Xis[n]
                            mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                            mt_ifm_mm2s[n][col][row].append(mm2s)
                    elif concat_mode ==1:
                        Cim = Cis[n] * aie_rows
                        Ci_real = split[n][2]
                        mt_Y_start = 0
                        mt_Y_stop  = Yis[n]
                        mt_X_start = 0
                        mt_X_stop  = max(split[n][1], Xis[n])
                        for row in range(aie_rows):
                            mt_C_start = row * Cis[n]
                            if mt_C_start >= Ci_real:
                                if Ci_real == 0:
                                    mt_C_start = 0
                                else:
                                    mt_C_start = Ci_real -Cis[n]
                            mt_C_stop = mt_C_start + Cis[n]
                            mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                            mt_ifm_mm2s[n][col][row].append(mm2s)
                    else:
                        print("placeholder or Dont support!")
            else:
                for n in range (num_inputs):
                    if concat_mode == 0:
                        Xi_real = Xis[n]
                        mt_Y_start = 0
                        mt_Y_stop  = Yis[n]
                        mt_C_start = 0
                        mt_C_stop  = Cis[n]
                        for row in range(aie_rows):
                            mt_X_start = row * Xis[n]
                            if mt_X_start >= Xi_real:
                                mt_X_start = Xi_real -Xis[n]
                            mt_X_stop = mt_X_start + Xis[n]
                            mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                            mt_ifm_mm2s[n][col][row].append(mm2s)
                    elif concat_mode ==1:
                        Ci_real = Cis[n]
                        mt_Y_start = 0
                        mt_Y_stop  = Yis[n]
                        mt_X_start = 0
                        mt_X_stop  = Xis[n]
                        for row in range(aie_rows):
                            mt_C_start = row * Cis[n]
                            if mt_C_start >= Ci_real:
                                mt_C_start = Ci_real -Cis[n]
                            mt_C_stop = mt_C_start + Cis[n]
                            mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                            mt_ifm_mm2s[n][col][row].append(mm2s)
                    else:
                        print("placeholder or Dont support!")

    mt_ifm_transfer ={}
    mt_ifm_transfer['mt_ifm_mem'] = mt_ifm_mem
    mt_ifm_transfer['mt_ifm_s2mm'] = mt_ifm_s2mm
    mt_ifm_transfer['mt_ifm_mm2s'] = mt_ifm_mm2s
    return mt_ifm_transfer


def generate_mt_ifm_split_only_qdq(
    aie_cols: int, aie_rows: int,
    concat_mode: int, num_inputs: int,
    Yis: list, Xis: list, Cis: list,
    shim_ifm: list,
    ):
    mt_ifm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for idxInput, input in enumerate(shim_ifm):
        assert aie_cols == len(input)
        for idxCol, colSplit in enumerate(input):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] -phaseSplit[0][0], phaseSplit[1][1] -phaseSplit[1][0], phaseSplit[2][1] -phaseSplit[2][0],)
                mt_ifm[idxInput][idxCol].append(size)
    #generating the mt_ifm.memory
    mt_ifm_mem =[[] for _ in range(aie_cols)]
    total_phase = len(mt_ifm[0][0])
    mem_empty = None
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y = split[0][0]
                    mt_X = split[0][1]
                    mt_C = sum(split[n][2] for n in range(num_inputs))
                    mem = (mt_Y, mt_X, mt_C)
                elif concat_mode ==1:
                    mt_Y = split[0][0]
                    mt_X = sum(split[n][1] for n in range(num_inputs))
                    mt_C = split[0][2]
                    mem = (mt_Y, mt_X, mt_C)
                else:
                    print("placeholder or Dont support!")
                if col == 0 and p == 0:
                    mem_empty = mem
            else:
                mem = mem_empty
            mem = mem_empty if any(m <= 0 for m in mem) else mem
            mt_ifm_mem[col].append(mem)

    #generating the mt_ifm.s2mm
    mt_ifm_s2mm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for n in range(num_inputs):
        for col in range(aie_cols):
            for p in range(total_phase):
                split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
                s2mm = None # for lint error fix
                if not all(len(inner) == 0 for inner in split):
                    if concat_mode == 0:
                        mt_Y_start = 0
                        mt_Y_stop  = split[0][0]
                        mt_X_start = 0
                        mt_X_stop  = split[0][1]
                        mt_C_start = 0 if n == 0 else sum(split[i][2] for i in range(n))
                        mt_C_stop  = sum(split[i][2] for i in range(n+1))
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                    elif concat_mode ==1:
                        mt_Y_start = 0
                        mt_Y_stop  = split[0][0]
                        mt_X_start = 0 if n == 0 else sum(split[i][1] for i in range(n))
                        mt_X_stop  = sum(split[i][1] for i in range(n+1))
                        mt_C_start = 0
                        mt_C_stop  = split[0][2]
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                    else:
                        print("placeholder or Dont support!")
                else:
                    s2mm = ((0, 0), (0, 0), (0, 0))
                mt_ifm_s2mm[n][col].append(s2mm)

    mt_ifm_mm2s =[[[] for _ in range(aie_rows)] for _ in range(aie_cols)]
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    Xim = Xis[0] * aie_rows
                    Xi_real = split[0][1]
                    mt_Y_start = 0
                    mt_Y_stop  = Yis[0]
                    mt_C_start = 0
                    mt_C_stop  = sum(Cis[i] for i in range(num_inputs))
                    for row in range(aie_rows):
                        mt_X_start = row * Xis[0]
                        if mt_X_start >= Xi_real:
                            if Xi_real == 0:
                                mt_X_start = 0
                            else:
                                mt_X_start = Xi_real - Xis[0]
                        mt_X_stop = mt_X_start + Xis[0]
                        mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)
                elif concat_mode ==1:
                    Cim = Cis[0] * aie_rows
                    Ci_real = split[0][2]
                    mt_Y_start = 0
                    mt_Y_stop  = Yis[0]
                    mt_X_start = 0
                    mt_X_stop  = sum(Xis[i] for i in range(num_inputs))
                    for row in range(aie_rows):
                        mt_C_start = row * Cis[0]
                        if mt_C_start >= Ci_real:
                            if Ci_real == 0:
                                mt_C_start = 0
                            else:
                                mt_C_start = Ci_real -Cis[0]
                        mt_C_stop = mt_C_start + Cis[0]
                        mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)
                else:
                    print("placeholder or Dont support!")
            else:
                # for n in range (num_outputs):
                if concat_mode == 0:
                    Xi_real = Xis[0]
                    mt_Y_start = 0
                    mt_Y_stop  = Yis[0]
                    mt_C_start = 0
                    mt_C_stop  = sum(Cis[i] for i in range(num_inputs))
                    for row in range(aie_rows):
                        mt_X_start = row * Xis[0]
                        if mt_X_start >= Xi_real:
                            mt_X_start = Xi_real -Xis[0]
                        mt_X_stop = mt_X_start + Xis[0]
                        mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)
                elif concat_mode ==1:
                    Ci_real = Cis[0]
                    mt_Y_start = 0
                    mt_Y_stop  = Yis[0]
                    mt_X_start = 0
                    mt_X_stop  = sum(Xis[i] for i in range(num_inputs))
                    for row in range(aie_rows):
                        mt_C_start = row * Cis[0]
                        if mt_C_start >= Ci_real:
                            mt_C_start = Ci_real -Cis[0]
                        mt_C_stop = mt_C_start + Cis[0]
                        mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ifm_mm2s[col][row].append(mm2s)
                else:
                    print("placeholder or Dont support!")

    mt_ifm_transfer ={}
    mt_ifm_transfer['mt_ifm_mem'] = mt_ifm_mem
    mt_ifm_transfer['mt_ifm_s2mm'] = mt_ifm_s2mm
    mt_ifm_transfer['mt_ifm_mm2s'] = mt_ifm_mm2s
    return mt_ifm_transfer


def generate_mt_ofm_split(
    aie_cols: int, aie_rows: int,
    concat_mode: int, num_inputs: int,
    Yis: list, Xis: list, Cis: list, Cos: int,
    is_kernel: bool,
    shim_ifm: list,
    ):
    mt_ifm =[[[] for _ in range(aie_cols)] for _ in range(num_inputs)]
    for idxInput, input in enumerate(shim_ifm):
        assert aie_cols == len(input)
        for idxCol, colSplit in enumerate(input):
            for phaseSplit in colSplit:
                size = (phaseSplit[0][1] -phaseSplit[0][0], phaseSplit[1][1] -phaseSplit[1][0], phaseSplit[2][1] -phaseSplit[2][0],)
                mt_ifm[idxInput][idxCol].append(size)
    #generating the mt_ifm.memory
    mt_ofm_mem =[[] for _ in range(aie_cols)]
    total_phase = len(mt_ifm[0][0])
    mem_empty = None
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y = Yis[0]
                    mt_X = aie_rows * Xis[0]
                    mt_C = Cos if is_kernel else sum(split[n][2] for n in range(num_inputs))
                    # mt_C = Cos
                    mem = (mt_Y, mt_X, mt_C)
                elif concat_mode == 1:
                    mt_Y = Yis[0]
                    mt_X = sum(split[n][1] for n in range(num_inputs))
                    mt_C = aie_rows * Cis[0]
                    mem = (mt_Y, mt_X, mt_C)
                else:
                    print("placeholder or Dont support!")
                if col == 0 and p == 0:
                    mem_empty = mem
            else:
                mem = mem_empty
            mem = mem_empty if any(m <= 0 for m in mem) else mem
            mt_ofm_mem[col].append(mem)

    #generating the mt_ifm.s2mm
    mt_ofm_s2mm =[[[] for _ in range(aie_rows)] for _ in range(aie_cols)]

    for col in range(aie_cols):
        for p in range(total_phase):
            empty_s2mm = None
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y_start = 0
                    # mt_Y_stop  = split[0][0]
                    mt_Y_stop = Yis[0]
                    mt_C_start = 0
                    mt_C_stop  = Cos if is_kernel else max(sum(Cis[i] for i in range(num_inputs)), \
                                        sum(split[i][2] for i in range(num_inputs)))
                    for row in range(aie_rows):
                        mt_X_start = row * Xis[0]
                        mt_X_stop  = mt_X_start + Xis[0]
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ofm_s2mm[col][row].append(s2mm)
                elif concat_mode == 1:
                    mt_Y_start = 0
                    # mt_Y_stop  = split[0][0]
                    mt_Y_stop = Yis[0]
                    mt_X_start = 0
                    mt_X_stop  = max(sum(Xis[i] for i in range(num_inputs)), \
                                        sum(split[i][1] for i in range(num_inputs)))
                    for row in range(aie_rows):
                        mt_C_start = row * Cis[0]
                        mt_C_stop  = mt_C_start + Cis[0]
                        s2mm = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                        mt_ofm_s2mm[col][row].append(s2mm)
                else:
                    print("placeholder or Dont support!")
            else:
                for row in range(aie_rows):
                    s2mm = mt_ofm_s2mm[0][row][p]
                    mt_ofm_s2mm[col][row].append(s2mm)

    mt_ofm_mm2s =[[] for _ in range(aie_cols)]
    for col in range(aie_cols):
        for p in range(total_phase):
            split = [s[col][p] if len(s[col]) >0 and len(s[col][p]) >0 else [] for s in mt_ifm]
            mm2s = None # for lint error fix
            if not all(len(inner) == 0 for inner in split):
                if concat_mode == 0:
                    mt_Y_start = 0
                    mt_Y_stop  = split[0][0]
                    mt_X_start = 0
                    mt_X_stop  = split[0][1]
                    mt_C_start = 0
                    mt_C_stop  = Cos if is_kernel else sum(split[i][2] for i in range(num_inputs))
                    # mt_C_stop = Cos
                    mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                elif concat_mode ==1:
                    mt_Y_start = 0
                    mt_Y_stop  = split[0][0]
                    mt_X_start = 0
                    mt_X_stop  = sum(split[i][1] for i in range(num_inputs))
                    mt_C_start = 0
                    mt_C_stop  = split[0][2]
                    mm2s = ((mt_Y_start, mt_Y_stop), (mt_X_start, mt_X_stop), (mt_C_start, mt_C_stop))
                else:
                    print("placeholder or Dont support!")
            else:
                mm2s = ((0, 0), (0, 0), (0, 0))
            mt_ofm_mm2s[col].append(mm2s)

    mt_ofm_transfer ={}
    mt_ofm_transfer['mt_ofm_mem'] = mt_ofm_mem
    mt_ofm_transfer['mt_ofm_s2mm'] = mt_ofm_s2mm
    mt_ofm_transfer['mt_ofm_mm2s'] = mt_ofm_mm2s
    return mt_ofm_transfer


def gen_transfers(dims: ConcatDims):
    kernel_subv = generate_split(dims.aie_cols, dims.aie_rows,
                          dims.concat_mode,
                          dims.input_rows, dims.input_cols, dims.input_chs_p,
                          dims.Yis, dims.Xis, dims.Cis, dims.Cos,
                          )
    shim_transfer = generate_shim_split(dims.aie_cols, dims.aie_rows,
                            dims.input_rows, dims.input_cols, dims.input_chs_p, dims.Cos,
                            dims.num_inputs, dims.concat_mode, kernel_subv)

    if dims.is_kernel:
        mt_ifm_transfer  = generate_mt_ifm_split_kernel(
                            dims.aie_cols, dims.aie_rows,
                            dims.concat_mode, dims.num_inputs,
                            dims.Yis, dims.Xis,dims. Cis,
                            shim_transfer['shim_ifm'])
    else:
        if dims.is_qdq:
            mt_ifm_transfer = generate_mt_ifm_split_only_qdq(dims.aie_cols, dims.aie_rows,
                            dims.concat_mode, dims.num_inputs,
                            dims.Yis, dims.Xis, dims.Cis,
                            shim_transfer['shim_ifm'])
        else:
            mt_ifm_transfer = generate_mt_ifm_split_no_kernel(
                               dims.aie_cols, dims.aie_rows,
                                dims.concat_mode, dims.num_inputs,
                                shim_transfer['shim_ifm'])

    mt_ofm_transfer  = generate_mt_ofm_split(
                        dims.aie_cols, dims.aie_rows,
                        dims.concat_mode, dims.num_inputs,
                       dims.Yis, dims.Xis, dims.Cis, dims.Cos,
                       dims.is_kernel,
                        shim_transfer['shim_ifm'])
    return mt_ifm_transfer, mt_ofm_transfer, shim_transfer


def concat_preproc_directives(dims: ConcatDims, back_end: BackEnd) -> List[str]:
    with open("shapes.txt", "w") as file:
        file.write(",".join(map(str, dims.input_rows_orig)) + "\n")
        file.write(",".join(map(str, dims.input_cols_orig)) + "\n")
        file.write(",".join(map(str, dims.input_chs)) + "\n")
        file.write(",".join(map(str, dims.input_types)) + "\n")

    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"

    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive("AIE_ROWS", dims.aie_rows),
        directive("AIE_COLS", dims.aie_cols),
        directive("CONCAT_MODE", dims.concat_mode),
        directive("NUM_INPUTS", dims.num_inputs),
        directive("IS_KERNEL", int(dims.is_kernel)),
        directive("IS_KERNEL_DEPAD", int(dims.is_kernel_depad_available)),
        directive("PADDING_EN", int(dims.padding_enable)),
        directive("QDQ_MODE", dims.qdq_mode),
        directive("INT_16", int(dims.is_int16)),
        directive("KENEL_IS_INT16_CONCAT", int(dims.is_int16_concat)),
        directive("TXN_MODE", txn_mode),
        # directive("IS_CONST_INPUT", int(dims.is_const_input)),
    ]
