import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'conv'))
from typing import List, Tuple, Optional, Type
from conv_params import generate_layer_kernel_params, generate_layer_kernel_params_xint8, generate_add_kernel_params, ConvSubvDims, OPMode
from kernels.conv.direct_conv_int8x8_generic.direct_conv_int8x8_generic_params import params
from typing import Dict

from dmacompiler import \
    CoreInstr, Loop, AcqBuffer, RelBuffer, ConfigBuffer, CallKernel, \
    DmaChannel, DmaDir, BackEnd, \
    set_dev_gen, DevGen, config

from dataflow_common import ceildiv, floordiv
set_dev_gen(DevGen.Aie2p)

class PingPong:
    def __init__(
            self,
            ping: bool = True,
            pong: bool = True,
            ):

        self.ping = ping
        self.pong = pong

    def __str__(self):
        return str(self.ping)[0] + "_" + str(self.pong)[0]

    def __int__(self):
        self_dict = vars(self)
        ret_val = 0
        for _, val in self_dict.items():
            ret_val += int(val)

        return ret_val

    @classmethod
    def from_str(cls, s: str):
        ping_str, pong_str = s.split('_')
        ping = (ping_str == 'T')
        pong = (pong_str == 'T')
        return cls(ping=ping, pong=pong)

class ConvPingPong:
    def __init__(
            self,
            ifm: PingPong = PingPong(True, True),
            ofm: PingPong = PingPong(True, True),
            wgt: PingPong = PingPong(True, True),
            tdm: PingPong = PingPong(True, True),
            ):

        self.ifm = ifm
        self.ofm = ofm
        self.wgt = wgt
        self.tdm = tdm

    def __str__(self):
        self_dict = vars(self)
        ret_val = []
        for key, val in self_dict.items():
            ret_val += [str(key) + ":" + str(val)]

        return " ".join(ret_val)

    def __repr__(self):
        return str(self)

class ConvDims:
    def __init__(
        self,
        aie_cols: int, aie_rows: int,
        Ci: int, Cis: int, Ci_gran: int, Co: int, Cos: int, Co_gran: int, Co_split: int,
        Yi: int, Yis: int, Yo: int, Yos: int,
        Xi: int, Xis: int, Xo: int, Xos: int, X_gran: int, X_align: int, X_split: int,
        Y_loop: int, Co_loop: int, Ci_loop: int, X_loop: int,
        Ky: int, Kx: int,
        Sy: int, Sx: int,
        Py_b: int, Px_b: int, Py_a: int, Px_a: int,
        ifm_bits: int, wgt_bits: int, ofm_bits: int, tdm_bits: int,
        L1_sizes_addrs: Dict[str, int],
        is_standalone_dwc: bool = False,
        ifm_use_hwc_format: bool = True,
        ofm_use_hwc_format: bool = True,
        conv_pp: ConvPingPong = ConvPingPong(),
        enable_L2_IFM: bool = False,
        enable_L2_OFM: bool = False,
        load_input_from_ddr = False,
        store_output_to_ddr = False,
        is_X8_split = False,
        enable_ifm_streaming: bool = False,
        enable_wgt_reuse: bool = False,
        pin_ifm_l1: bool = False,
        pin_wgt_bias_l1: bool = False,
        Com: int = 0,
        Xom: int = 0,
        Cim: int = 0,
        wgt_memtile_size: int = 0,
        num_ifm_subv: int = 0,
        prm_memtile_size: int = 0,
        ifm_memtile_size: int = 0,
        ofm_memtile_size: int = 0,
        conv_kernel_param_size: int = 256,
        param_subv_size: int = 1024,
        mt_co_pack: int = 1,
        num_pack_wgt_subv: int = 1,
        enable_add: bool = False,
        is_xint8: int = 0
    ):

        Y_gran = 4 if is_standalone_dwc else 1
        Ci_block = (Ci_gran * ifm_bits) // 8
        assert Yo == conv_output(Yi, Ky, Sy, Py_b, Py_a)
        assert Xo == conv_output(Xi, Kx, Sx, Px_b, Px_a)
        assert Yos == conv_output(Yis, Ky, Sy, 0, 0)
        assert Xis == iceil(conv_input(Xos, Kx, Sx) * Ci_block, X_align) // Ci_block
        assert (Cis % Ci_gran) == 0
        assert (Cos % Co_gran) == 0 or is_standalone_dwc
        assert (Xos % X_gran) == 0

        assert ((Yos % Y_gran) == 0) or is_standalone_dwc
        if is_standalone_dwc:
            assert Ci == Co
            assert Cis == Cos
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        self.Ci = Ci
        self.Cis = Cis
        self.Ci_gran = Ci_gran
        self.Co = Co
        self.Cos = Cos
        self.Co_gran = Co_gran
        self.Co_split = Co_split
        self.Ci_loop = Ci_loop
        self.Co_loop = Co_loop
        assert ((Ci % Cis) == 0) or (Ci < Cis)
        assert (Co % (Co_split * Cos) == 0) or (Co < (Co_split * Cos))
        self.Yi = Yi
        self.Yis = Yis
        self.Yo = Yo
        self.Yos = Yos
        self.Y_loop = Y_loop
        self.X_loop = X_loop
        self.Xi = Xi
        self.Xis = Xis
        self.Xo = Xo
        self.Xos = Xos
        self.X_gran = X_gran
        self.X_align = X_align
        self.X_split = X_split if not is_X8_split else self.X_loop        #X8 dataflow uses Xsplit for X_loop
        if not is_X8_split:
            assert Xo <= self.X_split * Xos
        #assert self.X_split * self.Co_split == self.aie_rows
        if not is_X8_split:
            assert self.X_split * self.Co_split == self.aie_rows
        else:
            assert self.Co_split == self.aie_rows
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.Py_b = Py_b
        self.Py_a = Py_a
        self.Px_b = Px_b
        self.Px_a = Px_a
        self.ifm_bits = ifm_bits
        self.wgt_bits = wgt_bits
        self.ofm_bits = ofm_bits
        self.tdm_bits = tdm_bits
        self.L1_sizes_addrs = L1_sizes_addrs

        self.enable_ifm_streaming = enable_ifm_streaming
        self.enable_wgt_reuse = enable_wgt_reuse
        self.pin_ifm_l1 = pin_ifm_l1
        self.pin_wgt_bias_l1 = pin_wgt_bias_l1
        self.Com = Com
        self.Xom = Xom
        self.Cim = Cim
        self.wgt_memtile_size = wgt_memtile_size
        self.num_ifm_subv = num_ifm_subv
        self.prm_memtile_size = prm_memtile_size
        self.ifm_memtile_size = ifm_memtile_size
        self.ofm_memtile_size = ofm_memtile_size
        self.conv_kernel_param_size = conv_kernel_param_size
        self.param_subv_size = param_subv_size
        self.mt_co_pack = mt_co_pack
        self.num_pack_wgt_subv = num_pack_wgt_subv

        self.is_standalone_dwc = is_standalone_dwc
        self.ifm_use_hwc_format = ifm_use_hwc_format
        self.ofm_use_hwc_format = ofm_use_hwc_format
        self.conv_pp = conv_pp
        self.is_X8_split = is_X8_split
        self.enable_add = enable_add
        self.is_xint8 = is_xint8

        for key, value in self.L1_sizes_addrs.items():
            setattr(self, key, value)

        self.shim_BD_num = {'ifm': 10, 'wgt': 1, 'ofm': 4, 'prm': 1}

        #NOTE: L2_fusion struct is updated with temporary values and it is used only by L2 dataflow.
        self.L2_fusion = {
                "enable_ifm_L2_fusion" : enable_L2_IFM,
                "input_addr" : 0,        #NOTE DUMMY addr
                "load_input_from_ddr"  : load_input_from_ddr,
                "enable_ofm_L2_fusion" : enable_L2_OFM,
                "store_output_to_ddr"  : store_output_to_ddr,
                "ifm_L2_hwc_format" : False,
                "ofm_L2_hwc_format" : False,
                "stride_mode": True
                }

    def __str__(self):
        self_dict = vars(self)
        ret_val = []
        for key, val in self_dict.items():
            ret_val += [str(key) + " : " + str(val)]

        return "\n" + "\n".join(ret_val)

class CoreAlloc:
    def __init__(
        self,
        ifm_ping_addr: int,
        ifm_pong_addr: Optional[int],
        wgt_ping_addr: int,
        wgt_pong_addr: Optional[int],
        ofm_ping_addr: int,
        ofm_pong_addr: Optional[int],
        tdm_ping_addr: int,
        tdm_pong_addr: int,
        ifm_sum_addr: int,
        scratch_buf: int,
        tmp_buf : int,
        conv_kernelprm_addr: Optional[int],
        add_ifm_addr: Optional[int],
    ):
        self.ifm_ping_addr = ifm_ping_addr
        self.ifm_pong_addr = ifm_pong_addr
        self.wgt_ping_addr = wgt_ping_addr
        self.wgt_pong_addr = wgt_pong_addr
        self.ofm_ping_addr = ofm_ping_addr
        self.ofm_pong_addr = ofm_pong_addr
        self.tdm_ping_addr = tdm_ping_addr
        self.tdm_pong_addr = tdm_pong_addr
        self.ifm_sum_addr = ifm_sum_addr
        self.scratch_buf = scratch_buf
        self.tmp_buf = tmp_buf
        self.conv_kernelprm_addr = conv_kernelprm_addr
        self.add_ifm_addr = add_ifm_addr

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

def ifloor(x: int, d: int) -> int:
    return floordiv(x, d) * d

def conv_output(input: int, kernel: int, stride: int, pad_before: int, pad_after: int) -> int:
    output = ((input + pad_before + pad_after - kernel) // stride) + 1
    return output

def conv_input(output: int, kernel: int, stride: int) -> int:
    input = ((output - 1) * stride) + kernel
    return input

def X_index(dims: ConvDims, row: int) -> int:
    assert 0 <= row < dims.aie_rows
    return row % dims.X_split

def Co_index(dims: ConvDims, row: int) -> int:
    assert 0 <= row < dims.aie_rows
    return row // (dims.aie_rows // dims.Co_split)

def Xi_slice(dims: ConvDims, row: int) -> Tuple[int, int, int, int]:
    Xi_stride = dims.Xos * dims.Sx
    Xi_start = (X_index(dims, row) * Xi_stride) - dims.Px_b
    Xi_stop = (
        min(Xi_start + conv_input(dims.Xos, dims.Kx, dims.Sx),
            dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
        Xi_start
    )
    Xi_size = Xi_stop - Xi_start
    return (Xi_start, Xi_stop, Xi_stride, Xi_size)

def Xo_slice(dims: ConvDims, row: int) -> Tuple[int, int, int, int]:
    Xo_stride = dims.Xos
    Xo_start = X_index(dims, row) * Xo_stride
    Xo_stop = (
        min(Xo_start + Xo_stride, dims.Xo) if Xo_start < dims.Xo else
        Xo_start
    )
    Xo_size = Xo_stop - Xo_start
    return (Xo_start, Xo_stop, Xo_stride, Xo_size)

def Co_slice(dims: ConvDims, row: int) -> Tuple[int, int, int, int]:
    Co_stride = dims.Cos * dims.Co_split
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = min(Co_start + dims.Cos, dims.Co)
    Co_size = Co_stop - Co_start
    return (Co_start, Co_stop, Co_stride, Co_size)

def ifm_core_memory(dims: ConvDims) -> str:
    return f'Yi:{dims.Yis} Ci:{dims.Cis} Xi:{dims.Xis} Ci:{dims.Ci_gran}'

def ifm_core_s2mm(dims: ConvDims, row: int) -> str:
    if dims.is_X8_split:
        Xi_size = dims.Xis
    else:
        _, _, _, Xi_size = Xi_slice(dims, row)
    return f'Yi:0:{dims.Yis} Ci:0:{dims.Cis}:{dims.Ci_gran} Xi:0:{Xi_size} Ci:0:{dims.Ci_gran}'

def ofm_core_memory(dims: ConvDims) -> str:
    return f'Yo:{dims.Yos} Co:{dims.Cos} Xo:{dims.Xos} Co:{dims.Co_gran}'

def ofm_core_mm2s(dims: ConvDims, row: int) -> str:
    _, _, _, Xo_size = Xo_slice(dims, row)
    _, _, _, Co_size = Co_slice(dims, row)
    return f'Yo:0:{dims.Yos} Co:0:{Co_size}:{dims.Co_gran} Xo:0:{Xo_size} Co:0:{dims.Co_gran}'

def conv_preproc_directives(dims: ConvDims, back_end: BackEnd) -> List[str]:
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_ROWS', dims.aie_rows),
        directive('AIE_COLS', dims.aie_cols),
        directive('C_IN', dims.Ci),
        directive('C_IN_SUBV', dims.Cis),
        directive('XIS', dims.Xis),
        directive('YIS', dims.Yis),
        directive('Y_IN', dims.Yi),
        directive('X_IN', dims.Xi),
        directive('C_OUT', dims.Co),
        directive('C_OUT_SUBV', dims.Cos),
        directive('XOS', dims.Xos),
        directive('YOS', dims.Yos),
        directive('C_OUT_SPLIT', dims.Co_split),
        directive('Y_OUT', dims.Yo),
        directive('X_OUT', dims.Xo),
        directive('KERNEL_Y', dims.Ky),
        directive('KERNEL_X', dims.Kx),
        directive('STRIDE_Y', dims.Sy),
        directive('STRIDE_X', dims.Sx),
        directive('PAD_Y_BEFORE', dims.Py_b),
        directive('PAD_X_BEFORE', dims.Px_b),
        directive('PAD_Y_AFTER', dims.Py_a),
        directive('PAD_X_AFTER', dims.Px_a),
        directive('IFM_IS_HWC', int(dims.ifm_use_hwc_format)),
        directive('OFM_IS_HWC', int(dims.ofm_use_hwc_format)),
        directive('TXN_MODE', txn_mode),
        directive('DIRECT_CONV_INT16X8_GENERIC_HAS_CONV', 1),
        directive('DIRECT_CONV_INT16X8_GENERIC_HAS_DWC', 1),
    ]

def conv_core_alloc(dims: ConvDims, stack_addr: int) -> CoreAlloc:

    core_alloc = CoreAlloc(
        dims.ifm_ping_addr, dims.ifm_pong_addr,
        dims.wgt_ping_addr, dims.wgt_pong_addr,
        dims.ofm_ping_addr, dims.ofm_pong_addr,
        dims.tdm_ping_addr, dims.tdm_pong_addr,
        dims.ifm_sum_addr, dims.scratch_buf,
        dims.tmp_buf,
        dims.conv_kernelprm_addr, dims.add_ifm_addr,
    )
    print(f"core_alloc = {vars(core_alloc)}")
    return core_alloc

def conv_a16w8_qdq_kernel_name() -> str:
    return 'run_conv_a16w8_qdq'

def conv_a8w8_qdq_kernel_name() -> str:
    return 'run_conv_a8w8_qdq'

def conv_xint8_kernel_name() -> str:
    return 'run_conv_xint8'

def add_kernel_name() -> str:
    return 'run_matadd'

def conv_xint8_params(
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    is_depthwise: bool,
    first_tdm: bool,
    final_tdm: bool,
    core_alloc: CoreAlloc,
    ifm_flag: bool,
    Tk: int
) -> bytes:
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    Cis, _, _ = input
    Cos, Yos, Xos = output
    Ky, Kx = kernel
    Sy, Sx = stride
    op_mode = OPMode.OP_DWC_ASYM if is_depthwise else OPMode.OP_CONV_ASYM
    params_blob = generate_layer_kernel_params_xint8(
        first_tdm,
        final_tdm,
        core_alloc.tdm_ping_addr,
        core_alloc.tdm_pong_addr,
        core_alloc.conv_kernelprm_addr,
        core_alloc.ofm_ping_addr,
        ifm_flag,
        ConvSubvDims(
            1,
            Yos, Y_gran,
            Xos, X_gran,
            Cos, Co_gran,
            Cis, Ci_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
        ),
        core_alloc.ifm_ping_addr if (Tk % 2 != 0 or core_alloc.ifm_pong_addr == None) else core_alloc.ifm_pong_addr,
    )
    return params_blob

def conv_a16w8_qdq_params(
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    is_depthwise: bool,
    first_tdm: bool,
    final_tdm: bool,
    tdm_1_addr: int,
    tdm_2_addr: int,
    ifm_sum_addr: int,
    scratch_buf: int,
    tmp_buf: int,
) -> bytes:
    #TODO fetch from kernel subvol_constraints
    Y_gran = 1
    X_gran = 8
    Co_gran = 8
    Ci_gran = 8
    size_bytes = 2
    stride_efficiency = 0.5
    mem_align = 64
    Cis, _, _ = input
    Cos, Yos, Xos = output
    Ky, Kx = kernel
    Sy, Sx = stride
    op_mode = OPMode.OP_DWC_ASYM if is_depthwise else OPMode.OP_CONV_ASYM
    params_blob = generate_layer_kernel_params(
        first_tdm,
        final_tdm,
        tdm_1_addr,
        tdm_2_addr,
        ifm_sum_addr,
        scratch_buf,
        tmp_buf,
        ConvSubvDims(
            1,
            Yos, Y_gran,
            Xos, X_gran,
            Cos, Co_gran,
            Cis, Ci_gran,
            Ky, Kx,
            Sy, Sx,
            op_mode,
            size_bytes,
            stride_efficiency,
            mem_align,
        )
    )
    return params_blob

def conv_a8w8_qdq_params(
    dims : ConvDims,
    first_tdm: bool,
    final_tdm: bool,
    tdm_1_addr: int,
    tdm_2_addr: int,
    ifm_sum_addr: int,
    scratch_buf: int,
    op_mode : int
) -> bytes:
    dummy = 0
    opmode_map = {
        (True, 12): "dwc_sym",
        (True, 14): "dwc_asym",
        (False, 9): "sym",
        (False, 11): "asym"
    }

    opmode = opmode_map.get((dims.is_standalone_dwc, op_mode), None)

    templates = {
        "has_dwc": int(dims.is_standalone_dwc),
        "has_conv": int(not dims.is_standalone_dwc),
        "has_sum": 0 if op_mode == 9 or op_mode == 12 else 1,
        "has_vector_coeffs": 0
    }
    parameters = {
        "subvolume": {
            "H": dims.Yos,
            "W": dims.Xos,
            "Co": dims.Cos,
            "Ci": dims.Cis,
            "Kh": dims.Ky,
            "Kw": dims.Kx,
            "Sh": dims.Sy,
            "Sw": dims.Sx,
            "Dh": 1,
            "Dw": 1
        },
        "op_mode": opmode,
        "dtype": {
            "I0": "uint8",
            "I1": "uint8",
            "O0": "uint8"
        },
        "transpose": {
            "I0": 0,
            "I1": 0
        },
        "quantization_coeffs": {
            "shift_res": 0,
            "zp_wght" : 0,
            "vector_coeffs": 0,
            "qdq_c0": 0,
            "qdq_c1": 0,
            "qdq_c2": 0,
            "qdq_c3": 0
        }
    }

    mem_align = 64
    weight_size = iceil(dims.Cos * dims.Ky * dims.Kx, mem_align) if dims.is_standalone_dwc \
        else iceil(dims.Cos * dims.Cis * dims.Ky * dims.Kx, mem_align)

    Yi = (dims.Yos - 1) * dims.Sy + (dims.Ky - 1) + 1
    Xi = (dims.Xos - 1) * dims.Sx + (dims.Kx - 1) + 1

    kernel_params = params(templates, parameters, Yi, Xi)
    kernel_param_padding = b'\x00' * (166 - len(kernel_params))
    kernel_param_blob = kernel_params + kernel_param_padding
    layer_param_blob = (
        first_tdm.to_bytes(length= 1, byteorder='little', signed=False)
        + final_tdm.to_bytes(length=1, byteorder='little', signed=False)
        + weight_size.to_bytes(length=2, byteorder='little', signed=False)
        + (dims.Cos).to_bytes(length=2, byteorder='little', signed=False)
        + tdm_1_addr.to_bytes(length = 2, byteorder='little', signed=False)
        + tdm_2_addr.to_bytes(length = 2, byteorder='little', signed=False)
        + ifm_sum_addr.to_bytes(length = 2, byteorder='little', signed=False)
        + scratch_buf.to_bytes(length = 2, byteorder='little', signed=False)
        + op_mode.to_bytes(length = 2, byteorder='little', signed=False)
        + dummy.to_bytes(length = 2, byteorder='little', signed=False)
    )
    bin_blob = layer_param_blob + kernel_param_blob
    return bin_blob

def add_call_kernel(
    dims: ConvDims,
    core_alloc: CoreAlloc,
) -> CallKernel:
    kernel_name = add_kernel_name()
    params_blob = generate_add_kernel_params(
        core_alloc.conv_kernelprm_addr,
        core_alloc.ofm_ping_addr,
        dims.ofm_subv_size,
    )
    return CallKernel(kernel_name, kernel_params=params_blob)

def conv_call_kernel(
    dims: ConvDims,
    core_alloc: CoreAlloc,
    first_tdm: bool,
    final_tdm: bool,
    ifm_flag: Optional[bool] = False,
    acc_loop: Optional[int] = 1,
) -> CallKernel:
    if (dims.ifm_bits == 16) and (dims.wgt_bits == 8):
        kernel_name = conv_a16w8_qdq_kernel_name()
        params_blob = conv_a16w8_qdq_params(
            (dims.Cis, dims.Yis, dims.Xis),
            (dims.Cos, dims.Yos, dims.Xos),
            (dims.Ky, dims.Kx),
            (dims.Sy, dims.Sx),
            dims.is_standalone_dwc,
            first_tdm,
            final_tdm,
            core_alloc.tdm_ping_addr,
            core_alloc.tdm_pong_addr,
            core_alloc.ifm_sum_addr,
            core_alloc.scratch_buf,
            core_alloc.tmp_buf
        )
    elif dims.is_xint8:
        kernel_name = conv_xint8_kernel_name()
        params_blob = conv_xint8_params(
            (dims.Cis, dims.Yis, dims.Xis),
            (dims.Cos, dims.Yos, dims.Xos),
            (dims.Ky, dims.Kx),
            (dims.Sy, dims.Sx),
            dims.is_standalone_dwc,
            first_tdm,
            final_tdm,
            core_alloc,
            ifm_flag,
            acc_loop,
        )

    else:
        kernel_name = conv_a8w8_qdq_kernel_name()
        opmode = 14 if dims.is_standalone_dwc else 11 # ASYM
        params_blob = conv_a8w8_qdq_params(
            dims,
            first_tdm,
            final_tdm,
            core_alloc.tdm_ping_addr,
            core_alloc.tdm_pong_addr,
            core_alloc.ifm_sum_addr,
            core_alloc.scratch_buf,
            opmode
        )
    return CallKernel(kernel_name, kernel_params=params_blob)

def conv_core_instrs(
    dims: ConvDims,
    core_alloc: CoreAlloc,
    outer_loop: int,
    inner_loop: int,
    acc_loop: int,
    ifm_config: Optional[ConfigBuffer] = None,
    wgt_config: Optional[ConfigBuffer] = None,
    ofm_config: Optional[ConfigBuffer] = None,
    L2_fusion: Optional[bool] = False,
) -> List[Type[CoreInstr]]:


    if ifm_config is None:
        ifm_config = ConfigBuffer(
            DmaChannel(DmaDir.S2MM, 0),
            core_alloc.ifm_ping_addr, core_alloc.ifm_pong_addr, dims.ifm_subv_size
        )
    if wgt_config is None:
        wgt_config = ConfigBuffer(
            DmaChannel(DmaDir.S2MM, 1),
            core_alloc.wgt_ping_addr, core_alloc.wgt_pong_addr, dims.wgt_subv_size
        )
    if ofm_config is None:
        ofm_config = ConfigBuffer(
            DmaChannel(DmaDir.MM2S, 0),
            (core_alloc.tdm_ping_addr if dims.enable_add else core_alloc.ofm_ping_addr), core_alloc.ofm_pong_addr, dims.ofm_subv_size
        )
    conv_param_config = [ConfigBuffer(
        DmaChannel(DmaDir.S2MM, 1),
        core_alloc.conv_kernelprm_addr, None, 256
        ),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1))
        ]

    if dims.is_xint8:
        print ("conv_kernel_param transfer is enabled")

    single_tdm_call = conv_call_kernel(dims, core_alloc, True, True, dims.enable_add, acc_loop)
    first_tdm_call  = conv_call_kernel(dims, core_alloc, True, False)
    middle_tdm_call = conv_call_kernel(dims, core_alloc, False, False)
    final_tdm_call  = conv_call_kernel(dims, core_alloc, False, True, dims.enable_add, acc_loop)
    add_param_call = add_call_kernel(dims, core_alloc)
    print("acc_loop", acc_loop)
    print("outer_loop", outer_loop)
    print("inner_loop", inner_loop)
    print("is_xint8", dims.is_xint8==1)

    core_instrs = conv_param_config if dims.is_xint8 else []
    num_buffers = 1 if core_alloc.ifm_pong_addr is not None else 2
    loop_constraint = (acc_loop * inner_loop * outer_loop <= (config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE * num_buffers))

    if not dims.enable_add:
        if acc_loop == 1:
            if (inner_loop == 1) and (dims.pin_wgt_bias_l1) : # NOTE: Pinned wgt in L1 so as to reuse it for whole Y_loop iteration
                core_instrs += [
                    ifm_config,
                    wgt_config,
                    ofm_config,
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    Loop(outer_loop, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        single_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                ]
            elif (outer_loop == 1) and (dims.pin_ifm_l1) : # NOTE: Pinned ifm in L1 so as to reuse it for whole Co_loop iteration
                core_instrs += [
                        ifm_config,
                        wgt_config,
                        ofm_config,
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        Loop(inner_loop, [
                            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            single_tdm_call,
                            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        ]),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    ]
            else:
                core_instrs += [
                    ifm_config,
                    wgt_config,
                    ofm_config,
                    Loop(outer_loop, [
                        Loop(inner_loop, [
                            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            single_tdm_call,
                            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        ]),
                    ]),
                ] if (loop_constraint) else [
                    Loop(outer_loop, [
                        ifm_config,
                        wgt_config,
                        ofm_config,
                        Loop(inner_loop, [
                            AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                            AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            single_tdm_call,
                            RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                            RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                            RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        ]),
                    ]),
                ]

        elif acc_loop == 2:
            core_instrs += [
                ifm_config,
                wgt_config,
                ofm_config,
                Loop(outer_loop, [
                    Loop(inner_loop, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ] if (loop_constraint) else [
                Loop(outer_loop, [
                    ifm_config,
                    wgt_config,
                    ofm_config,
                    Loop(inner_loop, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ]
        elif acc_loop == 3:
            core_instrs += [
                ifm_config,
                wgt_config,
                ofm_config,
                Loop(outer_loop, [
                    Loop(inner_loop, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        middle_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ] if (loop_constraint) else [
                Loop(outer_loop, [
                    ifm_config,
                    wgt_config,
                    ofm_config,
                    Loop(inner_loop, [
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        middle_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ]
        else:
            inner_loop_body = [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                first_tdm_call,
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                Loop(acc_loop - 2, [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    middle_tdm_call,
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                ]),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                final_tdm_call,
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ]
            if (loop_constraint):
                core_instrs += [
                    ifm_config,
                    wgt_config,
                    ofm_config,
                    Loop(outer_loop, [
                        Loop(inner_loop, inner_loop_body),
                    ]),
                ]
            else:
                core_instrs += [
                    Loop(outer_loop, [
                        ifm_config,
                        wgt_config,
                        ofm_config ,
                        Loop(inner_loop, inner_loop_body),
                    ]),
                ]
    else: #MATADD KERNEL CORE INSTRUCTIONS
        if acc_loop == 1:
            core_instrs += [
                wgt_config,
                ofm_config,
                Loop(outer_loop, [
                    Loop(inner_loop, [
                        ifm_config,
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        ConfigBuffer( DmaChannel(DmaDir.S2MM, 0), core_alloc.add_ifm_addr, None, dims.ofm_subv_size),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        single_tdm_call,

                        #add kernel support
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        add_param_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    ]),
                ]),
            ] if (loop_constraint) else [
                Loop(outer_loop, [
                    wgt_config,
                    ofm_config,
                    Loop(inner_loop, [
                        ifm_config,
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        ConfigBuffer( DmaChannel(DmaDir.S2MM, 0), core_alloc.add_ifm_addr, None, dims.ofm_subv_size),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        single_tdm_call,

                        #add kernel support
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        add_param_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    ]),
                ]),
            ]
        elif acc_loop == 2:
            core_instrs += [
                wgt_config,
                ofm_config,
                Loop(outer_loop, [
                    Loop(inner_loop, [
                        ifm_config,
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        ConfigBuffer( DmaChannel(DmaDir.S2MM, 0), core_alloc.add_ifm_addr, None, dims.ofm_subv_size),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                        #add kernel support
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        add_param_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ] if (loop_constraint) else [
                Loop(outer_loop, [
                    wgt_config,
                    ofm_config,
                    Loop(inner_loop, [
                        ifm_config,
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        first_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        ConfigBuffer( DmaChannel(DmaDir.S2MM, 0), core_alloc.add_ifm_addr, None, dims.ofm_subv_size),
                        final_tdm_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                        #add kernel support
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        add_param_call,
                        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    ]),
                ]),
            ]
        else:
            inner_loop_body = [
                ifm_config,
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                first_tdm_call,
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                Loop(acc_loop - 2, [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                    middle_tdm_call,
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                ]),

                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),

                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                ConfigBuffer( DmaChannel(DmaDir.S2MM, 0), core_alloc.add_ifm_addr, None, dims.ofm_subv_size),
                final_tdm_call,
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),

                #add kernel support
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                add_param_call,
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ]
            if (loop_constraint):
                core_instrs += [
                    wgt_config,
                    ofm_config,
                    Loop(outer_loop, [
                        Loop(inner_loop, inner_loop_body),
                    ]),
                ]
            else:
                core_instrs += [
                    Loop(outer_loop, [
                        wgt_config,
                        ofm_config,
                        Loop(inner_loop,inner_loop_body),
                    ]),
                ]
    return core_instrs

