import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
from template_base_gemm import Gemm_base, GemmDims
import scheduler_utils as utils

class Overlay_8x4_base(Gemm_base):
    def overlay_param(self, params):
        params['row_list']                  = [0,1,2,3]
        params['broadcast_idx']             = [[0,2,4,6],[0,1,2,3]]
        params['unicast_idx']               = [[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7]]
        params['rope_idx']                  = params['unicast_idx'] #RoPE SIN/COS is unicast
        params['elew_idx']                  = params['unicast_idx']
        if params['dims'].split[1] > params['dims'].split[3]:
            params['ifm_idx']               = params['unicast_idx']
            params['wgt_idx']               = params['broadcast_idx']
            params['ofm_idx']               = params['unicast_idx']
            params['ifm_channel']           = 'unicast'
            params['wgt_channel']           = 'broadcast'
            params['ofm_channel']           = 'unicast'
        else:
            params['ifm_idx']               = params['broadcast_idx']
            params['wgt_idx']               = params['unicast_idx']
            params['ofm_idx']               = params['unicast_idx']
            params['ifm_channel']           = 'broadcast'
            params['wgt_channel']           = 'unicast'
            params['ofm_channel']           = 'broadcast'
        params['qdq_repeat']            = 0
        
class M4N8(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode']                = 0 if params['actxact'] and params['dims'].act_bits == 16 else 1
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 0
        params['CoreQdqChId']           = 1

class M8N4(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode']                = 1 if params['actxact'] and params['dims'].act_bits == 16 else 0
        params['in_ch_mode']                = 1
        params['MemtileWgtchid']            = 0
        params['MemtileActchid']            = 1
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 0
        params['CoreQdqChId']           = 1

class M1N32(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode'] = 0 if params['actxact'] and params['dims'].act_bits == 16 else 1
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 0
        params['CoreQdqChId']           = 1
        if not params['actxact'] and params['dims'].wgt_subv_bytes >= 4096 and params['dims'].Tn > 1:
            params['shim_wgt_reengueue']    = True
        else:
            params['shim_wgt_reengueue']    = False

class M32N1(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode']                = 1 if params['actxact'] and params['dims'].act_bits == 16 else 0
        params['in_ch_mode']                = 1
        params['MemtileWgtchid']            = 0
        params['MemtileActchid']            = 1
        params['shim_act_bdchaining']       = False
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 0
        params['CoreQdqChId']           = 1


class B4M8N1(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        params['kernelmode'] = 1 if params['actxact'] and params['dims'].act_bits == 16 else 0
        params['in_ch_mode']                = 1
        params['MemtileWgtchid']            = 0
        params['MemtileActchid']            = 1
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 1
        params['CoreQdqChId']           = 1


class B32M1N1(Overlay_8x4_base):
    def set_spatialsplit_params(self, params):
        dims = params['dims']
        assert dims.Tm==dims.Tk==dims.Tn==1, "B32M1N1 only supports Tm=Tk=Tn=1"
        params['kernelmode'] = 1 if params['actxact'] and params['dims'].act_bits == 16 else 0
        params['in_ch_mode']                = 1
        params['MemtileWgtchid']            = 0
        params['MemtileActchid']            = 1
        params['ShimParamMode']         = 'unicast'
        params['ShimQdqMode']           = 'broadcast'
        params['ShimParamChannelId']    = 1
        params['param_channel_id']      = 0
        params['ShimQdqChannelId']      = 1
        params['CoreQdqChId']           = 1
        params['ifm_idx']               = params['unicast_idx']
        params['wgt_idx']               = params['unicast_idx']
        params['ofm_idx']               = params['unicast_idx']
        params['ifm_channel']           = 'unicast_headpercore'
        params['wgt_channel']           = 'unicast_headpercore'
        params['ofm_channel']           = 'unicast_headpercore'

        
            


