import numpy as np
import sys
import struct
from ml_dtypes import bfloat16
from L1_utils.utils import find_closest_shifted_int32, find_closest_shifted_int16

np.dtype("bfloat16")

def round_half_to_odd(inp):
    round_even_res = np.round(inp)
    inp_minus_round_even_res = np.abs(inp-round_even_res)
    bool_inp_minus_round_even_res = (inp_minus_round_even_res == 0.5).astype(np.int64)

    floor_res = np.floor(inp)
    bool_floor_res_odd = ((floor_res%2)==1).astype(np.int64)
    bool_floor_res_even = 1-bool_floor_res_odd

    #Combine 3 conditions below
    #Condition 1 if diff between round and input is not 0.5 result is fine
    #Condition 2 if diff is 0.5 and floor is odd, we should use floor result
    #Condition 3 if diff is 0.5 and floor is even, we should use floor result + 1
    round_odd_res = round_even_res*(1-bool_inp_minus_round_even_res) + \
                    (floor_res)*(bool_inp_minus_round_even_res*bool_floor_res_odd) + \
                    (floor_res+1)*(bool_inp_minus_round_even_res*bool_floor_res_even)

    return round_odd_res

def float_to_bfloat16(f):
    return struct.unpack('>H', struct.pack('>f', f)[0:2])[0]

def bfloat16_to_float(bf):
    return struct.unpack('>f', struct.pack('>H', bf) + b'\x00\x00')[0]

def test_dequant_linear(inp, scale, zero_pt):

    return np.float32(bfloat16(bfloat16(scale) * bfloat16(np.int32(inp) - np.int32(zero_pt))))

def find_closest_shifted_int32_vec(float_val_vec, INT32_MAX = 2147483647, shift_max = np.infty):

    int_vec = np.zeros(float_val_vec.size).astype(np.int32)
    shift_vec = np.zeros(float_val_vec.size).astype(np.int32)
    shift_min = 10000 #some arbitrarily large number

    for i in range(0, float_val_vec.size):
        [int_val, shift_val] = find_closest_shifted_int32(float_val_vec[i], INT32_MAX, shift_max)
        int_vec[i] = int_val
        shift_vec[i] = shift_val

        if shift_val < shift_min:
            shift_min = shift_val

    for i in range(0, float_val_vec.size):
        shift_val = shift_vec[i]

        if shift_val > shift_min:
            shift_diff = shift_val-shift_min
            int_val = int_vec[i]
            new_int_val = srs_int32_even_fast(int_val, shift_diff)
            int_vec[i] = new_int_val

    max_rel_err = -np.infty
    for i in range(0, float_val_vec.size):
        curr_golden = float_val_vec[i]
        curr_val = float(int_vec[i])/(2**shift_min)
        curr_rel_err = np.abs(curr_val-curr_golden)/curr_golden
        if curr_rel_err > max_rel_err:
            max_rel_err = curr_rel_err

    return [int_vec, shift_min]

def dq_uint16_leakyrelu_q_param_gen(
    alpha_val,
    a_dq_xscale,
    a_dq_xzero_pt,
    a_q_yscale,
    a_q_yzero_pt,
):


    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    scale_LR_fused = alpha_val*a_dq_xscale/a_q_yscale
    scale_noLR_fused = a_dq_xscale/a_q_yscale


    # max int16_c_val and max_shift_val chosen such that when upshift and add happens fpr fused qdq kerne
    # we never exceed 31 bits
    max_int16_c_val = 16383
    max_shift_val = 14

    INT32_max =  2147483647
    INT32_min = -2147483648


    [C_LR, shift_LR] = find_closest_shifted_int16(scale_LR_fused,max_int16_c_val,max_shift_val)
    [C_NLR, shift_NLR] = find_closest_shifted_int16(scale_noLR_fused, max_int16_c_val, max_shift_val)

    C0 = (a_q_yzero_pt << shift_LR) - (a_dq_xzero_pt * np.int64(C_LR))

    if C0 > INT32_max or C0 < INT32_min:
        sys.exit ('CO in leaky relu calculation has exceeded int32 range')

    C0 = C0.astype(np.int32)

    C1 = (a_q_yzero_pt << shift_NLR) - (a_dq_xzero_pt * np.int64(C_NLR))

    if C1 > INT32_max or C1 < INT32_min:
        sys.exit('C1 in leaky relu calculation has exceeded int32 range')

    C1 = C1.astype(np.int32)

    if C_LR > max_int16_c_val:
        sys.exit('C_LR in leaky relu calculation has exceeded int14 range')

    if C_NLR > max_int16_c_val:
        sys.exit('C_NLR in leaky relu calculation has exceeded int14 range')


    return [C_LR, C_NLR, shift_LR, shift_NLR, C0, C1]


def test_quant_linear(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = bfloat16(bfloat16(inp) * bfloat16(1/scale))
    res = np.round(res)
    res = np.clip(res + zero_pt, -128, 127)

    return res.astype(np.int8)


def test_quant_linear_uint8(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = bfloat16(bfloat16(inp) * bfloat16(1 / scale))

    res = np.round(res)
    res = np.clip(res + zero_pt, 0, 255)

    return res.astype(np.uint8)


def test_quant_linear_uint16(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = bfloat16(bfloat16(inp) * bfloat16(1 / scale))
    res = np.round(res)
    res = np.clip(res + zero_pt, 0, 65535)

    return res.astype(np.uint16)


def test_quant_linear_int16(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = bfloat16(bfloat16(inp) * bfloat16(1 / scale))
    res = np.round(res)
    res = np.clip(res + zero_pt, -32768, 32767)

    return res.astype(np.int16)


def test_quant_linear_int8(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = bfloat16(bfloat16(inp) * bfloat16(1 / scale))
    res = np.round(res)
    res = np.clip(res + zero_pt, -128, 127)

    return res.astype(np.int8)


#### from CARF repo

def dequant_linear_uint16_cstm(
    inp, scale, zero_pt
):

    if len(inp.shape) == 0:
        return np.float32(bfloat16(bfloat16(scale) * bfloat16(int(inp) - int(zero_pt))))
    else:
        return np.float32(bfloat16(bfloat16(scale) * bfloat16(np.int32(inp) - np.int32(zero_pt))))
    
def dequant_linear_uint8_cstm(
    inp, scale, zero_pt
):

    if len(inp.shape) == 0:
        return np.float32(bfloat16(bfloat16(scale) * bfloat16(int(inp) - int(zero_pt))))
    else:
        return np.float32(bfloat16(bfloat16(scale) * bfloat16(np.int32(inp) - np.int32(zero_pt))))


def quant_linear_uint16_cstm(
    inp, scale, zero_pt
):

    inp = bfloat16(inp)
    inv_scale = bfloat16(1 / scale)

    zero_pt_int = zero_pt.astype(np.int32)
    zero_pt_flt = zero_pt.astype(np.float32)

    scaled_in = np.float32(inp)*np.float32(inv_scale)

    scaled_in_plus_zp = scaled_in+zero_pt_flt

    if zero_pt_int%2 == 0:
        #If zero point is even, use np round since it is round to even
        res = np.round(scaled_in_plus_zp)
    else:
        #If zero point is odd, use custom round to odd code
        res = round_half_to_odd(scaled_in_plus_zp)

    #res = np.floor(scaled_in_plus_zp)
    res = np.clip(res, 0, 65535)

    return res.astype(np.uint16)

def quant_linear_uint8_cstm(
    inp, scale, zero_pt
):

    inp = bfloat16(inp)
    inv_scale = bfloat16(1 / scale)

    zero_pt_int = zero_pt.astype(np.int32)
    zero_pt_flt = zero_pt.astype(np.float32)

    scaled_in = np.float32(inp)*np.float32(inv_scale)

    scaled_in_plus_zp = scaled_in+zero_pt_flt

    if zero_pt_int%2 == 0:
        #If zero point is even, use np round since it is round to even
        res = np.round(scaled_in_plus_zp)
    else:
        #If zero point is odd, use custom round to odd code
        res = round_half_to_odd(scaled_in_plus_zp)

    #res = np.floor(scaled_in_plus_zp)
    res = np.clip(res, 0, 255)

    return res.astype(np.uint8)


def dq_uint8A_uint8A_matmul_q_param_gen(in_ch_dim, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    print(' Modified input channel dim is: ' + str(in_ch_dim))

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      sys.exit('Current AIE uint8A_uint8A qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C3*gemm_result+c2*IFM1+c1*IFM2+C0) >> shft

    C3 = c2_coeff_prime.astype(np.int32)
    C2 = c3_coeff_scale.astype(np.int32)
    C1 = ((-a_dq_xzero_pt) * c2_coeff_prime).astype(np.int32)
    if np.abs(C1) > 2147483647:  # Max int32 number
        sys.exit('Current AIE uint8A_uint8A qdq implementation does not support ifm sum shift')

    C0 = ((a_q_yzero_pt << shft) + np.int64(np.int64(a_dq_xzero_pt) * np.int64(w_dq_xzero_pt) *
                                            np.int64(in_ch_dim) * c2_coeff_prime.astype(np.int64))).astype(np.int64)

    return [np.int64(C3), np.int64(C2), np.int64(C1), np.int64(C0), np.int64(shft)]

def dq_uint8A_uint8W_matmul_q_param_gen(weights, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    in_ch_dim = weights.shape[-2]

    print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    print(' Modified input channel dim is: ' + str(in_ch_dim))

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft)

    c1_coeff = np.int64(c1_coeff)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * in_ch_dim)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      sys.exit('Current AIE uint8A_uint8W qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C2*gemm_result+c1*IFM1_sum+C0) >> shft

    C2 = c2_coeff_prime.astype(np.int32)

    C1 = (c3_coeff_scale).astype(np.int32)

    C0 = (c3_coeff_scale.astype(np.int64) * (c3_coeff_offset.astype(np.int64) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)+c1_coeff

    return [np.int64(C2), np.int64(C1), np.int64(C0), np.int64(shft)]

def dq_uint16A_uint16A_matmul_q_param_gen(in_ch_dim, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    #print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    #print(' Modified input channel dim is: ' + str(in_ch_dim))

    matmul_shift = np.int64(min(max(np.ceil(np.log2(in_ch_dim)) + 1 , 0), 15))

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft] = find_closest_shifted_int16(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      sys.exit('Current AIE uint16A_uint16A qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C3*gemm_result+c2*IFM1+c1*IFM2+C0) >> shft

    C3 = (c2_coeff_prime << matmul_shift).astype(np.int32)
    C2 = c3_coeff_scale.astype(np.int32)
    C1 = ((-a_dq_xzero_pt) * c2_coeff_prime).astype(np.int32)
    if np.abs(C1) > 2147483647:  # Max int32 number
        sys.exit('Current AIE uint16A_uint16A qdq implementation does not support ifm sum shift')

    C0 = ((a_q_yzero_pt << shft) + np.int64(np.int64(a_dq_xzero_pt) * np.int64(w_dq_xzero_pt) * np.int64(in_ch_dim) * c2_coeff_prime.astype(np.int64))).astype(np.int64)

    right_shft_matmul = matmul_shift

    shft_final = shft

    return [np.int64(C3), np.int64(C2), np.int64(C1), np.int64(C0), np.int64(right_shft_matmul), np.int64(shft_final)]

def dq_uint16A_mha_q_param_gen(qkt_in_ch_dim, qkt_a_dq_xscale, qkt_a_dq_xzero_pt, qkt_w_dq_xscale, qkt_w_dq_xzero_pt, qkt_a_q_yscale, qkt_a_q_yzero_pt,
                               smxbv_in_ch_dim, smxbv_a_dq_xscale, smxbv_a_dq_xzero_pt, smxbv_w_dq_xscale, smxbv_w_dq_xzero_pt, smxbv_a_q_yscale, smxbv_a_q_yzero_pt):

    #params for qkt matmul
    [qkt_C2, qkt_C1, qkt_C3, qkt_C0, qkt_shft_matmul,qkt_shft_final] = dq_uint16A_uint16A_matmul_q_param_gen(qkt_in_ch_dim,
                                                                                         qkt_a_dq_xscale, qkt_a_dq_xzero_pt,
                                                                                         qkt_w_dq_xscale, qkt_w_dq_xscale,
                                                                                         qkt_a_q_yscale, qkt_a_q_yzero_pt)

    # params for SMxV matmul
    [smxbv_C2, smxbv_C1, smxbv_C3, smxbv_C0, smxbv_shft_matmul,smxbv_shft_final] = dq_uint16A_uint16A_matmul_q_param_gen(smxbv_in_ch_dim,
                                                                                         smxbv_a_dq_xscale, smxbv_a_dq_xzero_pt,
                                                                                         smxbv_w_dq_xscale, smxbv_w_dq_xzero_pt,
                                                                                         smxbv_a_q_yscale, smxbv_a_q_yzero_pt)

    return [qkt_C2, qkt_C1, qkt_C3, qkt_C0, qkt_shft_matmul,qkt_shft_final, smxbv_C2, smxbv_C1, smxbv_C3, smxbv_C0, smxbv_shft_matmul,smxbv_shft_final]


def dq_uint16A_uint8W_matmul_q_param_gen(weights, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    in_ch_dim = weights.shape[-2]

    #print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    #print(' Modified input channel dim is: ' + str(in_ch_dim))

    matmul_shift = np.int64(min(max( np.ceil(np.log2(in_ch_dim)) - 7, 0), 7))

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft)

    c1_coeff = np.int64(c1_coeff)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * in_ch_dim)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      print(c3_coeff_scale)
      sys.exit('Current AIE uint16A_uint8W qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C2*(gemm_result>>matmul_shft)+c1*(IFM1_sum)+C0) >> final_shft

    C2 = (c2_coeff_prime << matmul_shift).astype(np.int32)

    C1 = (c3_coeff_scale).astype(np.int32)

    C0 = (c3_coeff_scale.astype(np.int64) * (c3_coeff_offset.astype(np.int64) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)+c1_coeff

    right_shft_matmul = matmul_shift

    shft_final = shft

    return [np.int64(C2), np.int64(C1), np.int64(C0), np.int64(right_shft_matmul), np.int64(shft_final)]

def dq_uint8A_uint8W_bias_matmul_q_param_gen(weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    in_ch_dim = weights.shape[-2]
    print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    print(' Modified input channel dim is: ' + str(in_ch_dim))

    weights = weights.astype(np.int64)
    bias_min_zp = bias.astype(np.int64) - b_dq_xzero_pt


    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    c4_coeff = float(b_dq_xscale / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    [c4_coeff_prime, shft_c4] = find_closest_shifted_int32(c4_coeff)


    if shft_c2 != shft_c4:
        diff_shft_c2_c4 = shft_c2 - shft_c4
        #print(diff_shft_c2_c4)
        abs_diff_shft_c2_c4 = np.abs(np.int64(diff_shft_c2_c4))
        if diff_shft_c2_c4 > 0:
          c4_coeff_prime = (c4_coeff_prime << abs_diff_shft_c2_c4)
        elif diff_shft_c2_c4 < 0:
          c4_coeff_prime = (c4_coeff_prime >> abs_diff_shft_c2_c4)
        else:
          c4_coeff_prime = c4_coeff_prime

    c2_coeff_prime = np.int64(c2_coeff_prime)

    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2) + (bias_min_zp * c4_coeff_prime)



    c1_coeff = np.int64(c1_coeff)

    c3_coeff_offset = np.int32(-a_dq_xzero_pt * in_ch_dim)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)


    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      print(c3_coeff_scale)
      sys.exit('Current AIE uint8A_uint8W_bias qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C2*gemm_result+c1*IFM1+C0) >> shft

    C2 = np.int64(c2_coeff_prime)

    C1 = np.int64(c3_coeff_scale)

    C0 = (c3_coeff_scale.astype(np.int64) * (c3_coeff_offset.astype(np.int64) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)+c1_coeff

    return [np.int64(C2), np.int64(C1), np.int64(C0), np.int64(shft_c2)]

def dq_uint16A_uint8W_bias_matmul_q_param_gen(weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    in_ch_dim = weights.shape[-2]
    #print('Premodified input channel dim is: ' + str(in_ch_dim))
    if (in_ch_dim % 49 == 0):  # hacky way for padding windowed attention

        in_ch_dim = np.int64(np.ceil(in_ch_dim / 49) * 64)
    #print(' Modified input channel dim is: ' + str(in_ch_dim))
    weights = weights.astype(np.int64)
    bias_min_zp = bias.astype(np.int64).squeeze() - b_dq_xzero_pt

    matmul_shift = np.int64(min(max(np.ceil(np.log2(in_ch_dim)) - 7, 0), 7))
    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    c4_coeff = float(b_dq_xscale / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    [c4_coeff_prime, shft_c4] = find_closest_shifted_int32(c4_coeff)


    if shft_c2 != shft_c4:
        diff_shft_c2_c4 = shft_c2 - shft_c4
        #print(diff_shft_c2_c4)
        abs_diff_shft_c2_c4 = np.abs(np.int64(diff_shft_c2_c4))
        if diff_shft_c2_c4 > 0:
          c4_coeff_prime = (c4_coeff_prime << abs_diff_shft_c2_c4)
        elif diff_shft_c2_c4 < 0:
          c4_coeff_prime = (c4_coeff_prime >> abs_diff_shft_c2_c4)
        else:
          c4_coeff_prime = c4_coeff_prime

    c2_coeff_prime = np.int64(c2_coeff_prime)

    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2) + (bias_min_zp * c4_coeff_prime)

    c1_coeff = np.int64(c1_coeff)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * in_ch_dim)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      print(c3_coeff_scale)
      sys.exit('Current AIE uint16A_uint8W qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C2*(gemm_result>>matmul_shft)+c1*(IFM1_sum)+C0) >> final_shft

    C2 = (c2_coeff_prime << matmul_shift).astype(np.int32)

    C1 = (c3_coeff_scale).astype(np.int32)

    C0 = (c3_coeff_scale.astype(np.int64) * (c3_coeff_offset.astype(np.int64) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)+c1_coeff

    right_shft_matmul = matmul_shift

    shft_final = shft_c2

    return [np.int64(C2), np.int64(C1), np.int64(C0), np.int64(right_shft_matmul), np.int64(shft_final)]

def dq_uint16A_uint8W_conv_q_param_gen(weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):

    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    weights = weights.astype(np.int64)
    bias = bias.astype(np.int64)

    weights_shape = np.shape(weights)
    weights_in_ch = weights.shape[-3]
    weights_ky = int(weights_shape[2])
    weights_kx = int(weights_shape[3])

    num_weights_unrolled = np.int64(weights_ky * weights_kx * weights_in_ch)

    num_weight_zp_padded = 0
    if num_weights_unrolled == (7*7*3): # hacky way of doing 1st layer weight repacking
        print('Premodified conv weight count is : ', str(num_weights_unrolled))
        num_weights_unrolled_new = np.int64(4*7*8)
        print('Modified conv weight count is : ', str(num_weights_unrolled_new))
        num_weight_zp_padded = num_weights_unrolled_new - num_weights_unrolled
        print('num weight zero padded is: ', str(num_weight_zp_padded))
        num_weights_unrolled = num_weights_unrolled_new

    conv_shift = np.int64(min(max(np.ceil(np.log2(num_weights_unrolled)) - 7, 0), 7))
    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)

    c2_coeff_prime = np.int64(c2_coeff_prime)

    #Coefficients embedded in attributes according to (c2*(conv>>conv_shft)+c1*ifm_sum+c0)>>final_shft
    #c1_coeff is actually C0
    c1_coeff = (
        (-a_dq_xzero_pt) * c2_coeff_prime * (np.sum(weights, axis=(1, 2, 3), dtype=np.int64) +
                                             np.int64(num_weight_zp_padded*w_dq_xzero_pt))
        + c2_coeff_prime * bias
        + (a_q_yzero_pt << shft_c2)
    )


    c1_coeff = np.int64(c1_coeff)

    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * num_weights_unrolled)

    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)
      print(c3_coeff_scale)
      sys.exit('Current AIE uint16A_uint8W qdq implementation does not support ifm sum shift')

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    # Parameter naming below according to (C2*(gemm_result>>conv_shft)+c1*(IFM1_sum)+C0) >> final_shft

    C2 = (c2_coeff_prime << conv_shift).astype(np.int32)

    C1 = (c3_coeff_scale).astype(np.int32)

    C0 = (c3_coeff_scale.astype(np.int64) * (c3_coeff_offset.astype(np.int64) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)+c1_coeff

    right_shft_conv = conv_shift

    shft_final = shft_c2

    return [np.int64(C2), np.int64(C1), np.int64(C0), np.int64(right_shft_conv), np.int64(shft_final)]

def dq_uint8uint16A_uint8W_bias_matmul_q_param_gen(ifm, weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):
    a=0

def srs_int32_even_fast(inp, shift):
    if shift == 0:
        round_res = np.clip(inp, -2147483648, 2147483647)
        return round_res.astype(np.int32)
    else:
        sign_inp = np.sign(inp)
        inp = abs(inp)
        inp_floor = inp >> shift
        inp_frac = inp - (inp_floor << shift)
        frac_lead_bit = inp_frac >> (shift - 1)

        frac_lead_bit_nonzero_bool_mtrx = (frac_lead_bit != 0)

        frac_lead_bit_nonzero_bool_mtrx = frac_lead_bit_nonzero_bool_mtrx.astype(np.int64)

        frac_lead_bit_zero_bool_mtrx = 1 - frac_lead_bit_nonzero_bool_mtrx

        inp_floor_even_bool_mtrx = (inp_floor % 2) == 0
        inp_floor_odd_bool_mtrx = 1 - inp_floor_even_bool_mtrx

        inp_frac_eq_half_bool_mtrx = (inp_frac == (1 << shift - 1))
        inp_frac_eq_half_bool_mtrx = inp_frac_eq_half_bool_mtrx.astype(np.int64)
        inp_frac_neq_half_bool_mtrx = 1 - inp_frac_eq_half_bool_mtrx

        inp_floor_plus_1 = inp_floor + 1

        round_res = (frac_lead_bit_zero_bool_mtrx * inp_floor)

        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_neq_half_bool_mtrx * inp_floor_plus_1
        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_odd_bool_mtrx * inp_floor_plus_1
        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_even_bool_mtrx * inp_floor
        # code snippet that is commented out is implemented in np.array format for speed
        '''
    if frac_lead_bit != 0:
      round_res = inp_floor+1
      if inp_frac == (1<<shift-1):
        if inp_floor%2==0:
          round_res = inp_floor
        else:
          round_res = inp_floor+1
    else:
      round_res = inp_floor
    '''
        round_res = sign_inp * round_res
        round_res = np.clip(round_res, -2147483648, 2147483647)

        return round_res.astype(np.int32)

def srs_uint16_even_fast(inp, shift):
    if shift == 0:
        round_res = np.clip(inp, 0, 65535)
        return round_res.astype(np.uint16)
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    frac_lead_bit_nonzero_bool_mtrx = (frac_lead_bit != 0)

    frac_lead_bit_nonzero_bool_mtrx = frac_lead_bit_nonzero_bool_mtrx.astype(np.int64)

    frac_lead_bit_zero_bool_mtrx = 1 - frac_lead_bit_nonzero_bool_mtrx

    inp_floor_even_bool_mtrx = (inp_floor % 2) == 0
    inp_floor_odd_bool_mtrx = 1 - inp_floor_even_bool_mtrx

    inp_frac_eq_half_bool_mtrx = (inp_frac == (1 << shift - 1))
    inp_frac_eq_half_bool_mtrx = inp_frac_eq_half_bool_mtrx.astype(np.int64)
    inp_frac_neq_half_bool_mtrx = 1 - inp_frac_eq_half_bool_mtrx

    inp_floor_plus_1 = inp_floor + 1

    round_res = (frac_lead_bit_zero_bool_mtrx * inp_floor)

    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_neq_half_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_odd_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_even_bool_mtrx * inp_floor
    # code snippet that is commented out is implemented in np.array format for speed
    '''
  if frac_lead_bit != 0:
    round_res = inp_floor+1
    if inp_frac == (1<<shift-1):
      if inp_floor%2==0:
        round_res = inp_floor
      else:
        round_res = inp_floor+1
  else:
    round_res = inp_floor
  '''
    round_res = sign_inp * round_res
    round_res = np.clip(round_res, 0, 65535)

    return round_res.astype(np.uint16)

# SRS with round even
def right_broadcasting(arr, target):
    return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim))
