from enum import IntEnum
from typing import Tuple, List, Optional
import struct

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

class OPMode(IntEnum):
    OP_NONE = 0
    OP_CONV = 1
    OP_SUM = 2
    OP_SUM_T = 3
    OP_DWC = 4
    OP_DWC_SUM = 5
    OP_CONV_SYM = 8
    OP_CONV_ASYM = 9
    OP_DWC_SYM = 10
    OP_DWC_ASYM = 11
    OP_QDQ = 12
    OP_ASYM = 13

class ConvSubvDims:
    def __init__(
        self,
        N : int, 
        Y : int, Y_gran : int,
        X : int, X_gran : int,
        Co : int, Co_gran : int,
        Ci : int, Ci_gran : int,
        Ky : int, 
        Kx : int,
        Sy : int,
        Sx : int,
        op_mode : OPMode,
        size_bytes : int,
        stride_efficiency : float,
        mem_align : int,
    ):
        assert Y % Y_gran == 0
        assert X % X_gran == 0
        assert Co % Co_gran == 0
        assert Ci % Ci_gran == 0
        assert op_mode in (OPMode.OP_NONE, OPMode.OP_CONV, 
                           OPMode.OP_SUM, OPMode.OP_SUM_T, 
                           OPMode.OP_DWC, OPMode.OP_DWC_SUM, 
                           OPMode.OP_CONV_SYM, OPMode.OP_CONV_ASYM, 
                           OPMode.OP_DWC_SYM, OPMode.OP_DWC_ASYM, 
                           OPMode.OP_QDQ ,OPMode.OP_ASYM)
        self.N = N
        self.Y = Y
        self.Y_gran = Y_gran
        self.X = X
        self.X_gran = X_gran
        self.Co = Co
        self.Co_gran = Co_gran
        self.Ci = Ci
        self.Ci_gran = Ci_gran
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.op_mode = op_mode
        self.size_bytes = size_bytes
        self.stride_efficiency = stride_efficiency
        self.mem_align = mem_align

def setup_conv_params(
    dims: ConvSubvDims,
    op_mode : OPMode
):
    Xi_exp = iceil((dims.X - 1) * dims.Sx + dims.Kx, \
                    dims.mem_align//dims.Ci_gran//dims.size_bytes)
    Kx_g = dims.Kx
    Ky_g = dims.Ky
    Ci_g = ceildiv(dims.Ci,dims.Ci_gran)
    if dims.op_mode == OPMode.OP_DWC_ASYM:
        Ci_g = 1
    S_g = dims.Sx
    N_g = dims.N
    X_g = int(ceildiv(dims.N * dims.X // dims.X_gran, dims.stride_efficiency
                  if dims.Sx > 1 else 1))
    Y_g = ceildiv(dims.Y, dims.Y_gran)
    Co_g = ceildiv(dims.Co, dims.Co_gran)
    inner_g = Kx_g * Ky_g * Ci_g
    outer_g = X_g * Y_g * Co_g
    shift_tdm = 0
    shift_res = 0
    zp_wght = 0
    step_Kx = dims.Ci_gran * dims.size_bytes
    step_Ky = Xi_exp * dims.size_bytes * dims.Ci
    step_Ci = Xi_exp * dims.Ci_gran * dims.size_bytes
    step_Xi = int(dims.X_gran * (dims.stride_efficiency * dims.Sx 
                                  if dims.Sx > 1 else 1) * dims.size_bytes)
    step_Yi = Xi_exp * dims.size_bytes * dims.Ci * dims.Sy
    step_Xo = dims.Co_gran * dims.size_bytes
    step_Yo = iceil(dims.X, dims.mem_align//dims.Co_gran) * dims.Co * dims.size_bytes 
    step_Co = dims.Co_gran * dims.size_bytes * iceil(dims.X, dims.mem_align//dims.Co_gran)
    param_value = 0
    ctrl = 0
    struct_fields = (
        Kx_g,
        Ky_g,
        Ci_g,
        S_g,
        N_g,
        X_g,
        Y_g,
        Co_g,
        inner_g,
        outer_g,
        shift_tdm,
        shift_res,
        zp_wght,
        op_mode,
        step_Kx,
        step_Ky,
        step_Ci,
        step_Xi,
        step_Yi,
        step_Xo,
        step_Yo,
        step_Co,
        param_value,
        ctrl,
    )
    format_string = 'BBBbBBBBHHbbbbHHHHHHHHii'
    kernel_params = struct.pack(format_string, *struct_fields)
    return kernel_params


def a16w8_conv_params(
    first_tdm: int,
    final_tdm: int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    ifm_sum_addr: int,
    scratch_buf: int,
    tmp_buf: int,
) -> bytes:
    assert first_tdm in (0, 1)
    assert final_tdm in (0, 1)
    assert 0 <= tdm_1_addr < 2**16
    assert 0 <= tdm_2_addr < 2**16
    assert 0 <= ifm_sum_addr < 2**16
    assert 0 <= scratch_buf < 2**16
    dummy = 0
    layer_params = (
        first_tdm.to_bytes(length=1, byteorder='little', signed=False) +
        final_tdm.to_bytes(length=1, byteorder='little', signed=False) +
        tdm_1_addr.to_bytes(length=2, byteorder='little', signed=False) +
        tdm_2_addr.to_bytes(length=2, byteorder='little', signed=False) +
        ifm_sum_addr.to_bytes(length=2, byteorder='little', signed=False) +
        scratch_buf.to_bytes(length=2, byteorder='little', signed=False) + 
        tmp_buf.to_bytes(length=2, byteorder='little', signed=False) +      # why dummy because of the 8 bytes alignment 
        dummy.to_bytes(length=2, byteorder='little', signed=False) +
        dummy.to_bytes(length=2, byteorder='little', signed=False)
    )
    return layer_params

def generate_layer_kernel_params(
    first_tdm: int,
    final_tdm: int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    ifm_sum_addr: int,
    scratch_buf: int,
    tmp_buf: int,
    dims: ConvSubvDims,
) -> bytes:
    layer_params = a16w8_conv_params(
        first_tdm,
        final_tdm,
        tdm_1_addr,
        tdm_2_addr,
        ifm_sum_addr,
        scratch_buf,
        tmp_buf
    )
    kernel_params = setup_conv_params(
        dims,
        dims.op_mode
    )
    combined_params = layer_params + kernel_params
    return combined_params

def generate_add_kernel_params(
    conv_kernelprm_addr: int,
    ofm_addr: int,
    ofm_subsize: int,
) -> bytes:
    assert 0 <= ofm_addr < 2**16
    assert 0 <= conv_kernelprm_addr < 2**16
    dummy = 0
    layer_params = (
        conv_kernelprm_addr.to_bytes(length=2, byteorder='little', signed=False) +
        ofm_addr.to_bytes(length=2, byteorder='little', signed=False) + 
        ofm_subsize.to_bytes(length=2, byteorder='little', signed=False) 
    )
    return layer_params

def xint8_conv_params(
    first_tdm: int,
    final_tdm: int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    conv_kernelprm_addr: int,
    wgt_subsize: int,
    ofm_addr: int,
    ifm_flag: bool,
    ifm_addr: int,
) -> bytes:
    assert first_tdm in (0, 1)
    assert final_tdm in (0, 1)
    assert 0 <= tdm_1_addr < 2**16
    assert 0 <= tdm_2_addr < 2**16
    assert 0 <= ofm_addr < 2**16
    assert 0 <= conv_kernelprm_addr < 2**16
    dummy = 0
    layer_params = (
        first_tdm.to_bytes(length=2, byteorder='little', signed=False) +
        final_tdm.to_bytes(length=2, byteorder='little', signed=False) +
        tdm_1_addr.to_bytes(length=2, byteorder='little', signed=False) +
        tdm_2_addr.to_bytes(length=2, byteorder='little', signed=False) +
        conv_kernelprm_addr.to_bytes(length=2, byteorder='little', signed=False) +
        wgt_subsize.to_bytes(length=2, byteorder='little', signed=False) +
        ofm_addr.to_bytes(length=2, byteorder='little', signed=False) + 
        ifm_flag.to_bytes(length=2, byteorder='little', signed=False) + 
        ifm_addr.to_bytes(length=2, byteorder='little', signed=False)  
    )
    return layer_params

def generate_layer_kernel_params_xint8(
    first_tdm: int,
    final_tdm: int,
    tdm_1_addr: int,
    tdm_2_addr: int,
    conv_kernelprm_addr: int,
    ofm_addr: int,
    ifm_flag: bool,
    dims: ConvSubvDims,
    ifm_addr: int,
) -> bytes:
    combined_params = xint8_conv_params(
        first_tdm,
        final_tdm,
        tdm_1_addr,
        tdm_2_addr,
        conv_kernelprm_addr,
        dims.Co * dims.Ci * dims.Kx * dims.Ky,
        ofm_addr,
        ifm_flag,
        ifm_addr,
    )
    return combined_params

def main():
    #TESTS
    dims_s1_w16 = ConvSubvDims(1,      #H
                               4, 1,   #H, H_gran
                               16, 8,  #W, W_gran 
                               16, 8,  #Co, Co_gran
                               40, 8,  #Ci, Ci_gran
                               1,      #Ky
                               1,      #Kx,
                               1,      #Sy
                               1,      #Sx
                               OPMode.OP_CONV, #op_mode
                               2,      #size_bytes
                               0.5,    #stride_efficiency
                               64      #mem_alignment
    )

    params_s1_w16 = generate_layer_kernel_params(1,1,5120,21504,49984,1, 0, dims_s1_w16)
    print(params_s1_w16)

    dims_s1_w8 =  ConvSubvDims(1,      #H
                               2, 1,   #H, H_gran
                               8, 8,  #W, W_gran 
                               16, 8,  #Co, Co_gran
                               64, 8,  #Ci, Ci_gran
                               3,      #Ky
                               3,      #Kx,
                               1,      #Sy
                               1,      #Sx
                               OPMode.OP_CONV, #op_mode
                               2,      #size_bytes
                               0.5,    #stride_efficiency
                               64      #mem_alignment
    )
    params_s1_w8 = generate_layer_kernel_params(1,1,16384,32768,49152,1,0, dims_s1_w8)

    dims_s2_w32 = ConvSubvDims(1,      #H
                               2, 1,   #H, H_gran
                               32, 8,  #W, W_gran 
                               16, 8,  #Co, Co_gran
                               16, 8,  #Ci, Ci_gran
                               3,      #Ky
                               3,      #Kx,
                               2,      #Sy
                               2,      #Sx
                               OPMode.OP_CONV, #op_mode
                               2,      #size_bytes
                               0.5,    #stride_efficiency
                               64      #mem_alignment
    )
    params_s2_w32 = generate_layer_kernel_params(1,1,16384,32768,49152,1,0, dims_s2_w32)

if __name__ == '__main__':
    main()
