from enum import IntEnum
import struct

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

class OPMode(IntEnum):
    OP_NONE = 0
    OP_CONV = 1
    OP_SUM = 2
    OP_SUM_T = 3
    OP_DWC = 4
    OP_DWC_SUM = 5
    OP_CONV_SYM = 8
    OP_CONV_ASYM = 9
    OP_DWC_SYM = 10
    OP_DWC_ASYM = 11
    OP_QDQ = 12
    OP_ASYM = 13

class MhaSubvDims:
    def __init__(
        self,
        N : int, 
        Y : int, Y_gran : int,
        X : int, X_gran : int,
        Ci : int, Ci_gran : int,
        Co : int, Co_gran : int,
        Ky : int, 
        Kx : int,
        Sy : int,
        Sx : int,
        op_mode : OPMode,
        size_bytes : int,
        stride_efficiency : float,
        mem_align : int,
    ):
        assert Y % Y_gran == 0
        assert X % X_gran == 0
        assert Co % Co_gran == 0
        assert Ci % Ci_gran == 0
        assert op_mode in (OPMode.OP_NONE, OPMode.OP_CONV, 
                           OPMode.OP_SUM, OPMode.OP_SUM_T, 
                           OPMode.OP_DWC, OPMode.OP_DWC_SUM, 
                           OPMode.OP_CONV_SYM, OPMode.OP_CONV_ASYM, 
                           OPMode.OP_DWC_SYM, OPMode.OP_DWC_ASYM, 
                           OPMode.OP_QDQ ,OPMode.OP_ASYM)
        self.N = N
        self.Y = Y
        self.Y_gran = Y_gran
        self.X = X
        self.X_gran = X_gran
        self.Co = Co
        self.Co_gran = Co_gran
        self.Ci = Ci
        self.Ci_gran = Ci_gran
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.op_mode = op_mode
        self.size_bytes = size_bytes
        self.stride_efficiency = stride_efficiency
        self.mem_align = mem_align

def setup_mha_params(
    dims: MhaSubvDims
):
    Xi_exp = iceil(dims.X * dims.Sx + dims.Kx - 1, dims.mem_align//dims.Ci_gran//dims.size_bytes)
    Kx_g = dims.Kx
    Ky_g = dims.Ky
    Ci_g = ceildiv(dims.Ci,dims.Ci_gran)
    if dims.op_mode == OPMode.OP_DWC:
        Ci_g = 1
    S_g = dims.Sx
    N_g = dims.N
    X_g = int(ceildiv(dims.N * dims.X // dims.X_gran, dims.stride_efficiency
                  if dims.Sx > 1 else 1))
    Y_g = ceildiv(dims.Y, dims.Y_gran)
    Co_g = ceildiv(dims.Co, dims.Co_gran)
    inner_g = Kx_g * Ky_g * Ci_g
    outer_g = X_g * Y_g * Co_g
    shift_tdm = 0
    shift_res = 0
    zp_wght = 0
    step_Kx = dims.Ci_gran * dims.size_bytes
    step_Ky = Xi_exp * dims.size_bytes * dims.Ci
    step_Ci = Xi_exp * dims.Ci_gran * dims.size_bytes
    step_Xi = int(dims.X_gran * (dims.stride_efficiency * dims.Sx 
                                  if dims.Sx > 1 else 1) * dims.size_bytes)
    step_Yi = Xi_exp * dims.size_bytes * dims.Ci * dims.Sy
    step_Xo = dims.Co_gran * dims.size_bytes
    step_Yo = iceil(dims.X, ceildiv(dims.mem_align,dims.Co)) * dims.Co * dims.size_bytes 
    step_Co = dims.Co_gran * dims.size_bytes * iceil(dims.X, ceildiv(dims.mem_align,dims.Co))
    param_value = 0
    ctrl = 0
    struct_fields = (
        Kx_g,
        Ky_g,
        Ci_g,
        S_g,
        N_g,
        X_g,
        Y_g,
        Co_g,
        inner_g,
        outer_g,
        shift_tdm,
        shift_res,
        zp_wght,
        dims.op_mode,
        step_Kx,
        step_Ky,
        step_Ci,
        step_Xi,
        step_Yi,
        step_Xo,
        step_Yo,
        step_Co,
        param_value,
        ctrl,
    )
    format_string = 'BBBbBBBBHHbbbbHHHHHHHHii'
    kernel_params = struct.pack(format_string, *struct_fields)
    return kernel_params

def qkt_sfmx_params(
        dims : MhaSubvDims,
        mha_mode : int,
        multi_core:int, 
        col:int, 
        row:int,
        Sin_kv:int, 
        core_tdm1_addr : int, 
        core_tdm2_addr : int,
        core_qdq_addr : int,
        core_act1_addr : int,
        core_act2_addr : int,
        core_C0_addr : int,
        core_scratch_addr : int,
        core_query_addr : int,
        core_key_addr : int,
        core_val_addr : int,
        core_msk_addr : int

) -> bytes:
    return (
           mha_mode.to_bytes(length=1, byteorder='little', signed=False) + \
           multi_core.to_bytes(length=1, byteorder='little', signed=False) + \
           col.to_bytes(length=1, byteorder='little', signed=False) + \
           row.to_bytes(length=1, byteorder='little', signed=False) + \
           dims.X.to_bytes(length=2, byteorder='little', signed=False) + \
           dims.Ci.to_bytes(length=2, byteorder='little', signed=False) + \
           dims.Co.to_bytes(length=2, byteorder='little', signed=False) + \
           Sin_kv.to_bytes(length=2, byteorder='little', signed=False) + \
           core_tdm1_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_tdm2_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_qdq_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_act1_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_act2_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_C0_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_scratch_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_query_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_key_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_val_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           core_msk_addr.to_bytes(length=2, byteorder='little', signed=False) 
    )

def generate_layer_kernel_params1(
    mha_mode : int,
    multi_core : int,
    col_id : int,
    row_id : int,
    Sin_kv : int,
    core_tdm1_addr : int, 
    core_tdm2_addr : int,
    core_qdq_addr : int,
    core_act1_addr : int,
    core_act2_addr : int,
    core_C0_addr : int,
    core_scratch_addr : int,
    core_query_addr : int,
    core_key_addr : int,
    core_val_addr : int,
    core_msk_addr : int,
    dims : MhaSubvDims
) -> bytes:
    layer_params = qkt_sfmx_params(
        dims,
        mha_mode,
        multi_core,
        col_id,
        row_id,
        Sin_kv,
        core_tdm1_addr,
        core_tdm2_addr,
        core_qdq_addr,
        core_act1_addr,
        core_act2_addr,
        core_C0_addr,
        core_scratch_addr,
        core_query_addr,
        core_key_addr,
        core_val_addr,
        core_msk_addr,
    )
    kernel_params = setup_mha_params(
        dims
    )
    combined_params = layer_params + kernel_params
    return combined_params

def main():
    dims = MhaSubvDims(1,
                        1,1,
                        16,8,
                        64,8,
                        40,8,
                        1,1,
                        1, 1,
                        OPMode.OP_CONV_ASYM,
                        2, 0.5, 64)

    setup_mha_params(dims)

if __name__ == '__main__':
    main()