import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
from template_base_gemm import Gemm_base, GemmDims
import scheduler_utils as utils

class Overlay_4x4_base(Gemm_base):
    def overlay_param(self, params):
        params['row_list'] = [0,1,2,3]
        params['broadcast_idx'] = [[0,1,2,3],[0,1,2,3]]
        params['unicast_idx']   = [[0,1,2,3],[0,1,2,3]]
        if params['SpatialSplit'][0] > params['SpatialSplit'][2]:
            params['ifm_idx'] = params['unicast_idx']
            params['wgt_idx'] = params['broadcast_idx']
            params['ofm_idx'] = params['unicast_idx']
            params['ifm_channel']   = 'unicast'
            params['wgt_channel']   = 'broadcast'
        else:
            params['ifm_idx'] = params['broadcast_idx']
            params['wgt_idx'] = params['unicast_idx']
            params['ofm_idx'] = params['unicast_idx']
            params['ifm_channel']   = 'broadcast'
            params['wgt_channel']   = 'unicast'


class M4N4(Overlay_4x4_base):
    #all tiling formats are in tiling_funcs class
    #act_shim_memory
    #act_shim_mm2s
    #act_memtile_memory
    #act_memtile_s2mm
    #act_memtile_mm2s
    #wgt_shim_memory
    #wgt_shim_mm2s
    #wgt_memtile_memory
    #wgt_memtile_s2mm
    #wgt_memtile_mm2s
    #out_shim_memory
    #out_shim_s2mm
    #out_memtile_memory
    #out_memtile_s2mm
    #out_memtile_mm2s
    def set_spatialsplit_params(self, params):
        params['ShimParamMode']      = 'broadcast'
        params['ShimQdqMode']        = 'unicast'
        params['ShimParamChannelId'] = 1
        params['param_channel_id']   = 1
        params['ShimQdqChannelId']   = 0
        params['CoreQdqChId']        = 0

class M1N16(Overlay_4x4_base):
    def set_spatialsplit_params(self, params):
        params['shim_wgt_reengueue'] = True
        self.set_kernelmode_prm(params)
        self.set_in_ch_mode(params)
        params['ShimParamMode']      = 'broadcast'
        params['ShimQdqMode']        = 'unicast'
        params['ShimParamChannelId'] = 1
        params['param_channel_id']   = 1
        params['ShimQdqChannelId']   = 0
        params['CoreQdqChId']        = 0
    #act_shim_memory
    #act_memtile_memory
    #act_memtile_mm2s

    def act_shim_mm2s(self, params, col: int = 0, col_list: int = [], itr: int = 0, bd_last:int = False, pad_last:int = False, Mpad=0, Kpad=0, batch: int = 0) -> str:
        dims = params['dims']
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]: #TBD
            H_stride = dims.M_subv
            H_start = dims.M_subv
            H_stop = dims.M_subv - Mpad
            return f'H:0:{dims.M}:{H_stride} W:0:{dims.K}:{dims.K_subv} H:0:{H_stop} W:0:{dims.K_subv}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            return f'H:0:{dims.M-Mpad} W:0:{dims.K}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:          #TBD
            raise RuntimeError('BFP with Activation stream mode is not supported!')
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",     "BFP_A8"]:
            H_stride = 4 * dims.M_subv
            H_start = col * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"H:0:{dims.M}:{H_stride} W:0:{dims.K}:{params['min_cols']} H:{H_start}:{H_stop} W:0:{params['min_cols']}"

    def act_memtile_s2mm(self, params, col=0, col_list=[0,1,2,3], Mpad=0, min_cols=8, batch=0) -> str:
        dims = params['dims']
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]:
            if Mpad == 0:
                return f'H:0:{dims.M_subv} W:0:{dims.K_subv}'
            else:
                return f'H:0:{dims.M-Mpad} W:0:{dims.K_subv}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            if Mpad == 0:
                return f'H:0:{dims.M_subv} W:0:{dims.K}'
            else:
                return f'H:0:{dims.M-Mpad} W:0:{dims.K}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            return f'H:0:{dims.M_subv-Mpad} W:0:{dims.K_subv}'                      #TBD
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "BFP_A8"]:
            return f'W:0:{dims.K}:{min_cols} H:0:{dims.M_subv} W:0:{min_cols}'

    #wgt_shim_memory

    def wgt_shim_mm2s(self, params, array: int=0, col: int=0, Kpad: int = 0, Npad: int = 0, itr: int=0, batch: int = 0) -> str:
        dims = params['dims']
        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                H_stride = dims.aie_cols * dims.N_subv
                H_start = col * dims.N_subv
                H_stop = H_start + dims.N_subv
                return f'H:0:{dims.K} W:0:{dims.N}:{H_stride} W:{H_start}:{H_stop}'
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                H_stride = dims.aie_cols * dims.N_subv
                H_start = col * dims.N_subv
                H_stop = H_start + dims.N_subv
                return f'H:0:{dims.K} W:0:{dims.N}:{H_stride} W:{H_start}:{H_stop}'
        else:
            if params['WgtDataFlow'] == 'full':
                col_start = col * 4
                col_stop = col_start + 4
                col_stride = dims.aie_arrays * dims.aie_cols
                return (
                    f'Cols:0:{dims.wgt_subv_cols}:{dims.aie_cols * dims.aie_rows} ' +
                    f'Cols:{col_start}:{col_stop} ' +
                    f'Rows:0:{dims.wgt_subv_rows} ' +
                    f'Bytes:0:{dims.wgt_subv_bytes}'
                )
            elif params['WgtDataFlow'] == 'stream':
                col_start = col * 4 + itr * dims.aie_cols * dims.aie_rows
                col_stop = col_start + 4
                col_stride = dims.aie_arrays * dims.aie_cols
                return (
                    f'Rows:0:{dims.wgt_subv_rows} ' +
                    f'Cols:{col_start}:{col_stop} ' +
                    f'Bytes:0:{dims.wgt_subv_bytes}'
                )

    def wgt_memtile_memory(self, params, M_itr, N_itr, K_itr) -> str:
        dims = params['dims']
        if params['actxact']:
            if params['WgtDataFlow'] == "full":
                return f"W:{dims.N//dims.aie_cols} H:{dims.K} W:{params['min_cols']}"
            if params['WgtDataFlow'] == "stream":
                return f"W:{dims.N_subv} H:{dims.K_subv} W:{params['min_cols']}"
        else:
            if params['WgtDataFlow'] == "stream":
                return f'Cols:{dims.aie_rows} Bytes:{dims.wgt_subv_bytes}'
            if params['WgtDataFlow'] == "full":
                num_cols = dims.wgt_subv_cols // (dims.aie_cols * dims.aie_arrays)
                return f'Cols:{dims.aie_rows} Rows:{dims.wgt_subv_rows} Bytes:{dims.wgt_subv_bytes}'

    def wgt_memtile_mm2s(self, params, row, col) -> str:
        dims = params['dims']
        mode = params['WgtDataFlow']

        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                return f"H:0:{dims.K}:{dims.K_subv} W:0:{dims.N//dims.aie_cols}:{params['min_cols']} H:0:{dims.K_subv} W:0:{params['min_cols']}"
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                return f"W:0:{dims.N_subv}:{params['min_cols']} H:0:{dims.K_subv} W:0:{params['min_cols']}"
        else:
            num_cols = dims.wgt_subv_cols // (dims.aie_cols * dims.aie_arrays)
            if mode=="stream":
                return f'Cols:{row}:{row+1} Bytes:0:{dims.wgt_subv_bytes}'
            else:
                return f'Cols:{row}:{row+1} Rows:0:{dims.wgt_subv_rows} Bytes:0:{dims.wgt_subv_bytes}'

    def wgt_memtile_s2mm(self, params, col: int, col_list=[], Kpad: int = 0, Npad: int = 0, batch: int = 0) -> str:
        dims = params['dims']
        mode = params['WgtDataFlow']

        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                return f'H:0:{dims.K} W:0:{dims.N//dims.aie_cols}'
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                return f'H:0:{dims.K_subv} W:0:{dims.N_subv}'
        else:
            num_cols = dims.wgt_subv_cols // (dims.aie_cols * dims.aie_arrays)
            if mode=="stream":
                return f'Cols:0:{dims.aie_rows} Bytes:0:{dims.wgt_subv_bytes}'
            else:
                return f'Cols:0:{num_cols} Rows:0:{dims.wgt_subv_rows} Bytes:0:{dims.wgt_subv_bytes}'

    #out_shim_memory

    def out_memtile_memory(self, params) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            return f'H:{dims.M_subv} W:{dims.N_subv * dims.aie_rows}'
        if params['OfmFormat'] == "BFP_A8":
            return f"H:{dims.M_subv * dims.aie_rows} W:{dims.N_subv} H:{dims.M_subv} W:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            return f"H:{dims.M_subv * dims.aie_rows} W:{dims.N_subv}"

    def out_memtile_s2mm(self, params, row: int, col: int, batch=0) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            W_start = row * dims.N_subv
            W_stop = W_start + dims.N_subv
            return f"W:{W_start}:{W_stop}:8 H:0:{dims.M_subv} W:0:8"
        if params['OfmFormat'] == "BFP_A8":
            H_start = row * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:{H_start}:{H_stop} W:0:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": 
            H_start = row * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:{H_start}:{H_stop} W:0:{params['min_cols']}"
        
    def out_memtile_mm2s(self, params, col, Mpad=0, batch=0) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            return f'H:0:{dims.M_subv-Mpad} W:0:{dims.N_subv * dims.aie_rows}'
        if params['OfmFormat'] == "BFP_A8":
            return f"W:0:{dims.N_subv}:{params['min_cols']} H:0:{dims.M_subv*dims.aie_rows} W:0:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8":
            return f'H:0:{dims.M_subv * dims.aie_rows} W:0:{dims.N_subv}' #same as default

    def out_shim_s2mm(self, params, array: int, col: int, Munpad=0, last=0, batch: int = 0) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            H_start = 0
            H_stop = H_start + (dims.M_subv * dims.aie_rows)
            H_stride = dims.M_subv * dims.aie_rows
            W_start = self.col_index(dims, array, col) * dims.N_subv * dims.aie_rows
            W_stop = W_start + dims.N_subv * dims.aie_rows
            W_stride = dims.N_subv * dims.aie_cols * dims.aie_rows
            return f'H:0:{dims.M}:{dims.M_subv} W:0:{dims.N}:{W_stride} H:0:{dims.M_subv-Munpad} W:{W_start}:{W_stop}'
        if params['OfmFormat'] == "BFP_A8":
            H_stride = dims.M_subv* dims.aie_rows
            H_start = col * H_stride
            H_stop = H_start + H_stride
            W_stride = dims.N_subv
            W_start = col * W_stride
            W_stop  = W_start + W_stride 
            return f"H:0:{dims.M}:{H_stride} W:0:{dims.N}:{dims.N_subv*dims.aie_cols} W:{W_start}:{W_stop}:{params['min_cols']} H:0:{H_stride} W:0:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            H_start = 0
            H_stop = H_start + (dims.M_subv * dims.aie_rows)
            H_stride = dims.M_subv * dims.aie_rows
            W_start = self.col_index(dims, array, col) * dims.N_subv
            W_stop = W_start + dims.N_subv
            W_stride = dims.N_subv * dims.aie_cols * dims.aie_arrays
            return f'H:0:{dims.M}:{H_stride} W:0:{dims.N}:{W_stride} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'

class M16N1(Overlay_4x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode'] = 1
        params['in_ch_mode'] = 1
        params['MemtileWgtchid'] = 0
        params['MemtileActchid'] = 1
        params['ShimParamMode']      = 'broadcast'
        params['ShimQdqMode']        = 'unicast'
        params['ShimParamChannelId'] = 1
        params['param_channel_id']   = 1
        params['ShimQdqChannelId']   = 0
        params['CoreQdqChId']        = 0
    #act_shim_mm2s
    #act_memtile_memory

    def act_shim_mm2s(self, params, col: int = 0, col_list: int = [], itr: int = 0, bd_last:int = False, pad_last:int = False, Mpad=0, Kpad=0, batch: int = 0) -> str:
        dims = params['dims']
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]: #TBD
            H_stride = dims.M_subv * dims.aie_rows * dims.aie_cols
            H_start = dims.M_subv * dims.aie_rows * col
            H_stop = H_start + dims.M_subv * dims.aie_rows
            return f'H:0:{dims.M}:{H_stride} W:0:{dims.K}:{dims.K_subv} H:{H_start}:{H_stop} W:0:{dims.K_subv}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            H_stride = dims.M_subv*dims.aie_rows*dims.aie_cols
            H_start = col*dims.M_subv*dims.aie_rows
            H_stop = H_start + dims.M_subv*dims.aie_rows
            return f'H:0:{dims.M}:{H_stride} H:{H_start}:{H_stop} W:0:{dims.K-Kpad}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:          #TBD
            raise RuntimeError('BFP with Activation stream mode is not supported!')
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",     "BFP_A8"]:
            H_stride = 4 * dims.M_subv
            H_start = col * dims.M_subv
            H_stop = H_start + dims.M_subv
            return f"H:0:{dims.M}:{H_stride} W:0:{dims.K}:{params['min_cols']} H:{H_start}:{H_stop} W:0:{params['min_cols']}"

    def act_memtile_memory(self, params, col=0, col_list=[0,1,2,3], Mpad=0, Kpad=0) -> str:
        dims = params['dims']
        if params['ActDataFlow'] == 'stream':
            return f'W:{dims.K_subv} H:{dims.M_subv*dims.aie_rows} W:8'
        else:
            return f'W:{dims.K-Kpad} H:{dims.M_subv*dims.aie_rows-Mpad} W:8'

    def act_memtile_s2mm(self, params, col=0, col_list=[0,1,2,3], Mpad=0, min_cols=8, batch=0) -> str:
        dims = params['dims']
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "default"]:
            return f'H:0:{dims.M_subv*dims.aie_rows} W:0:{dims.K_subv}'
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "default"]:
            return f'H:0:{dims.M_subv*dims.aie_rows-Mpad} W:0:{dims.K}'
        if [params['ActDataFlow'], params['ActFormat']] == ["stream", "BFP_A8"]:
            raise RuntimeError('BFP with Activation stream mode is not supported!')
            return f'H:0:{dims.M_subv-Mpad} W:0:{dims.K_subv}'                      #TBD
        if [params['ActDataFlow'], params['ActFormat']] == ["pin",    "BFP_A8"]:
            return f'W:0:{dims.K}:{min_cols} H:0:{dims.M_subv} W:0:{min_cols}'

    def act_memtile_mm2s(self, params, col, row, col_list=[0,1,2,3], Mpad=0, Kpad=0, batch=0) -> str:
        dims = params['dims']
        if params['ActDataFlow'] == "stream":
            H_start = dims.M_subv * row
            H_stop  = H_start + dims.M_subv
            return f'W:0:{dims.K_subv}:8 H:{H_start}:{H_stop} W:0:8'
        if params['ActDataFlow'] == "pin":
            H_start = row * dims.M_subv
            H_stop = H_start + dims.M_subv - Mpad
            return f'W:0:{dims.K+Kpad}:8 H:{H_start}:{H_stop} W:0:8'

    #wgt_shim_memory

    def wgt_shim_mm2s(self, params, array: int=0, col: int=0, Kpad: int = 0, Npad:int = 0, itr: int=0, batch: int = 0) -> str:
        dims = params['dims']
        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                H_stride = dims.aie_cols * dims.N_subv
                H_start = col * dims.N_subv
                H_stop = H_start + dims.N_subv
                return f'H:0:{dims.K} W:0:{dims.N}'
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                H_stride = dims.aie_cols * dims.N_subv
                H_start = col * dims.N_subv
                H_stop = H_start + dims.N_subv
                return f'W:0:{dims.N}:{dims.N_subv} H:0:{dims.K} W:0:{dims.N_subv}'
        else:
            if params['WgtDataFlow'] == 'full':
                return (
                    f'Cols:0:{dims.wgt_subv_cols} ' +
                    f'Rows:0:{dims.wgt_subv_rows} ' +
                    f'Bytes:0:{dims.wgt_subv_bytes}'
                )
            elif params['WgtDataFlow'] == 'stream':
                return (
                    f'Rows:0:{dims.wgt_subv_rows} ' +
                    f'Cols:0:{dims.wgt_subv_cols} ' +
                    f'Bytes:0:{dims.wgt_subv_bytes}'
                )

    def wgt_memtile_memory(self, params, M_itr, N_itr, K_itr) -> str:
        dims = params['dims']
        if params['actxact']:
            if params['WgtDataFlow'] == "full":
                return f"W:{dims.N} H:{dims.K} W:{params['min_cols']}"
            if params['WgtDataFlow'] == "stream":
                return f"W:{dims.N_subv} H:{dims.K_subv} W:{params['min_cols']}"
        else:
            if params['WgtDataFlow'] == "stream":
                return f'Bytes:{dims.wgt_subv_bytes}'
            if params['WgtDataFlow'] == "full":
                return f'Cols:{dims.wgt_subv_cols} Rows:{dims.wgt_subv_rows} Bytes:{dims.wgt_subv_bytes}'

    def wgt_memtile_mm2s(self, params, row, col) -> str:
        dims = params['dims']
        mode = params['WgtDataFlow']

        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                return f"W:0:{dims.N}:{dims.N_subv} H:0:{dims.K}:{dims.K_subv} W:0:{dims.N}:{params['min_cols']} H:0:{dims.K_subv} W:0:{params['min_cols']}"
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                return f"W:0:{dims.N_subv}:{params['min_cols']} H:0:{dims.K_subv} W:0:{params['min_cols']}"
        else:
            if mode=="stream":
                return f'Bytes:0:{dims.wgt_subv_bytes}'
            else:
                return f'Cols:0:{dims.wgt_subv_cols} Rows:0:{dims.wgt_subv_rows} Bytes:0:{dims.wgt_subv_bytes}'

    def wgt_memtile_s2mm(self, params, col: int, col_list=[], Kpad: int = 0, Npad: int = 0, batch: int = 0) -> str:
        dims = params['dims']
        mode = params['WgtDataFlow']

        if params['actxact']:
            if [params['WgtDataFlow'], params['WgtFormat']] == ["full",    "default"]:
                return f'H:0:{dims.K} W:0:{dims.N}'
            if [params['WgtDataFlow'], params['WgtFormat']] == ["stream",    "default"]:
                return f'H:0:{dims.K_subv} W:0:{dims.N_subv}'
        else:
            num_cols = dims.wgt_subv_cols // (dims.aie_cols * dims.aie_arrays)
            if mode=="stream":
                return f'Bytes:0:{dims.wgt_subv_bytes}'
            else:
                return f'Cols:0:{dims.wgt_subv_cols} Rows:0:{dims.wgt_subv_rows} Bytes:0:{dims.wgt_subv_bytes}'

    #out_shim_memory
    #out_memtile_memory
    #out_memtile_s2mm
    #out_memtile_mm2s

    def out_shim_s2mm(self, params, array: int, col: int, Munpad=0, last=0, batch: int = 0) -> str:
        dims = params['dims']
        if params['OfmFormat'] == "default":
            H_start = col * (dims.M_subv * dims.aie_rows)
            H_stop = H_start + (dims.M_subv * dims.aie_rows)
            H_stride = dims.M_subv * dims.aie_rows * dims.aie_cols
            W_start = self.col_index(dims, array, col) * dims.N_subv
            W_stop = W_start + dims.N_subv * dims.aie_rows
            W_stride = dims.N_subv * dims.aie_cols * dims.aie_rows
            return f'H:0:{dims.M}:{H_stride} W:0:{dims.N}:{dims.N_subv} H:{H_start}:{H_stop} W:0:{dims.N_subv}'
        if params['OfmFormat'] == "BFP_A8":
            H_stride = dims.M_subv* dims.aie_rows
            H_start = col * H_stride
            H_stop = H_start + H_stride
            W_stride = dims.N_subv
            W_start = col * W_stride
            W_stop  = W_start + W_stride 
            return f"H:0:{dims.M}:{H_stride} W:0:{dims.N}:{dims.N_subv*dims.aie_cols} W:{W_start}:{W_stop}:{params['min_cols']} H:0:{H_stride} W:0:{params['min_cols']}"
        if params['OfmFormat'] == "BF16_A8": #same as default
            H_start = 0
            H_stop = H_start + (dims.M_subv * dims.aie_rows)
            H_stride = dims.M_subv * dims.aie_rows
            W_start = self.col_index(dims, array, col) * dims.N_subv
            W_stop = W_start + dims.N_subv
            W_stride = dims.N_subv * dims.aie_cols * dims.aie_arrays
            return f'H:0:{dims.M}:{H_stride} W:0:{dims.N}:{W_stride} H:{H_start}:{H_stop} W:{W_start}:{W_stop}'
