import struct

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

from utils.utils_common import (
    log,
)


def setup_gemm_qdq_a16w4_params(
    Xos: int, Cos: int, Cis: int,
    Ky: int, Kx: int,
    X_loop: int, Co_loop: int, Ci_loop: int,
    full_iters: bool = True,
) -> bytes:
    granM = 32
    step_back_AM_int16 = -128
    step_back_B_int4 = -4096/2
    dimsClassA = DimsHelper(step_back_AM_int16 * granM)
    dimsClassB = DimsHelper( step_back_B_int4 )
    dimsClassQ = DimsHelper(0)

    outer_iters = X_loop * Co_loop if full_iters else Co_loop
    inner_iters = Ci_loop
    inner_loop = Cis // 64
    Y_g = Xos // 32
    X_g = Cos // 64
    step_Xi = 64 * Xos * 2
    step_Yi = 64 * 2
    step_Kx = 64
    step_Ky = Cos * (64 // 2)
    # shift_res = dims.shift_out
    # ctrl = 0
    dimsA = dimsClassA.from_steps(( inner_loop, Y_g ), ( step_Xi, step_Yi * granM, 0))
    dimsB = dimsClassB.from_steps(( inner_loop, Y_g ), ( step_Ky, 0, step_Kx * 64 ))
    dimsQ = dimsClassQ.from_steps((Y_g),(0,512))

    print(f"outer_iters: {outer_iters}")
    print(f"inner_iters: {inner_iters}")
    print(f"inner_loop: {inner_loop}")
    print(f"Y_g: {Y_g}")
    print(f"X_g: {X_g}")
    print(f"step_Xi: {step_Xi}")
    print(f"step_Yi: {step_Yi}")
    print(f"step_Kx: {step_Kx}")
    print(f"step_Ky: {step_Ky}")
    # print(f"shift_res: {shift_res}")
    # print(f"ctrl: {ctrl}")
    print(f"dimsA['num0']: {dimsA['num0']}")
    print(f"dimsA['num1']: {dimsA['num1']}")
    print(f"dimsA['inc0']: {dimsA['inc0']}")
    print(f"dimsA['inc1']: {dimsA['inc1']}")
    print(f"dimsA['inc2']: {dimsA['inc2']}")

    print(f"dimsB['num0']: {dimsB['num0']}")
    print(f"dimsB['num1']: {dimsB['num1']}")
    print(f"dimsB['inc0']: {dimsB['inc0']}")
    print(f"dimsB['inc1']: {dimsB['inc1']}")
    print(f"dimsB['inc2']: {dimsB['inc2']}")

    print(f"dimsQ['num0']: {dimsQ['num0']}")
    print(f"dimsQ['inc0']: {dimsQ['inc0']}")
    print(f"dimsQ['inc1']: {dimsQ['inc1']}")

    packed_params = struct.pack(
        '<3H2B4H2I3i2I3i1I2i',
        outer_iters,        # H
        inner_iters,        # H
        inner_loop,     # H
        Y_g,        # B
        X_g,    # B
        step_Xi,    # H
        step_Yi,    # H
        step_Kx,    # H
        step_Ky,    # H
        # shift_res,  # H
        # ctrl,       # B
        # 0,  # B (reserved) - There was a byte alignment issue in the original code, so we added a reserved byte
        dimsA['num0'],  # I
        dimsA['num1'],  # I
        dimsA['inc0'],  # i
        dimsA['inc1'],  # i
        dimsA['inc2'],  # i
        dimsB['num0'],  # I
        dimsB['num1'],  # I
        int(dimsB['inc0']),  # i
        int(dimsB['inc1']),  # i
        int(dimsB['inc2']),  # i
        dimsQ['num0'],  # I
        dimsQ['inc0'],  # i
        dimsQ['inc1'],   # i
    )
    return packed_params


def gemm_layer_params(
    ifm_ch_num:int,
    core_spill_buf: int,
    core_ifm_tmp_buffer: int,
    core_coeff_tmp_buffer: int,
    wgt_size: int,
    coeff_size: int,
) -> bytes:
    '''Generate the layer parameters for the GEMM operation.'''
    return (
        core_spill_buf.to_bytes(length=4, byteorder='little', signed=False)
        + core_ifm_tmp_buffer.to_bytes(length=4, byteorder='little', signed=False)
        + core_coeff_tmp_buffer.to_bytes(length=4, byteorder='little', signed=False)
        + wgt_size.to_bytes(length=4, byteorder='little', signed=False)
        + coeff_size.to_bytes(length=4, byteorder='little', signed=False)
        + ifm_ch_num.to_bytes(length=4, byteorder='little', signed=False)
    )

def generate_gemm_qdq_a16w4_params(
    Xos: int, Cos: int, Cis: int, Ky : int, Kx : int,
    X_loop: int, Co_loop: int, Ci_loop: int,
    ifm_ch_num:int,
    core_spill_buf: int,
    core_ifm_tmp_buffer: int,
    core_coeff_tmp_buffer: int,
    full_iters: bool = True,
) -> bytes:
    wgt_bits = 4
    bias_bits = 32
    raw_wgt_subv_size = Cos * Ky * Kx * Cis * wgt_bits // 8
    layer_params = gemm_layer_params(ifm_ch_num, core_spill_buf, core_ifm_tmp_buffer, core_coeff_tmp_buffer, raw_wgt_subv_size, (Cos*32//8 + Cos*32//8 + Cos*32//8))
    kernel_params = setup_gemm_qdq_a16w4_params(Xos, Cos, Cis, Ky, Kx, X_loop, Co_loop, Ci_loop, full_iters)
    return (layer_params + kernel_params)
