from kernels.common.kernel_params_helper import DimsHelper
import struct 
import math

def setup_gemm_qdq_int16x2_params(
    W: int, 
    Ci: int,
    Co: int,
    mode: int,
    zero_init: int,
    final_tdm_iter: int,
    wgt_size: int,
    bias_size: int,
    zp_size: int,
    weights_unpack_addr: int,
    tdm1_addr: int,
    tdm2_addr: int,
    tdm1s_addr: int,
    tdm2s_addr: int,
    qdq_addr: int,
    sign_A: int,
    sign_W: int,
    sign_out: int
):
    dimsZ = DimsHelper(0)
    dimsA = DimsHelper(0)
    dimsAO = DimsHelper(0)
    dimsW = DimsHelper(0)
    dimsSum = DimsHelper(0)
    dimsQnt = DimsHelper(0)
    dimsOut = DimsHelper(0)
    
    unpack_inner_loop = Ci*Co // (512)
    unpack_dimsZ = dimsZ.from_steps(( 1, Ci/64 ), ( 0, 2, 2*Ci/64 ))
    outer_g = (W/8)*Co/8
    inner_g = 8
    block_g = Ci/64
    dimsW = dimsW.from_steps(( 64/8, Co/8 ), ( 64, Ci*8, 0 ))
    dimsA = dimsA.from_steps(( 8 ), ( W*8*2, 0))
    dimsAO = dimsAO.from_steps(( W/8, Co/8 ), ( 8*8*2, 0, 0 ))
    loop_blocked = W*Co/32
    blocked_A_offset = W*64
    blocked_B_offset = 8*64
    blocked_sw_offset = 8
    tdm_scaled_sum_offset = W*Co/2
    # shift_sgemm = shift_sgemm
    sgemm_c2_wrap = W/4-1
    sgemm_c2_step = 4*8*Ci/64
    # QDQ Param setup
    loop = W*Co/32
    split_mode = 0
    vector_coeffs = 0
    dims_in1_wrap0 = 2
    dims_in1_wrap1 = 0
    dims_in1_step = 128
    dims_sum = dimsSum.from_steps(( Ci/8 ), ( 8, 0 ))
    dims_qnt = dimsQnt.from_steps(( W/4 ), ( 0, 64 ))
    dims_out = dimsOut.from_steps(( W*Co/32, 0 ), ( 64, 0, 0 ))
    
    reserved = 0
    
    print(f"mode: {mode}") # I
    print(f"zero_init: {zero_init}") # I
    print(f"final_tdm_iter: {final_tdm_iter}") # I
    print(f"wgt_size: {wgt_size}") # I
    print(f"bias_size: {bias_size}") # I
    print(f"zp_size: {zp_size}") # I
    print(f"weights_unpack_addr: {weights_unpack_addr}") # I
    print(f"tdm1_addr: {tdm1_addr}") # I
    print(f"tdm2_addr: {tdm2_addr}") # I
    print(f"tdm1s_addr: {tdm1s_addr}") # I
    print(f"tdm2s_addr: {tdm2s_addr}") # I
    print(f"qdq_addr: {qdq_addr}") # I
    print(f"sign_A: {sign_A}") # I
    print(f"sign_W: {sign_W}") # I
    
    print(f"unpack_inner_loop: {unpack_inner_loop}") # H
    
    print(f"unpack_dimsZ['num0']: {unpack_dimsZ['num0']}") # i
    print(f"unpack_dimsZ['num1']: {unpack_dimsZ['num1']}") # i
    print(f"unpack_dimsZ['inc0']: {unpack_dimsZ['inc0']}") # i
    print(f"unpack_dimsZ['inc1']: {unpack_dimsZ['inc1']}") # i
    print(f"unpack_dimsZ['inc2']: {unpack_dimsZ['inc2']}") # i
    
    print(f"outer_g: {outer_g}") # H
    print(f"inner_g: {inner_g}") # H
    print(f"block_g: {block_g}") # H
    
    print(f"dimsW['num0']: {dimsW['num0']}") # i 
    print(f"dimsW['num1']: {dimsW['num1']}") # i 
    print(f"dimsW['inc0']: {dimsW['inc0']}") # i 
    print(f"dimsW['inc1']: {dimsW['inc1']}") # i 
    print(f"dimsW['inc2']: {dimsW['inc2']}") # i  

    print(f"dimsA['num0']: {dimsA['num0']}") # i
    print(f"dimsA['inc0']: {dimsA['inc0']}") # i
    print(f"dimsA['inc1']: {dimsA['inc1']}") # i
    
    print(f"dimsAO['num0']: {dimsAO['num0']}") # i
    print(f"dimsAO['num1']: {dimsAO['num1']}") # i
    print(f"dimaAO['inc0']: {dimsAO['inc0']}") # i
    print(f"dimaAO['inc1']: {dimsAO['inc1']}") # i
    print(f"dimaAO['inc2']: {dimsAO['inc2']}") # i
    
    print(f"loop_blocked: {loop_blocked}") # h
    print(f"blocked_A_offset: {blocked_A_offset}") # I
    print(f"blocked_B_offset: {blocked_B_offset}") # I
    print(f"blocked_sw_offset: {blocked_sw_offset}") # I
    print(f"tdm_scaled_sum_offset: {tdm_scaled_sum_offset}") # I
    print(f"sgemm_c2_wrap: {sgemm_c2_wrap}") # I
    print(f"sgemm_c2_step: {sgemm_c2_step}") # I
    
    print(f"loop: {loop}") # h
    print(f"split_mode: {split_mode}") # B
    print(f"sign_out: {sign_out}") # B
    print(f"vector_coeffs: {vector_coeffs}") # b
    print(f"dims_in1_wrap0: {dims_in1_wrap0}") # h
    print(f"dims_in1_wrap1: {dims_in1_wrap1}") # h
    print(f"dims_in1_step: {dims_in1_step}") # h
    
    print(f"dims_sum['num0']: {dims_sum['num0']}") # i
    print(f"dims_sum['inc0']: {dims_sum['inc0']}") # i
    print(f"dims_sum['inc1']: {dims_sum['inc1']}") # i
     
    print(f"dims_qnt['num0']: {dims_qnt['num0']}") # i
    print(f"dims_qnt['inc0']: {dims_qnt['inc0']}") # i
    print(f"dims_qnt['inc1']: {dims_qnt['inc1']}") # i
    
    print(f"dims_out['num0']: {dims_out['num0']}") # i 
    print(f"dims_out['num1']: {dims_out['num1']}") # i
    print(f"dims_out['inc0']: {dims_out['inc0']}") # i
    print(f"dims_out['inc1']: {dims_out['inc1']}") # i
    print(f"dims_out['inc2']: {dims_out['inc2']}") # i
    
    packed_params = struct.pack(
        '<14I2H5i4H13i2h6Ih2B2b3h11i',
        mode,
        zero_init,
        final_tdm_iter,
        wgt_size,
        bias_size,
        zp_size,
        weights_unpack_addr,
        tdm1_addr,
        tdm2_addr,
        tdm1s_addr,
        tdm2s_addr,
        qdq_addr,
        sign_A,
        sign_W,
        unpack_inner_loop,
        reserved,
        
        int(unpack_dimsZ['num0']),
        int(unpack_dimsZ['num1']),
        int(unpack_dimsZ['inc0']),
        int(unpack_dimsZ['inc1']),
        int(unpack_dimsZ['inc2']),
        
        int(outer_g),
        int(inner_g),
        int(block_g),
        reserved,
        
        int(dimsW['num0']),
        int(dimsW['num1']),
        int(dimsW['inc0']),
        int(dimsW['inc1']),
        int(dimsW['inc2']),
        
        int(dimsA['num0']),
        int(dimsA['inc0']),
        int(dimsA['inc1']),
        
        int(dimsAO['num0']),
        int(dimsAO['num1']),
        int(dimsAO['inc0']),
        int(dimsAO['inc1']),
        int(dimsAO['inc2']),
        
        int(loop_blocked),
        reserved,
        int(blocked_A_offset),
        int(blocked_B_offset),
        int(blocked_sw_offset),
        int(tdm_scaled_sum_offset),
        int(sgemm_c2_wrap),
        int(sgemm_c2_step),
        
        int(loop),
        int(split_mode),
        int(sign_out),
        int(vector_coeffs),
        reserved,
        int(dims_in1_wrap0),
        int(dims_in1_wrap1),
        int(dims_in1_step),
        
        int(dims_sum['num0']),
        int(dims_sum['inc0']),
        int(dims_sum['inc1']),
        
        int(dims_qnt['num0']),
        int(dims_qnt['inc0']),
        int(dims_qnt['inc1']),
                
        int(dims_out['num0']),
        int(dims_out['num1']),
        int(dims_out['inc0']),
        int(dims_out['inc1']),
        int(dims_out['inc2']),
    )
    print("Len of packed params", len(packed_params))
    return packed_params
    
    
    