import struct 

from kernel.common.kernel_params_helper import (
    DimsHelper, conv_to_local_ptr,
)

from scheduler.conv.conv_config_builders import (
    ConvDims,
)

from utils.utils_common import (
    split_to_mode,
    log,
)


def setup_gemm_qdq_a16a16_transpose_params(
    dims: ConvDims
) -> bytes:
    transpose = dims.transpose_wgts #0
    dimsA = DimsHelper( -128 * 16 )
    dimsAs = DimsHelper( -128 * 16 )
    dimsB = DimsHelper( 128 - 8192 ) if transpose else DimsHelper(0)
    dimsQ = DimsHelper( 0 )
    dimsW = DimsHelper( 0 )
    
    outer_time_iters = dims.X_loop * dims.Co_loop * dims.Y_loop
    inner_time_iters = dims.Ci_loop
    inner_g = int(dims.Cis/64)
    Y_g = int(dims.Xis / 16)
    X_g = int(dims.Cos / 64)
    sizeof_I0 = 2 # 2 bytes
    step_Xi = 64 * dims.Xis * sizeof_I0
    step_Yi = 64 * sizeof_I0
    step_Kx = 64 * sizeof_I0
    step_Ky = 64 * dims.Cos * sizeof_I0
    shift_res = 10 # range(-32,31)
    tsl_bound = int(dims.Cis / 64 + dims.Cos / 32)
    is_split = dims.is_split   #0
    mode = split_to_mode(dims)
    dimsA = dimsA.from_steps((inner_g, Y_g), (step_Xi, step_Yi * 16, 0 ))
    dimsB = dimsB.from_steps((inner_g, Y_g), (step_Ky, 0, step_Kx * 64)) if transpose else dimsB.from_steps(( 2, 2 ), ( 128, 64, 256 ))
    dimsQ = dimsQ.from_steps(Y_g, (0, 512))
    dimsAs = dimsAs.from_steps((inner_g), (step_Xi, step_Yi * 16))
    dimsW = dimsW.from_steps(( 16, int(dims.Cos / 32) ), ( 256, 4096 if transpose else 8192, 8192 if transpose else 128 ))

    log(f"outer_time_iters: {outer_time_iters}")
    log(f"inner_time_iters: {inner_time_iters}")
    log(f"inner_g: {inner_g}")
    log(f"tsl_bound: {tsl_bound}")
    log(f"transpose: {transpose}")
    log(f"is_split: {is_split}")
    log(f"Y_g: {Y_g}")
    log(f"X_g: {X_g}")
    log(f"step_Xi: {step_Xi}")
    log(f"step_Yi: {step_Yi}")
    log(f"step_Kx: {step_Kx}")
    log(f"step_Ky: {step_Ky}")
    log(f"shift_res: {shift_res}")
    log(f"mode: {mode}")
    # log(f"ctrl: {ctrl}")
    log(f"dimsA['num0']: {dimsA['num0']}")
    log(f"dimsA['num1']: {dimsA['num1']}")
    log(f"dimsA['inc0']: {dimsA['inc0']}")
    log(f"dimsA['inc1']: {dimsA['inc1']}")
    log(f"dimsA['inc2']: {dimsA['inc2']}")

    log(f"dimsB['num0']: {dimsB['num0']}")
    log(f"dimsB['num1']: {dimsB['num1']}")
    log(f"dimsB['inc0']: {dimsB['inc0']}")
    log(f"dimsB['inc1']: {dimsB['inc1']}")
    log(f"dimsB['inc2']: {dimsB['inc2']}")

    log(f"dimsQ['num0']: {dimsQ['num0']}")
    log(f"dimsQ['inc0']: {dimsQ['inc0']}")
    log(f"dimsQ['inc1']: {dimsQ['inc1']}")

    log(f"dimsAs['num0']: {dimsAs['num0']}")
    log(f"dimsAs['inc0']: {dimsAs['inc0']}")
    log(f"dimsAs['inc1']: {dimsAs['inc1']}")

    log(f"dimsW['num0']: {dimsW['num0']}")
    log(f"dimsW['num1']: {dimsW['num1']}")
    log(f"dimsW['inc0']: {dimsW['inc0']}")
    log(f"dimsW['inc1']: {dimsW['inc1']}")
    log(f"dimsW['inc2']: {dimsW['inc2']}")
    # TOODO: make generic
    
    packed_tail = struct.pack(
        '<4H4B6H2I3i2I3i1I2i1I2i2I3i',
        outer_time_iters,
        inner_time_iters,
        inner_g,
        tsl_bound,
        transpose,
        is_split,
        Y_g,
        X_g,
        step_Xi,
        step_Yi,
        step_Kx,
        step_Ky,
        shift_res,
        mode,
        dimsA['num0'], dimsA['num1'], dimsA['inc0'], dimsA['inc1'], dimsA['inc2'],
        dimsB['num0'], dimsB['num1'], dimsB['inc0'], dimsB['inc1'], dimsB['inc2'],
        dimsQ['num0'], dimsQ['inc0'], dimsQ['inc1'],
        dimsAs['num0'], dimsAs['inc0'], dimsAs['inc1'],
        dimsW['num0'], dimsW['num1'], dimsW['inc0'], dimsW['inc1'], dimsW['inc2'],
    )
    return packed_tail

def gemm_transpose_layer_params(
    core_tdm_buffer_addr: int,
    core_wght_transpose_sb_addr: int,
    core_cfqdq_buffer_addr: int,
    qdq_addr: int,
) -> bytes:
    '''Generate the layer parameters for the GEMM operation.'''
    core_tdm_buffer_addr = conv_to_local_ptr(core_tdm_buffer_addr)
    core_wght_transpose_sb_addr = conv_to_local_ptr(core_wght_transpose_sb_addr)
    core_cfqdq_buffer_addr = conv_to_local_ptr(core_cfqdq_buffer_addr)
    qdq_addr = conv_to_local_ptr(qdq_addr)
    return struct.pack(
        '<4I',
        core_tdm_buffer_addr,
        core_wght_transpose_sb_addr,
        core_cfqdq_buffer_addr,
        qdq_addr
    )


def generate_gemm_qdq_a16a16_transpose_params(
    dims: ConvDims,
    core_tdm_buffer_addr: int,
    core_wght_transpose_sb_addr: int,
    core_cfqdq_buffer_addr: int,
    qdq_addr: int,
) -> bytes:
    layer_params = gemm_transpose_layer_params(core_tdm_buffer_addr, core_wght_transpose_sb_addr, core_cfqdq_buffer_addr, qdq_addr)
    kernel_params = setup_gemm_qdq_a16a16_transpose_params(dims)
    log("Layer params size and kernel params size:")
    log(str(len(layer_params + kernel_params)))
    return (layer_params + kernel_params)
