from enum import IntEnum
import struct
import ctypes

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

class OPMode(IntEnum):
    OP_NONE = 0
    OP_CONV = 1
    OP_SUM = 2
    OP_SUM_T = 3
    OP_DWC = 4
    OP_DWC_SUM = 5
    OP_CONV_SYM = 8
    OP_CONV_ASYM = 9
    OP_DWC_SYM = 10
    OP_DWC_ASYM = 11
    OP_QDQ = 12
    OP_ASYM = 13

class Ctrl(ctypes.Structure):
    _fields_= [
        ("zero_init", ctypes.c_uint32, 1),
        ("sign_N", ctypes.c_uint32, 1),
        ("sign_O", ctypes.c_uint32, 1),
        ("reserved3", ctypes.c_uint32, 3),
        ("skip_casc_in", ctypes.c_uint32, 1),
        ("skip_casc_out", ctypes.c_uint32, 1),
        ("sign_W", ctypes.c_uint32, 1),
        ("sign_A", ctypes.c_uint32, 1),
        ("out_32", ctypes.c_uint32, 1),
        ("add_bias", ctypes.c_uint32, 1),
        ("reserved10", ctypes.c_uint32, 12),
        ("norm_ch_g", ctypes.c_uint32, 8)
    ]

class GemmSubvDims:
    def __init__(
        self,
        N : int, 
        Y : int, Y_gran : int,
        X : int, X_gran : int,
        Co : int, Co_gran : int,
        Ci : int, Ci_gran : int,
        Ky : int, 
        Kx : int,
        Sy : int,
        Sx : int,
        op_mode : OPMode,
        size_bytes : int,
        stride_efficiency : float,
        mem_align : int,
        sign_act: int=None, # pylint: disable=no-value-for-parameter
        sign_wgt: int=None, # pylint: disable=no-value-for-parameter
        sign_out: int=None, # pylint: disable=no-value-for-parameter
    ): 
        assert Y % Y_gran == 0
        assert X % X_gran == 0
        assert Co % Co_gran == 0
        assert Ci % Ci_gran == 0
        assert op_mode in (OPMode.OP_NONE, OPMode.OP_CONV, 
                           OPMode.OP_SUM, OPMode.OP_SUM_T, 
                           OPMode.OP_DWC, OPMode.OP_DWC_SUM, 
                           OPMode.OP_CONV_SYM, OPMode.OP_CONV_ASYM, 
                           OPMode.OP_DWC_SYM, OPMode.OP_DWC_ASYM, 
                           OPMode.OP_QDQ ,OPMode.OP_ASYM)
        self.N = N
        self.Y = Y
        self.Y_gran = Y_gran
        self.X = X
        self.X_gran = X_gran
        self.Co = Co
        self.Co_gran = Co_gran
        self.Ci = Ci
        self.Ci_gran = Ci_gran
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.op_mode = op_mode
        self.size_bytes = size_bytes
        self.stride_efficiency = stride_efficiency
        self.mem_align = mem_align
        self.sign_act = sign_act
        self.sign_wgt = sign_wgt
        self.sign_out = sign_out

def setup_gemm_params(
    dims: GemmSubvDims
):
    Xi_exp = iceil(dims.X * dims.Sx + dims.Kx - 1, dims.mem_align//dims.Ci_gran//dims.size_bytes)
    Kx_g = dims.Kx
    Ky_g = dims.Ky
    Ci_g = ceildiv(dims.Ci,dims.Ci_gran)
    if dims.op_mode == OPMode.OP_DWC:
        Ci_g = 1
    S_g = dims.Sx
    N_g = dims.N
    X_g = int(ceildiv(dims.N * dims.X // dims.X_gran, dims.stride_efficiency
                  if dims.Sx > 1 else 1))
    Y_g = ceildiv(dims.Y, dims.Y_gran)
    Co_g = ceildiv(dims.Co, dims.Co_gran)
    inner_g = Kx_g * Ky_g * Ci_g
    outer_g = X_g * Y_g * Co_g
    shift_tdm = 0
    shift_res = 0
    zp_wght = 0
    step_Kx = dims.Ci_gran * dims.size_bytes
    step_Ky = Xi_exp * dims.size_bytes * dims.Ci
    step_Ci = Xi_exp * dims.Ci_gran * dims.size_bytes
    step_Xi = int(dims.X_gran * (dims.stride_efficiency * dims.Sx 
                                  if dims.Sx > 1 else 1) * dims.size_bytes)
    step_Yi = Xi_exp * dims.size_bytes * dims.Ci * dims.Sy
    step_Xo = dims.Co_gran * dims.size_bytes
    step_Yo = iceil(dims.X, ceildiv(dims.mem_align,dims.Co)) * dims.Co * dims.size_bytes 
    step_Co = dims.Co_gran * dims.size_bytes * iceil(dims.X, ceildiv(dims.mem_align,dims.Co))
    param_value = 0
    ctrl = Ctrl()
    ctrl.sign_A = dims.sign_act
    ctrl.sign_W = dims.sign_wgt
    ctrl.sign_O = dims.sign_out
    struct_fields = (
        Kx_g,
        Ky_g,
        Ci_g,
        S_g,
        N_g,
        X_g,
        Y_g,
        Co_g,
        inner_g,
        outer_g,
        shift_tdm,
        shift_res,
        zp_wght,
        dims.op_mode,
        step_Kx,
        step_Ky,
        step_Ci,
        step_Xi,
        step_Yi,
        step_Xo,
        step_Yo,
        step_Co,
        param_value,
        ctypes.string_at(ctypes.addressof(ctrl), ctypes.sizeof(ctrl))
    )
    format_string = 'BBBbBBBBHHbbbbHHHHHHHHi4s'
    kernel_params = struct.pack(format_string, *struct_fields)
    return kernel_params


def gemm_params(
    dims: GemmSubvDims,
    zero_acc: int,
    mode : int,
    qdq_addr: int,
    sum_addr: int,
    tdm1_addr: int,
    tdm2_addr: int,
    wgt_bits: int,
) -> bytes:
    # NOTE: Mode0: args_params[1] == 0 => MatA broadcast, MatB unicast
    #       Mode1: args_params[1] == 1 => MatA unicast, MatB broadcast
    # Here, weights are unicast, so that is mode 1.
    assert zero_acc in (0, 1)
    assert mode in (0, 1)
    wgt_size = iceil((dims.Ci * dims.Co * wgt_bits) // 8, 64)
    n_elems = dims.Ci * dims.Co
    reserved = 0
    return (
        zero_acc.to_bytes(length=1, byteorder='little', signed=False)
        + mode.to_bytes(length=1, byteorder='little', signed=False)
        + wgt_size.to_bytes(length=2, byteorder='little', signed=False)
        + n_elems.to_bytes(length=2, byteorder='little', signed=False)
        + qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
        + sum_addr.to_bytes(length=2, byteorder='little', signed=False)
        + tdm1_addr.to_bytes(length=2, byteorder='little', signed=False)
        + tdm2_addr.to_bytes(length=2, byteorder='little', signed=False)
        + reserved.to_bytes(length=2, byteorder='little', signed=False)
    )

def generate_layer_kernel_params(
    zero_acc: int,
    mode : int,
    qdq_addr: int,
    ifmsum_addr : int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    dims : GemmSubvDims,
    int4_wgt: int = 0,
    wgt_unpack_addr: int = 0,
    wgt_bits: int = 8,
) -> bytes:
    layer_params = gemm_params(
        dims,
        zero_acc,
        mode,
        qdq_addr,
        ifmsum_addr,
        tdm_1_addr,
        tdm_2_addr,
        wgt_bits,
    )
    kernel_params = setup_gemm_params(
        dims
    )
    int4_wgt_params = (
        int4_wgt.to_bytes(length=2, byteorder='little', signed=False)
        + wgt_unpack_addr.to_bytes(length=2, byteorder='little', signed=False)
    )
    print(f"int4_wgt", int4_wgt)
    print(f"wgt_unpack_addr", wgt_unpack_addr)
    combined_params = layer_params + kernel_params + int4_wgt_params
    return combined_params

def main():
    dims = GemmSubvDims(1,
                        1,1,
                        32,8,
                        80,8,
                        80,8,
                        1,1,
                        1, 1,
                        OPMode.OP_CONV_ASYM,
                        2, 0.5, 64) # pylint: disable=no-value-for-parameter

    setup_gemm_params(dims)

if __name__ == '__main__':
    main()
