import onnx
import onnxruntime as ort
import numpy as np
from onnx import helper, shape_inference
from collections import OrderedDict
from onnx import numpy_helper, TensorProto
from onnx.helper import (
    make_model, make_node, make_graph,
    make_tensor_value_info, make_attribute)
import csv
import math
from timeit import default_timer as timer
from ml_dtypes import bfloat16
import pickle

np.dtype('bfloat16')

vectorized_erf = np.vectorize(math.erf)

from onnxruntime_extensions import onnx_op, PyOp, get_library_path, PyCustomOpDef

import json
import sys


# Matmul uint16act X uint16 weights
@onnx_op(op_type="qdq_matmul_uint16_uint16_cstm", inputs=[PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_uint16,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16],
         outputs=[PyCustomOpDef.dt_uint16])
def qdq_matmul_uint16_uint16_cstm(activations, weights, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale,
                                  w_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)

    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    print('matmul_uint16_uint16')
    weights_in_ch = np.int64(weights.shape[-2])

    if (activations.dtype == 'uint8') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 7))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(33 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 15))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint8'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 7))
    else:
        matmul_shift = 0

    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int16(c2_coeff)

    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2)

    c1_coeff = np.int64(c1_coeff)
    num_weights_unrolled = weights_in_ch

    c3_coeff_offset = np.int32(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    ####################################################
    c3_coeff_scale_shift = 0
    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)

    int32_matmul = np.matmul(activations, weights).astype(np.int64)
    #print(np.mean(np.abs(int32_matmul.astype(np.float32))))
    #print(np.max(np.abs(int32_matmul.astype(np.float32))))
    #print(np.min(np.abs(int32_matmul.astype(np.float32))))
    ####################################################
    int32_matmul = srs_int32_even_fast(np.matmul(activations, weights).astype(np.int64), matmul_shift)
    # int32 matmul typecast as int64 to ensure the output of int32xint32 is int64
    temp_out = (c2_coeff_prime << matmul_shift).astype(np.int32) * ((int32_matmul.astype(np.int64)))
    # 2nd operand is typecast as int64 to ensure product is int64
    # Compensate for right shift of c3 coeff_scale in second operand, experiment shows for PSJ 2nd operand really only utilizes around 10 bits of 32 bits for PSJ
    new_term = c3_coeff_scale * ((((np.sum(activations, axis=(-1), dtype=np.int32)) + c3_coeff_offset.astype(
                                      np.int64)) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)

    if len(c1_coeff.shape) == (len(temp_out.shape) - 1):
        c1_coeff = right_broadcasting(c1_coeff, temp_out)
        c1_coeff = np.swapaxes(c1_coeff, -2, -1)

    temp_out += c1_coeff
    temp_out += right_broadcasting(new_term, temp_out)
    temp_out = srs_uint16_even_fast(temp_out, shft_c2)
    output = np.reshape(temp_out, temp_out.shape)

    #return test_quant_linear_uint16(np.matmul(test_dequant_linear(activations, a_dq_xscale, a_dq_xzero_pt), test_dequant_linear(weights, w_dq_xscale, w_dq_xzero_pt)),a_q_yscale, a_q_yzero_pt)
    return output


# Matmul uint16act X uint8 weights
@onnx_op(op_type="qdq_matmul_uint16_uint8_cstm", inputs=[PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_uint8,
                                                         PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                                                         PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                         PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16],
         outputs=[PyCustomOpDef.dt_uint16])
def qdq_matmul_uint16_uint8_cstm(activations, weights, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt,
                                 a_q_yscale, a_q_yzero_pt):
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    print('matmul_uint16_uint8')
    weights_in_ch = np.int64(weights.shape[-2])

    if (activations.dtype == 'uint8') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 7))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(33 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 15))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint8'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0), 7))
    else:
        matmul_shift = 0
    #print(matmul_shift)
    #matmul_shift=0
    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)

    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2)

    c1_coeff = np.int64(c1_coeff)
    num_weights_unrolled = weights_in_ch
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    ####################################################
    c3_coeff_scale_shift = 0
    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)
    #int32_matmul = np.matmul(activations, weights).astype(np.int64)
    #print(np.mean(np.abs(int32_matmul.astype(np.float32))))
    #print(np.max(np.abs(int32_matmul.astype(np.float32))))
    #print(np.min(np.abs(int32_matmul.astype(np.float32))))
    ####################################################
    int32_matmul = srs_int32_even_fast(np.matmul(activations, weights), matmul_shift)
    # int32 matmul typecast as int64 to ensure the output of int32xint32 is int64
    temp_out = (c2_coeff_prime << matmul_shift).astype(np.int32) * ((int32_matmul.astype(np.int64)))
    # 2nd operand is typecast as int64 to ensure product is int64
    # Compensate for right shift of c3 coeff_scale in second operand, experiment shows for PSJ 2nd operand really only utilizes around 10 bits of 32 bits for PSJ
    new_term = c3_coeff_scale * ((((np.sum(activations, axis=(-1), dtype=np.int32)) + c3_coeff_offset.astype(
                                      np.int64)) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)

    if len(c1_coeff.shape) == (len(temp_out.shape) - 1):
        c1_coeff = right_broadcasting(c1_coeff, temp_out)
        c1_coeff = np.swapaxes(c1_coeff, -2, -1)

    temp_out += c1_coeff
    temp_out += right_broadcasting(new_term, temp_out)
    temp_out = srs_uint16_even_fast(temp_out, shft_c2)
    output = np.reshape(temp_out, temp_out.shape)

    return output

@onnx_op(op_type="qdq_gelu_uint8_cstm",
         inputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8],
         outputs=[PyCustomOpDef.dt_uint8])

def qdq_gelu_uint8_cstm(activations, a_dq_xscale, a_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):
  print('gelu_uint8')
  dq_x = test_dequant_linear(activations, a_dq_xscale, a_dq_xzero_pt)
  gelu_dq_x = gelu_lut_approx_bfloat16_6_384(dq_x)

  return test_quant_linear_uint8(gelu_dq_x, a_q_yscale, a_q_yzero_pt)

@onnx_op(op_type="qdq_matmul_uint8_uint8_cstm", inputs= [PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8],
                                                          outputs=[PyCustomOpDef.dt_uint8],
                                                          attrs={'coeff': PyCustomOpDef.dt_int64, 'shift': PyCustomOpDef.dt_int64})
def qdq_matmul_uint8_uint8_cstm(activations, weights, a_dq_xscale, a_dq_xzero_pt,  w_dq_xscale, w_dq_xzero_pt,
                                 a_q_yscale, a_q_yzero_pt, **kwargs):


    ################

    print('matmul_uint8_uint8')
    ##############
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    #print('matmul_uint8_uint8')
    weights_in_ch = np.int64(weights.shape[-2])

    if (activations.dtype == 'uint8') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(max(33 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint8'):
        matmul_shift = np.int64(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    else:
        #print('uint8 x uint8')
        matmul_shift = 0
    #print('Matmul shift = ' + str(matmul_shift))
    #matmul_shift=0
    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)

    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2)

    c1_coeff = np.int64(c1_coeff)
    num_weights_unrolled = weights_in_ch
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    ####################################################
    c3_coeff_scale_shift = 0
    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647: #Max int32 number
      c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale)))-31)

    else:
      c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)
    #int32_matmul = np.matmul(activations, weights).astype(np.int64)
    #print(np.mean(np.abs(int32_matmul.astype(np.float32))))
    #print(np.max(np.abs(int32_matmul.astype(np.float32))))
    #print(np.min(np.abs(int32_matmul.astype(np.float32))))
    ####################################################
    int32_matmul = srs_int32_even_fast(np.matmul(activations, weights), matmul_shift)
    # int32 matmul typecast as int64 to ensure the output of int32xint32 is int64
    temp_out = (c2_coeff_prime << matmul_shift).astype(np.int32) * ((int32_matmul.astype(np.int64)))
    # 2nd operand is typecast as int64 to ensure product is int64
    # Compensate for right shift of c3 coeff_scale in second operand, experiment shows for PSJ 2nd operand really only utilizes around 10 bits of 32 bits for PSJ
    new_term = c3_coeff_scale * ((((np.sum(activations, axis=(-1), dtype=np.int32)) + c3_coeff_offset.astype(
                                      np.int64)) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)

    if len(c1_coeff.shape) == (len(temp_out.shape) - 1):
        c1_coeff = right_broadcasting(c1_coeff, temp_out)
        c1_coeff = np.swapaxes(c1_coeff, -2, -1)

    temp_out += c1_coeff
    temp_out += right_broadcasting(new_term, temp_out)
    temp_out = srs_uint8_even_fast(temp_out, shft_c2)
    output = np.reshape(temp_out, temp_out.shape)

    #print('Matmul shift   = ' + str(matmul_shift))
    #print('c3 coeff shift = ' + str(c3_coeff_scale_shift))
    #print('c2 coeff = '       + str(c2_coeff_prime))
    #print('c2 coeff shift = ' + str(shft_c2))
    #print(new_term)

    return output
####$$$$$$$$$$$$$$$$$$$$$$$$$#############
@onnx_op(op_type="qdq_matmul_uint8_uint8_bias_cstm", inputs= [PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                                          PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8],
                                                          outputs=[PyCustomOpDef.dt_uint8])
def qdq_matmul_uint8_uint8_bias_cstm(activations, weights, bias, a_dq_xscale, a_dq_xzero_pt,  w_dq_xscale, w_dq_xzero_pt,
                                     b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):
    #return test_quant_linear_uint8(np.matmul(test_dequant_linear(activations,a_dq_xscale,a_dq_xzero_pt),test_dequant_linear(weights,w_dq_xscale,w_dq_xzero_pt))+test_dequant_linear(bias,b_dq_xscale,b_dq_xzero_pt), a_q_yscale, a_q_yzero_pt)
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)
    print('matmul_uint8_uint8_bias')
    weights_in_ch = np.int64(weights.shape[-2])

    if (activations.dtype == 'uint8') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(max(33 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint8'):
        matmul_shift = np.int64(max(25 + np.ceil(np.log2(weights_in_ch)) - 33, 0))
    else:
        # print('uint8 x uint8')
        matmul_shift = 0
    # print(matmul_shift)
    # matmul_shift=0
    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)
    bias_min_zp = bias.astype(np.int64) - b_dq_xzero_pt

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    c4_coeff = float(b_dq_xscale / a_q_yscale)
    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    [c4_coeff_prime, shft_c4] = find_closest_shifted_int32(c4_coeff)
    if shft_c2 != shft_c4:
        diff_shft_c2_c4 = shft_c2 - shft_c4
        #print(diff_shft_c2_c4)
        abs_diff_shft_c2_c4 = np.abs(np.int64(diff_shft_c2_c4))
        if diff_shft_c2_c4 > 0:
          c4_coeff_prime = (c4_coeff_prime << abs_diff_shft_c2_c4)
        elif diff_shft_c2_c4 < 0:
          c4_coeff_prime = (c4_coeff_prime >> abs_diff_shft_c2_c4)
        else:
          c4_coeff_prime = c4_coeff_prime

    #print(shft_c2)
    #print(shft_c4)



    #print(c4_coeff_prime)

    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(-2), dtype=np.int64) + \
               (a_q_yzero_pt << shft_c2) + (bias_min_zp * c4_coeff_prime)
    c1_coeff = np.int64(c1_coeff)
    num_weights_unrolled = weights_in_ch
    c3_coeff_offset = np.int32(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    ####################################################
    c3_coeff_scale_shift = 0
    # right shift c3 coeff_scale to ensure fits into int32
    if np.abs(c3_coeff_scale) > 2147483647:  # Max int32 number
        c3_coeff_scale_shift = np.int64(np.ceil(np.log2(np.abs(c3_coeff_scale))) - 31)

    else:
        c3_coeff_scale_shift = 0

    c3_coeff_scale = (c3_coeff_scale >> c3_coeff_scale_shift).astype(np.int32)
    # int32_matmul = np.matmul(activations, weights).astype(np.int64)
    # print(np.mean(np.abs(int32_matmul.astype(np.float32))))
    # print(np.max(np.abs(int32_matmul.astype(np.float32))))
    # print(np.min(np.abs(int32_matmul.astype(np.float32))))
    ####################################################
    int32_matmul = srs_int32_even_fast(np.matmul(activations, weights), matmul_shift)
    # int32 matmul typecast as int64 to ensure the output of int32xint32 is int64
    temp_out = (c2_coeff_prime << matmul_shift).astype(np.int32) * ((int32_matmul.astype(np.int64)))
    # 2nd operand is typecast as int64 to ensure product is int64
    # Compensate for right shift of c3 coeff_scale in second operand, experiment shows for PSJ 2nd operand really only utilizes around 10 bits of 32 bits for PSJ
    new_term = c3_coeff_scale * ((((np.sum(activations, axis=(-1), dtype=np.int32)) + c3_coeff_offset.astype(
        np.int64)) << c3_coeff_scale_shift).astype(np.int64)).astype(np.int64)

    if len(c1_coeff.shape) == (len(temp_out.shape) - 1):
        c1_coeff = right_broadcasting(c1_coeff, temp_out)
        c1_coeff = np.swapaxes(c1_coeff, -2, -1)

    temp_out += c1_coeff
    temp_out += right_broadcasting(new_term, temp_out)
    temp_out = srs_uint8_even_fast(temp_out, shft_c2)
    output = np.reshape(temp_out, temp_out.shape)
    #print(output.shape)

    return  output



####$$$$$$$$$$$$$$$$$$$$$$############

# Function below is a fixed point representation of conv qdq kernel with zero valued weight zero point
# and non-zero valued activation zero point
@onnx_op(op_type="qdq_conv2d_weightsZPeq0_cstm",
         inputs=[PyCustomOpDef.dt_int8, PyCustomOpDef.dt_int8, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_int8, PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8],
         outputs=[PyCustomOpDef.dt_int8],
         attrs={'pads_h_beg': PyCustomOpDef.dt_int64, 'pads_w_beg': PyCustomOpDef.dt_int64,
                'pads_h_end': PyCustomOpDef.dt_int64, 'pads_w_end': PyCustomOpDef.dt_int64,
                'strides_h': PyCustomOpDef.dt_int64, 'strides_w': PyCustomOpDef.dt_int64})
def qdq_conv2d_weightsZPeq0_cstm(activations, weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale, w_dq_xzero_pt,
                                 b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt, **kwargs):
    print('Conv working')
    pad_height_begin = kwargs.get('pads_h_beg')
    pad_height_end = kwargs.get('pads_h_end')

    pad_width_begin = kwargs.get('pads_w_beg')
    pad_width_end = kwargs.get('pads_w_end')

    stride_height = kwargs.get('strides_h')
    stride_width = kwargs.get('strides_w')

    weights_shape = np.shape(weights)
    weights_out_ch = int(weights_shape[0])
    weights_in_ch = int(weights_shape[1])
    weights_ky = int(weights_shape[2])
    weights_kx = int(weights_shape[3])

    acts_shape = np.shape(activations)
    acts_height = int(acts_shape[2])
    acts_width = int(acts_shape[3])

    activations = np.reshape(activations, (weights_in_ch, acts_height, acts_width))

    padded_acts_height = int(acts_height + pad_height_begin + pad_height_end)
    padded_acts_width = int(acts_width + pad_width_begin + pad_width_end)

    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)
    bias = bias.astype(np.int64)
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    input_act_padded = np.ones((weights_in_ch, padded_acts_height, padded_acts_width), dtype=np.int64) * a_dq_xzero_pt
    input_act_padded[:, pad_height_begin:pad_height_begin + acts_height,
    pad_width_begin:pad_width_begin + acts_width] = activations

    out_height = int((padded_acts_height - weights_ky) / stride_height + 1)
    out_width = int((padded_acts_width - weights_kx) / stride_width + 1)

    output = np.zeros((weights_out_ch, out_height, out_width), dtype=np.int8)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)

    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(1, 2, 3)) + c2_coeff_prime * bias + (
            a_q_yzero_pt << shft_c2)
    c1_coeff = np.array(c1_coeff, dtype=np.int64)

    num_weights_unrolled = int(weights_ky * weights_kx * weights_in_ch)

    for out_ch in range(weights_out_ch):

        for in_height in range(0, padded_acts_height - weights_ky + 1, stride_height):
            out_h_idx = int(in_height / stride_height)

            for in_width in range(0, padded_acts_width - weights_kx + 1, stride_width):
                out_w_idx = int(in_width / stride_width)
                curr_weights = weights[out_ch, :, :, :].reshape(num_weights_unrolled)
                curr_acts = input_act_padded[:, in_height:(in_height + weights_ky), \
                            in_width:(in_width + weights_kx)].reshape(num_weights_unrolled)
                output[out_ch][out_h_idx][out_w_idx] = srs_int8_even(
                    c2_coeff_prime * np.dot(curr_acts, curr_weights) + c1_coeff[out_ch], shft_c2)

    output = np.reshape(output, (1, weights_out_ch, out_height, out_width))

    return output


# Function below is a fixed point representation of add qdq kernel
@onnx_op(op_type="qdq_add_cstm_int8", inputs=[PyCustomOpDef.dt_int8, PyCustomOpDef.dt_int8,
                                              PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8,
                                              PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8,
                                              PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8],
         outputs=[PyCustomOpDef.dt_int8])
def qdq_add_cstm_int8(in0, in1, in0_dq_xscale, in0_dq_xzero_pt, in1_dq_xscale, in1_dq_xzero_pt, a_q_yscale,
                      a_q_yzero_pt):
    print('Addint8')
    '''
    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)

    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / in0_dq_xscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    coeff_0_prime = coeff_0_prime.astype(np.int64)
    coeff_1_prime = coeff_1_prime.astype(np.int64)

    shft_0 = shft_0.astype(np.int64)
    shft_1 = shft_1.astype(np.int64)

    temp_in_sum = (in0 << shft_1) + (coeff_1_prime * in1)
    temp_offset_sum = (in0_dq_xzero_pt << shft_1) + (coeff_1_prime * in1_dq_xzero_pt)
    temp_sum = (temp_in_sum - temp_offset_sum) * coeff_0_prime + (a_q_yzero_pt << (shft_0 + shft_1))
    temp_sum_shape = np.shape(temp_sum)

    temp_sum_flat = temp_sum.flatten()
    res_int8 = np.zeros(temp_sum_flat.shape, dtype=np.int8)

    for i in range(len(res_int8)):
        res_int8[i] = srs_int8_even(temp_sum_flat[i], shft_0 + shft_1)

    res_int8 = res_int8.reshape(temp_sum_shape)

    return res_int8
    '''

    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)
    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)


    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / a_q_yscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    #print('good')
    #print((((in0 - in0_dq_xzero_pt) * coeff_0_prime) >> shft_0))
    return srs_int8_even((((in0-in0_dq_xzero_pt)*np.int64(coeff_0_prime)) >> shft_0)+(((in1-in1_dq_xzero_pt)*np.int64(coeff_1_prime)) >> shft_1)+a_q_yzero_pt,np.int64(0))

# Function below is a fixed point representation of add qdq kernel
@onnx_op(op_type="qdq_add_cstm_uint16", inputs=[PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_uint16,
                                                PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                                                PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                                                PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16],
         outputs=[PyCustomOpDef.dt_uint16])
def qdq_add_cstm_uint16(in0, in1, in0_dq_xscale, in0_dq_xzero_pt, in1_dq_xscale, in1_dq_xzero_pt, a_q_yscale,
                        a_q_yzero_pt):
    print('Adduint16')
    '''
    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)
    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)


    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / a_q_yscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    #print('good')
    #print((((in0 - in0_dq_xzero_pt) * coeff_0_prime) >> shft_0))
    return srs_uint16_even_fast((((in0-in0_dq_xzero_pt)*coeff_0_prime) >> shft_0)+(((in1-in1_dq_xzero_pt)*coeff_1_prime) >> shft_1)+a_q_yzero_pt,np.int64(0))
    #return srs_uint16_even_fast(temp_sum, shft_0 + shft_1)
    #return test_quant_linear_uint16(test_dequant_linear(in0,in0_dq_xscale,in0_dq_xzero_pt)+test_dequant_linear(in1,in1_dq_xscale,in1_dq_xzero_pt), a_q_yscale, a_q_yzero_pt)
    '''

    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)
    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)


    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / a_q_yscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    final_shift = max(shft_0, shft_1)
    coeff_0_prime = np.int64(coeff_0_prime)
    coeff_1_prime = np.int64(coeff_1_prime)

    term_0 = (in0-in0_dq_xzero_pt)*coeff_0_prime
    term_1 = (in1 - in1_dq_xzero_pt) * coeff_1_prime

    if shft_0 < shft_1:
        term_0 = np.int64(term_0 << np.int64(shft_1-shft_0))
    else:
        term_1 = np.int64( term_1 << np.int64(shft_0 - shft_1))

    return (srs_uint16_even_fast(term_0+term_1+np.int64(a_q_yzero_pt << final_shift), np.int64(final_shift)))


@onnx_op(op_type="qdq_add_cstm_uint8", inputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8,
                                               PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                               PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8,
                                               PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8],
         outputs=[PyCustomOpDef.dt_uint8])
def qdq_add_cstm_uint8(in0, in1, in0_dq_xscale, in0_dq_xzero_pt, in1_dq_xscale, in1_dq_xzero_pt, a_q_yscale,
                       a_q_yzero_pt):
    '''
    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)

    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / in0_dq_xscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    coeff_0_prime = np.int64(coeff_0_prime)
    coeff_1_prime = np.int64(coeff_1_prime)

    shft_0 = np.int64(shft_0)
    shft_1 = np.int64(shft_1)

    temp_in_sum = (in0 << shft_1) + (coeff_1_prime * in1)
    temp_offset_sum = (in0_dq_xzero_pt << shft_1) + (coeff_1_prime * in1_dq_xzero_pt)
    temp_sum = (temp_in_sum - temp_offset_sum) * coeff_0_prime + (a_q_yzero_pt << (shft_0 + shft_1))

    return srs_uint8_even_fast(temp_sum, shft_0 + shft_1)
    '''
    print('Adduint8')

    in0 = in0.astype(np.int64)
    in1 = in1.astype(np.int64)
    in0_dq_xzero_pt = in0_dq_xzero_pt.astype(np.int64)
    in1_dq_xzero_pt = in1_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)


    coeff_0 = np.float32(in0_dq_xscale / a_q_yscale)
    coeff_1 = np.float32(in1_dq_xscale / a_q_yscale)

    [coeff_0_prime, shft_0] = find_closest_shifted_int32(coeff_0)
    [coeff_1_prime, shft_1] = find_closest_shifted_int32(coeff_1)

    final_shift = max(shft_0, shft_1)
    coeff_0_prime = np.int64(coeff_0_prime)
    coeff_1_prime =  np.int64(coeff_1_prime)

    term_0 = (in0-in0_dq_xzero_pt)*coeff_0_prime
    term_1 = (in1 - in1_dq_xzero_pt) * coeff_1_prime

    if shft_0 < shft_1 :
        term_0 = np.int64(term_0 << np.int64(shft_1-shft_0))
    else:
        term_1 = np.int64( term_1 << np.int64(shft_0 - shft_1))
    #print((((in0 - in0_dq_xzero_pt) * coeff_0_prime) >> shft_0))
    return (srs_uint8_even_fast(term_0+term_1+np.int64(a_q_yzero_pt << final_shift), np.int64(final_shift)))
    #return srs_uint8_even_fast((((in0-in0_dq_xzero_pt)*coeff_0_prime) >> shft_0)+(((in1-in1_dq_xzero_pt)*coeff_1_prime) >> shft_1)+a_q_yzero_pt,np.int64(0))
    #return srs_uint16_even_fast(temp_sum, shft_0 + shft_1)
    #return test_quant_linear_uint16(test_dequant_linear(in0,in0_dq_xscale,in0_dq_xzero_pt)+test_dequant_linear(in1,in1_dq_xscale,in1_dq_xzero_pt), a_q_yscale, a_q_yzero_pt)

#############################


# Function below is a fixed point representation of conv qdq kernel with non-zero valued weight zero point
# and non-zero valued activation zero point
@onnx_op(op_type="qdq_conv2d_weightsZPeq0_uint8_cstm",
         inputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8],
         outputs=[PyCustomOpDef.dt_uint8],
         attrs={'pads_h_beg': PyCustomOpDef.dt_int64, 'pads_w_beg': PyCustomOpDef.dt_int64,
                'pads_h_end': PyCustomOpDef.dt_int64, 'pads_w_end': PyCustomOpDef.dt_int64,
                'strides_h': PyCustomOpDef.dt_int64, 'strides_w': PyCustomOpDef.dt_int64})
def qdq_conv2d_weightsZPeq0_uint8_cstm(activations, weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale,
                                       w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt, **kwargs):
    print('Convint8')
    pad_height_begin = kwargs.get('pads_h_beg')
    pad_height_end = kwargs.get('pads_h_end')

    pad_width_begin = kwargs.get('pads_w_beg')
    pad_width_end = kwargs.get('pads_w_end')

    stride_height = kwargs.get('strides_h')
    stride_width = kwargs.get('strides_w')
    '''
  pad_height_begin = int(pad_vec[0])
  pad_height_end = int(pad_vec[2])

  pad_width_begin = int(pad_vec[1])
  pad_width_end = int(pad_vec[3])

  stride_height = int(stride_vec[0])
  stride_width = int(stride_vec[1])
  '''

    weights_shape = np.shape(weights)
    weights_out_ch = int(weights_shape[0])
    weights_in_ch = int(weights_shape[1])
    weights_ky = int(weights_shape[2])
    weights_kx = int(weights_shape[3])

    acts_shape = np.shape(activations)
    acts_height = int(acts_shape[2])
    acts_width = int(acts_shape[3])
    acts_in_ch = int(acts_shape[1])

    # if weights_in_ch ==1: #depthwise conv
    activations = np.reshape(activations, (acts_in_ch, acts_height, acts_width))

    padded_acts_height = int(acts_height + pad_height_begin + pad_height_end)
    padded_acts_width = int(acts_width + pad_width_begin + pad_width_end)

    activations = activations.astype(np.int32)
    weights = weights.astype(np.int32)
    bias = bias.astype(np.int64)
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int32)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    weights_in_ch = np.int64(weights_in_ch)

    input_act_padded = np.ones((acts_in_ch, padded_acts_height, padded_acts_width), dtype=np.int32) * a_dq_xzero_pt
    input_act_padded[:, pad_height_begin:pad_height_begin + acts_height,
    pad_width_begin:pad_width_begin + acts_width] = activations

    out_height = int((padded_acts_height - weights_ky) / stride_height + 1)
    out_width = int((padded_acts_width - weights_kx) / stride_width + 1)

    output = np.zeros((weights_out_ch, out_height, out_width), dtype=np.int64)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(1, 2, 3),
                                                          dtype=np.int64) + c2_coeff_prime * bias + (
                       a_q_yzero_pt << shft_c2)

    c1_coeff = np.array(c1_coeff, dtype=np.int64)

    num_weights_unrolled = np.int64((weights_ky * weights_kx * weights_in_ch))

    c3_coeff_offset = np.int64(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    curr_weights = weights[:, :, :, :].reshape(weights_out_ch, num_weights_unrolled)

    if weights_in_ch > 1:
        curr_acts = np.zeros((weights_in_ch * weights_ky * weights_ky, (out_height * out_width)), dtype=np.int32)
        for in_height in range(0, padded_acts_height - weights_ky + 1, stride_height):
            out_h_idx = int(in_height / stride_height)
            # print(out_h_idx)
            for in_width in range(0, padded_acts_width - weights_kx + 1, stride_width):
                out_w_idx = int(in_width / stride_width)
                curr_acts[:, out_h_idx * out_width + out_w_idx] = input_act_padded[:,
                                                                  in_height:(in_height + weights_ky),
                                                                  in_width:(in_width + weights_kx)].reshape(
                    num_weights_unrolled)

        temp_out = c2_coeff_prime * (np.matmul(curr_weights, curr_acts).astype(np.int64))
        temp_out += c3_coeff_scale * (np.sum(curr_acts, axis=(0)) + c3_coeff_offset)
    else:  # dw conv, revisit to pass group through arg
        temp_out = np.zeros((weights_out_ch, out_height, out_width), dtype=np.int64)
        for out_ch_idx in range(0, weights_out_ch):
            for in_height in range(0, padded_acts_height - weights_ky + 1, stride_height):
                out_h_idx = int(in_height / stride_height)
                for in_width in range(0, padded_acts_width - weights_kx + 1, stride_width):
                    out_w_idx = int(in_width / stride_width)
                    curr_acts = input_act_padded[out_ch_idx, in_height:(in_height + weights_ky),
                                in_width:(in_width + weights_kx)].reshape(num_weights_unrolled)
                    temp_out[out_ch_idx][out_h_idx][out_w_idx] = c2_coeff_prime * np.dot(curr_acts,
                                                                                         curr_weights[out_ch_idx, :])
                    temp_out[out_ch_idx][out_h_idx][out_w_idx] += c3_coeff_scale * (
                                np.sum(curr_acts, axis=0) + c3_coeff_offset)

    output = np.reshape(temp_out, (weights_out_ch, out_height, out_width))

    for out_ch in range(weights_out_ch):
        output[out_ch, :, :] += c1_coeff[out_ch]

    output = srs_uint8_even_fast(output, shft_c2)

    output = np.reshape(output, (1, weights_out_ch, out_height, out_width))

    return output


# Function below is a fixed point representation of conv qdq kernel with non-zero valued weight zero point
# and non-zero valued activation zero point
@onnx_op(op_type="qdq_conv2d_weightsZPeq0_uint16_cstm",
         inputs=[PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_float,
                 PyCustomOpDef.dt_int32, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16],
         outputs=[PyCustomOpDef.dt_uint16],
         attrs={'pads_h_beg': PyCustomOpDef.dt_int64, 'pads_w_beg': PyCustomOpDef.dt_int64,
                'pads_h_end': PyCustomOpDef.dt_int64, 'pads_w_end': PyCustomOpDef.dt_int64,
                'strides_h': PyCustomOpDef.dt_int64, 'strides_w': PyCustomOpDef.dt_int64})
def qdq_conv2d_weightsZPeq0_uint16_cstm(activations, weights, bias, a_dq_xscale, a_dq_xzero_pt, w_dq_xscale,
                                        w_dq_xzero_pt, b_dq_xscale, b_dq_xzero_pt, a_q_yscale, a_q_yzero_pt, **kwargs):
    print('Convint16')
    pad_height_begin = kwargs.get('pads_h_beg')
    pad_height_end = kwargs.get('pads_h_end')

    pad_width_begin = kwargs.get('pads_w_beg')
    pad_width_end = kwargs.get('pads_w_end')

    stride_height = kwargs.get('strides_h')
    stride_width = kwargs.get('strides_w')
    '''
    pad_height_begin = int(pad_vec[0])
    pad_height_end = int(pad_vec[2])

    pad_width_begin = int(pad_vec[1])
    pad_width_end = int(pad_vec[3])

    stride_height = int(stride_vec[0])
    stride_width = int(stride_vec[1])
    '''
    weights_shape = np.shape(weights)
    weights_in_ch = int(weights_shape[1])
    weights_ky = int(weights_shape[2])
    weights_kx = int(weights_shape[3])

    if (activations.dtype == 'uint8') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch*weights_ky*weights_kx)) - 33, 0), 7))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint16'):
        matmul_shift = np.int64(min(max(33 + np.ceil(np.log2(weights_in_ch*weights_ky*weights_kx)) - 33, 0), 15))
    elif (activations.dtype == 'uint16') and (weights.dtype == 'uint8'):
        matmul_shift = np.int64(min(max(25 + np.ceil(np.log2(weights_in_ch*weights_ky*weights_kx)) - 33, 0), 7))
    else:
        matmul_shift = 0

    print(matmul_shift)

    weights_out_ch = int(weights_shape[0])


    acts_shape = np.shape(activations)
    acts_height = int(acts_shape[2])
    acts_width = int(acts_shape[3])
    acts_in_ch = int(acts_shape[1])

    # if weights_in_ch ==1: #depthwise conv
    activations = np.reshape(activations, (acts_in_ch, acts_height, acts_width))

    padded_acts_height = int(acts_height + pad_height_begin + pad_height_end)
    padded_acts_width = int(acts_width + pad_width_begin + pad_width_end)

    activations = activations.astype(np.int64)
    weights = weights.astype(np.int64)
    bias = bias.astype(np.int64)
    a_dq_xzero_pt = a_dq_xzero_pt.astype(np.int64)
    w_dq_xzero_pt = w_dq_xzero_pt.astype(np.int64)
    a_q_yzero_pt = a_q_yzero_pt.astype(np.int64)

    weights_in_ch = np.int64(weights_in_ch)

    input_act_padded = np.ones((acts_in_ch, padded_acts_height, padded_acts_width), dtype=np.int64) * a_dq_xzero_pt
    input_act_padded[:, pad_height_begin:pad_height_begin + acts_height,
    pad_width_begin:pad_width_begin + acts_width] = activations

    out_height = int((padded_acts_height - weights_ky) / stride_height + 1)
    out_width = int((padded_acts_width - weights_kx) / stride_width + 1)

    output = np.zeros((weights_out_ch, out_height, out_width), dtype=np.int64)

    c2_coeff = float((a_dq_xscale * w_dq_xscale) / a_q_yscale)
    [c2_coeff_prime, shft_c2] = find_closest_shifted_int32(c2_coeff)
    c2_coeff_prime = np.int64(c2_coeff_prime)

    # c1_coeff can be computed at compile time
    c1_coeff = (-a_dq_xzero_pt) * c2_coeff_prime * np.sum(weights, axis=(1, 2, 3),
                                                          dtype=np.int64) + c2_coeff_prime * bias + (
                       a_q_yzero_pt << shft_c2)

    c1_coeff = np.array(c1_coeff, dtype=np.int64)
    num_weights_unrolled = np.int64((weights_ky * weights_kx * weights_in_ch))

    c3_coeff_offset = np.int64(-a_dq_xzero_pt * num_weights_unrolled)
    c3_coeff_scale = np.int64(-c2_coeff_prime * w_dq_xzero_pt)
    curr_weights = weights[:, :, :, :].reshape(weights_out_ch, num_weights_unrolled)

    if weights_in_ch > 1:
        curr_acts = np.zeros((weights_in_ch * weights_ky * weights_ky, (out_height * out_width)), dtype=np.int64)
        for in_height in range(0, padded_acts_height - weights_ky + 1, stride_height):
            out_h_idx = int(in_height / stride_height)
            # print(out_h_idx)
            for in_width in range(0, padded_acts_width - weights_kx + 1, stride_width):
                out_w_idx = int(in_width / stride_width)
                curr_acts[:, out_h_idx * out_width + out_w_idx] = input_act_padded[:,
                                                                  in_height:(in_height + weights_ky),
                                                                  in_width:(in_width + weights_kx)].reshape(
                    num_weights_unrolled)

        print(np.max(np.abs(np.matmul(curr_weights, curr_acts))))
        int32_matmul = srs_int32_even_fast(np.matmul(curr_weights, curr_acts).astype(np.int64), matmul_shift)
        temp_out = (c2_coeff_prime << matmul_shift) * int32_matmul.astype(np.int64)
        #temp_out = c2_coeff_prime * ((np.matmul(curr_weights, curr_acts) >> matmul_shift).astype(np.int32)).astype(np.int64)
        temp_out += c3_coeff_scale * (np.sum(curr_acts, axis = 0) + c3_coeff_offset)

    else:  # dw conv, revisit to pass group through arg
        temp_out = np.zeros((weights_out_ch, out_height, out_width), dtype=np.int64)
        for out_ch_idx in range(0, weights_out_ch):
            for in_height in range(0, padded_acts_height - weights_ky + 1, stride_height):
                out_h_idx = int(in_height / stride_height)
                for in_width in range(0, padded_acts_width - weights_kx + 1, stride_width):
                    out_w_idx = int(in_width / stride_width)
                    curr_acts = input_act_padded[out_ch_idx, in_height:(in_height + weights_ky),
                                in_width:(in_width + weights_kx)].reshape(num_weights_unrolled)
                    temp_out[out_ch_idx][out_h_idx][out_w_idx] = c2_coeff_prime * np.dot(curr_acts,
                                                                                         curr_weights[out_ch_idx, :])
                    temp_out[out_ch_idx][out_h_idx][out_w_idx] += c3_coeff_scale * (
                                np.sum(curr_acts, axis=0) + c3_coeff_offset)

    output = np.reshape(temp_out, (weights_out_ch, out_height, out_width))

    for out_ch in range(weights_out_ch):
        output[out_ch, :, :] += c1_coeff[out_ch]

    output = srs_uint16_even_fast(output, shft_c2)

    output = np.reshape(output, (1, weights_out_ch, out_height, out_width))
    return output


@onnx_op(op_type="qdq_gelu_uint16_cstm",
         inputs=[PyCustomOpDef.dt_uint16, PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_uint16],
         outputs=[PyCustomOpDef.dt_uint16])
def qdq_gelu_uint16_cstm(activations, a_dq_xscale, a_dq_xzero_pt, a_q_yscale, a_q_yzero_pt):
    print('gelu_uint16')
    dq_x = test_dequant_linear(activations, a_dq_xscale, a_dq_xzero_pt)
    gelu_dq_x = gelu_lut_approx_bfloat16_6_384(dq_x)

    return test_quant_linear_uint16(gelu_dq_x, a_q_yscale, a_q_yzero_pt)


@onnx_op(op_type="instance_norm_int16_cstm",
         inputs=[PyCustomOpDef.dt_int16, PyCustomOpDef.dt_int8, PyCustomOpDef.dt_int32,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int16,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int32,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int16],
         outputs=[PyCustomOpDef.dt_int16],
         attrs={'epsilon': PyCustomOpDef.dt_float, 'is_relu': PyCustomOpDef.dt_int64})
def instance_norm_int16_cstm(x, gamma, bias, scale_activations, a_zp, scale_gamma, gamma_zp,
                             scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt, **kwargs):
    is_relu = bool(kwargs.get('is_relu'))
    epsilon = np.float32(kwargs.get('epsilon'))

    return  _instancenorm_bf16_act_kernelref_int16_change(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp,
                                  scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt)

@onnx_op(op_type="instance_norm_int8_cstm",
         inputs=[PyCustomOpDef.dt_int16, PyCustomOpDef.dt_int8, PyCustomOpDef.dt_int32,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int16,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int32,
                 PyCustomOpDef.dt_float, PyCustomOpDef.dt_int8],
         attrs={'epsilon': PyCustomOpDef.dt_float, 'is_relu': PyCustomOpDef.dt_int64},
         outputs=[PyCustomOpDef.dt_int8])
def instance_norm_int8_cstm(x, gamma, bias, scale_activations, a_zp, scale_gamma, gamma_zp,
                             scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt, **kwargs):

  epsilon = kwargs.get('epsilon')
  is_relu = kwargs.get('is_relu')
  print('instnorm')
  return _instancenorm_bf16_act_kernelref_int8_change(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp,
                                    scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt)


def gelu_lut_approx_fp32_6_384(val):
    LUT_len = 384
    indx_start = -6
    indx_end = 6
    val_ = val

    val = np.clip(val, indx_start, indx_end)

    LUT_indx_vec = np.linspace(indx_start, indx_end, num=LUT_len + 1)
    GELU_samples = LUT_indx_vec * 0.5 * (1 + vectorized_erf(LUT_indx_vec / (2 ** 0.5)))


    step_size = (indx_end - indx_start) / LUT_len
    gelu_slope_LUT = ((GELU_samples[1:] - GELU_samples[0:-1]) / step_size).astype(np.float32)
    gelu_offset_LUT = (GELU_samples[0:-1] - gelu_slope_LUT * LUT_indx_vec[0:-1]).astype(np.float32)

    val_lut_indx = np.int64(np.minimum(np.maximum(np.floor((val - indx_start) / step_size), 0), LUT_len - 1))

    return_val = val_ * (val_ > indx_end) + (val_ <= indx_end) * (val_ >= indx_start) * (
                gelu_slope_LUT[val_lut_indx] * val + gelu_offset_LUT[val_lut_indx])

    return return_val  # gelu_offset_LUT[val_lut_indx]+gelu_slope_LUT[val_lut_indx]*(val-LUT_indx_vec[val_lut_indx])


def gelu_lut_approx_bfloat16_6_384(val):
    LUT_len = 384
    indx_start = -6
    indx_end = 6
    val_ = val

    val = np.clip(val, indx_start, indx_end)

    LUT_indx_vec = np.linspace(indx_start, indx_end, num=LUT_len + 1)
    GELU_samples = LUT_indx_vec * 0.5 * (1 + vectorized_erf(LUT_indx_vec / (2 ** 0.5)))


    step_size = (indx_end - indx_start) / LUT_len
    gelu_slope_LUT = (bfloat16((GELU_samples[1:] - GELU_samples[0:-1]) / step_size))
    gelu_offset_LUT = bfloat16(GELU_samples[0:-1] - gelu_slope_LUT * LUT_indx_vec[0:-1])

    val_lut_indx = np.int64(np.minimum(np.maximum(np.floor((val - indx_start) / step_size), 0), LUT_len - 1))
    return_val = val_ * (val_ > indx_end) + (val_ <= indx_end) * (val_ >= indx_start) * (
                gelu_slope_LUT[val_lut_indx] * val + gelu_offset_LUT[val_lut_indx])

    return return_val  # gelu_offset_LUT[val_lut_indx]+gelu_slope_LUT[val_lut_indx]*(val-LUT_indx_vec[val_lut_indx])


def gelu_lut_approx_fp32_3_192(val):
    LUT_len = 192
    indx_start = -3
    indx_end = 3
    val_ = val

    val = np.clip(val, indx_start, indx_end)

    LUT_indx_vec = np.linspace(indx_start, indx_end, num=LUT_len + 1)
    GELU_samples = LUT_indx_vec * 0.5 * (1 + vectorized_erf(LUT_indx_vec / (2 ** 0.5)))


    step_size = (indx_end - indx_start) / LUT_len
    gelu_slope_LUT = ((GELU_samples[1:] - GELU_samples[0:-1]) / step_size)
    gelu_offset_LUT = (GELU_samples[0:-1] - gelu_slope_LUT * LUT_indx_vec[0:-1])

    val_lut_indx = np.int64(np.minimum(np.maximum(np.floor((val - indx_start) / step_size), 0), LUT_len - 1))
    return_val = val_ * (val_ > indx_end) + (val_ <= indx_end) * (val_ >= indx_start) * (
                gelu_slope_LUT[val_lut_indx] * val + gelu_offset_LUT[val_lut_indx])

    return return_val  # gelu_offset_LUT[val_lut_indx]+gelu_slope_LUT[val_lut_indx]*(val-LUT_indx_vec[val_lut_indx])


def gelu_lut_approx_bfloat16_3_192(val):
    LUT_len = 192
    indx_start = -3
    indx_end = 3
    val_ = val

    val = np.clip(val, indx_start, indx_end)

    LUT_indx_vec = np.linspace(indx_start, indx_end, num=LUT_len + 1)
    GELU_samples = LUT_indx_vec * 0.5 * (1 + vectorized_erf(LUT_indx_vec / (2 ** 0.5)))


    step_size = (indx_end - indx_start) / LUT_len
    gelu_slope_LUT = bfloat16((GELU_samples[1:] - GELU_samples[0:-1]) / step_size)  # .astype(np.bfloat16)
    gelu_offset_LUT = bfloat16(GELU_samples[0:-1] - gelu_slope_LUT * LUT_indx_vec[0:-1])  # .astype(np.bfloat16)

    val_lut_indx = np.int64(np.minimum(np.maximum(np.floor((val - indx_start) / step_size), 0), LUT_len - 1))
    return_val = val_ * (val_ > indx_end) + (val_ <= indx_end) * (val_ >= indx_start) * (
                gelu_slope_LUT[val_lut_indx] * val + gelu_offset_LUT[val_lut_indx])

    return return_val  # gelu_offset_LUT[val_lut_indx]+gelu_slope_LUT[val_lut_indx]*(val-LUT_indx_vec[val_lut_indx])


def _instancenorm_int64_sum_sq(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp,
                               scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples
    x = x - a_zp
    sum_x = np.sum(x.astype(np.int32), axis=axis, keepdims=True, dtype=np.int32)

    sum_x_sq = np.sum(x.astype(np.int64) * x.astype(np.int64), axis=axis, keepdims=True, dtype=np.int64)
    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    gamma = gamma.astype(np.int32)

    gamma -= gamma_zp

    bias = bias.astype(np.int32)

    bias -= bias_zp

    bias = bias.reshape(-1, *dim_ones)
    numerator = (N.astype(np.int32) * x.astype(np.int32) - sum_x.astype(np.int32)).astype(np.int32)  # int32

    denom_in_sqrt_no_eps = ((N.astype(np.int64) * sum_x_sq.astype(np.int64)) - (
            sum_x.astype(np.int64) * sum_x.astype(np.int64))).astype(np.int64)

    denom = np.sqrt(
        denom_in_sqrt_no_eps.astype(np.float32) + ((N ** 2).astype(np.float32) * epsilon /
                                                   (scale_activations ** 2).astype(np.float32))).astype(np.float32)

    pre_relu = ((gamma.astype(np.float32) * scale_gamma).astype(np.float32) * (numerator / denom).astype(np.float32) +
                bias.astype(np.float32) * scale_bias.astype(np.float32)).astype(np.float32)
    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)

    else:
        post_relu = pre_relu

    return test_quant_linear_int16(post_relu, a_q_yscale, a_q_yzero_pt)


def _instancenorm_bf16_act_int16(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples
    x = x.astype(np.int32)
    x = x - a_zp
    x = bfloat16(x)
    sum_x = np.sum(x, axis=axis, keepdims=True, dtype=np.float32)

    sum_x_sq = np.sum(x.astype(np.float32) * x.astype(np.float32), axis=axis, keepdims=True, dtype=np.float32)
    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    gamma = gamma.astype(np.int32)

    gamma -= gamma_zp

    bias = bias.astype(np.int32)

    bias -= bias_zp

    bias = bias.reshape(-1, *dim_ones)

    numerator = (N.astype(np.int32) * x.astype(np.float32) - sum_x).astype(np.float32)  # int32

    denom_in_sqrt_no_eps = ((N.astype(np.int64) * sum_x_sq) - (sum_x * sum_x)).astype(np.float32)

    denom = np.sqrt(
        denom_in_sqrt_no_eps.astype(np.float32) + ((N ** 2).astype(np.float32) * epsilon /
                                                   (scale_activations ** 2).astype(np.float32))).astype(
        np.float32)

    pre_relu = ((gamma.astype(np.float32) * scale_gamma).astype(np.float32) * (numerator / denom).astype(np.float32) +
                bias.astype(np.float32) * scale_bias.astype(np.float32)).astype(np.float32)

    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu

    return test_quant_linear_int16(post_relu, a_q_yscale, a_q_yzero_pt)

def _instancenorm_bf16_act_kernelref_int16(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples.astype(np.float32)
    inv_N = np.float32(1.0/N)

    gamma = test_dequant_linear(gamma, scale_gamma, gamma_zp)
    bias = test_dequant_linear(bias, scale_bias, bias_zp)

    x = x.astype(np.int32)
    x = x - a_zp
    inp_min_zp_bf16 = bfloat16(x)
    sum_x = np.sum(x, axis=axis, keepdims=True, dtype=np.int32).astype(np.float32)
    sum_x_sq = np.sum(x.astype(np.int64) * x.astype(np.int64), axis=axis, keepdims=True, dtype=np.int64).astype(np.float32)
    mean_x_fp32 = (inv_N * sum_x).astype(np.float32)
    mean_x_bf16 =  bfloat16(mean_x_fp32)
    var_x_fp32 = (inv_N * sum_x_sq - (mean_x_fp32*mean_x_fp32)).astype(np.float32)

    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    scale_bf16 = bfloat16(gamma/a_q_yscale)
    bias_bf16  = bfloat16(bias/a_q_yscale)

    inv_sqrt_bf16 = bfloat16(1/np.sqrt(var_x_fp32+(bfloat16(epsilon/(scale_activations ** 2))).astype(np.float32)))

    #pre_relu = scale_bf16*((inp_min_zp_bf16*inv_sqrt_bf16) - mean_x_bf16*inv_sqrt_bf16.astype(np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt
    #pre_relu = scale_bf16 * (inp_min_zp_bf16 * inv_sqrt_bf16.astype(np.float32) - mean_x_bf16 * inv_sqrt_bf16.astype(
    #    np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt

    pre_relu = bfloat16(scale_bf16 * ((inp_min_zp_bf16 * inv_sqrt_bf16).astype(np.float32) - (mean_x_bf16 * inv_sqrt_bf16).astype(np.float32)))+ bias_bf16.astype(np.float32)

    pre_relu = (test_quant_linear_int16(pre_relu, np.float32(1.0), np.float32(0.0)).astype(np.int32)+ a_q_yzero_pt).astype(np.float32)


    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu
    return test_quant_linear_int16(post_relu, np.float32(1.0), np.float32(0.0))

def _instancenorm_bf16_act_kernelref_int16_change(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples.astype(np.float32)
    inv_N = bfloat16(1.0/N)

    gamma = test_dequant_linear(gamma, scale_gamma, gamma_zp)
    bias = test_dequant_linear(bias, scale_bias, bias_zp)

    x = bfloat16(x)
    x = x - bfloat16(a_zp)
    inp_min_zp_bf16 = bfloat16(x)
    sum_x = np.sum(bfloat16(x), axis=axis, keepdims=True, dtype=np.float32)
    sum_x_sq = np.sum(bfloat16(x) * bfloat16(x), axis=axis, keepdims=True, dtype=np.float32)
    mean_x_fp32 = (inv_N * sum_x).astype(np.float32)
    mean_x_bf16 = bfloat16(mean_x_fp32)
    var_x_fp32 = (inv_N * bfloat16(sum_x_sq) - (mean_x_bf16*mean_x_bf16)).astype(np.float32)

    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    scale_bf16 = bfloat16(gamma/a_q_yscale)
    bias_bf16  = bfloat16(bias/a_q_yscale)

    inv_sqrt_bf16 = bfloat16(1/np.sqrt(var_x_fp32+(bfloat16(epsilon/(scale_activations ** 2))).astype(np.float32)))

    #pre_relu = scale_bf16*((inp_min_zp_bf16*inv_sqrt_bf16) - mean_x_bf16*inv_sqrt_bf16.astype(np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt
    #pre_relu = scale_bf16 * (inp_min_zp_bf16 * inv_sqrt_bf16.astype(np.float32) - mean_x_bf16 * inv_sqrt_bf16.astype(
    #    np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt

    pre_relu = bfloat16(scale_bf16 * ((inp_min_zp_bf16 * inv_sqrt_bf16).astype(np.float32) - (mean_x_bf16 * inv_sqrt_bf16).astype(np.float32))) + bfloat16(a_q_yzero_pt)+bias_bf16.astype(np.float32)
    #pre_relu = (test_quant_linear_int16(pre_relu, np.float32(1.0), np.float32(0.0)).astype(np.int32)+ a_q_yzero_pt).astype(np.float32)

    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu

    return test_quant_linear_int16(post_relu, np.float32(1.0), np.float32(0.0))

def _instancenorm_bf16_act_kernelref_int8(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples.astype(np.float32)
    inv_N = np.float32(1.0/N)

    gamma = test_dequant_linear(gamma, scale_gamma, gamma_zp)
    bias = test_dequant_linear(bias, scale_bias, bias_zp)

    x = x.astype(np.int32)
    x = x - a_zp
    inp_min_zp_bf16 = bfloat16(x)
    sum_x = np.sum(x, axis=axis, keepdims=True, dtype=np.int32).astype(np.float32)
    sum_x_sq = np.sum(x.astype(np.int64) * x.astype(np.int64), axis=axis, keepdims=True, dtype=np.int64).astype(np.float32)
    mean_x_fp32 = (inv_N * sum_x).astype(np.float32)
    mean_x_bf16 =  bfloat16(mean_x_fp32)
    var_x_fp32 = (inv_N * sum_x_sq - (mean_x_fp32*mean_x_fp32)).astype(np.float32)

    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    scale_bf16 = bfloat16(gamma/a_q_yscale)
    bias_bf16  = bfloat16(bias/a_q_yscale)

    inv_sqrt_bf16 = bfloat16(1/np.sqrt(var_x_fp32+(bfloat16(epsilon/(scale_activations ** 2))).astype(np.float32)))

    #pre_relu = scale_bf16*((inp_min_zp_bf16*inv_sqrt_bf16) - mean_x_bf16*inv_sqrt_bf16.astype(np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt
    #pre_relu = scale_bf16 * (inp_min_zp_bf16 * inv_sqrt_bf16.astype(np.float32) - mean_x_bf16 * inv_sqrt_bf16.astype(
    #    np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt

    pre_relu = bfloat16(scale_bf16 * ((inp_min_zp_bf16 * inv_sqrt_bf16).astype(np.float32) - (mean_x_bf16 * inv_sqrt_bf16).astype(np.float32))) + bias_bf16.astype(np.float32)

    pre_relu = (test_quant_linear_int16(pre_relu, np.float32(1.0), np.float32(0.0)).astype(np.int32)+ a_q_yzero_pt).astype(np.float32)


    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu

    return test_quant_linear_int8(post_relu, np.float32(1.0), np.float32(0.0))

def _instancenorm_bf16_act_kernelref_int8_change(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples.astype(np.float32)
    inv_N = bfloat16(1.0/N)

    gamma = test_dequant_linear(gamma, scale_gamma, gamma_zp)
    bias = test_dequant_linear(bias, scale_bias, bias_zp)

    x = bfloat16(x)
    x = x - bfloat16(a_zp)
    inp_min_zp_bf16 = bfloat16(x)
    sum_x = np.sum(bfloat16(x), axis=axis, keepdims=True, dtype=np.float32)
    sum_x_sq = np.sum(bfloat16(x) * bfloat16(x), axis=axis, keepdims=True, dtype=np.float32)
    mean_x_fp32 = (inv_N * sum_x).astype(np.float32)
    mean_x_bf16 = bfloat16(mean_x_fp32)
    var_x_fp32 = (inv_N * bfloat16(sum_x_sq) - (mean_x_bf16*mean_x_bf16)).astype(np.float32)

    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    scale_bf16 = bfloat16(gamma/a_q_yscale)
    bias_bf16  = bfloat16(bias/a_q_yscale)

    inv_sqrt_bf16 = bfloat16(1/np.sqrt(var_x_fp32+(bfloat16(epsilon/(scale_activations ** 2))).astype(np.float32)))

    #pre_relu = scale_bf16*((inp_min_zp_bf16*inv_sqrt_bf16) - mean_x_bf16*inv_sqrt_bf16.astype(np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt
    #pre_relu = scale_bf16 * (inp_min_zp_bf16 * inv_sqrt_bf16.astype(np.float32) - mean_x_bf16 * inv_sqrt_bf16.astype(
    #    np.float32)).astype(np.float32) + bias_bf16 + a_q_yzero_pt

    pre_relu = bfloat16(scale_bf16 * ((inp_min_zp_bf16 * inv_sqrt_bf16).astype(np.float32) - (mean_x_bf16 * inv_sqrt_bf16).astype(np.float32))) + bfloat16(a_q_yzero_pt)+bias_bf16.astype(np.float32)

    #pre_relu = (test_quant_linear_int16(pre_relu, np.float32(1.0), np.float32(0.0)).astype(np.int32)+ a_q_yzero_pt).astype(np.float32)


    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu

    return test_quant_linear_int8(post_relu, np.float32(1.0), np.float32(0.0))

def _instancenorm_bf16_act_int8(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp, scale_bias,
                           bias_zp, a_q_yscale, a_q_yzero_pt):
    dims_x = len(x.shape)
    axis = tuple(range(2, dims_x))
    axis_indx = np.array(range(2, dims_x), dtype=np.int32)
    num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
    N = num_mean_var_samples
    x = x.astype(np.int32)
    x = x - a_zp
    x = bfloat16(x)
    sum_x = np.sum(x, axis=axis, keepdims=True, dtype=np.float32)

    sum_x_sq = np.sum(x.astype(np.float32) * x.astype(np.float32), axis=axis, keepdims=True, dtype=np.float32)
    dim_ones = (1,) * (dims_x - 2)
    gamma = gamma.reshape(-1, *dim_ones)
    bias = bias.reshape(-1, *dim_ones)
    gamma = gamma.astype(np.int32)

    gamma -= gamma_zp

    bias = bias.astype(np.int32)

    bias -= bias_zp

    bias = bias.reshape(-1, *dim_ones)

    numerator = (N.astype(np.int32) * x.astype(np.float32) - sum_x).astype(np.float32)  # int32

    denom_in_sqrt_no_eps = ((N.astype(np.int64) * sum_x_sq) - (sum_x * sum_x)).astype(np.float32)

    denom = np.sqrt(
        denom_in_sqrt_no_eps.astype(np.float32) + ((N ** 2).astype(np.float32) * epsilon /
                                                   (scale_activations ** 2).astype(np.float32))).astype(
        np.float32)

    pre_relu = ((gamma.astype(np.float32) * scale_gamma).astype(np.float32) * (numerator / denom).astype(np.float32) +
                bias.astype(np.float32) * scale_bias.astype(np.float32)).astype(np.float32)

    if is_relu:
        post_relu = pre_relu * (pre_relu >= 0)
    else:
        post_relu = pre_relu

    return test_quant_linear_int8(post_relu, a_q_yscale, a_q_yzero_pt)


def _instancenorm_int64_sum_sq_int8(is_relu, x, gamma, bias, epsilon, scale_activations, a_zp, scale_gamma, gamma_zp,
                               scale_bias, bias_zp, a_q_yscale, a_q_yzero_pt):
  dims_x = len(x.shape)
  axis = tuple(range(2, dims_x))
  axis_indx = np.array(range(2, dims_x), dtype=np.int32)
  num_mean_var_samples = np.prod((np.array(np.shape(x)))[axis_indx], dtype=np.int32)
  N = num_mean_var_samples
  x = x - a_zp
  sum_x = np.sum(x.astype(np.int32), axis=axis, keepdims=True, dtype=np.int32)

  sum_x_sq = np.sum(x.astype(np.int64) * x.astype(np.int64), axis=axis, keepdims=True, dtype=np.int64)
  dim_ones = (1,) * (dims_x - 2)
  gamma = gamma.reshape(-1, *dim_ones)
  gamma = gamma.astype(np.int32)

  gamma -= gamma_zp

  bias = bias.astype(np.int32)

  bias -= bias_zp

  bias = bias.reshape(-1, *dim_ones)
  numerator = (N.astype(np.int32) * x.astype(np.int32) - sum_x.astype(np.int32)).astype(np.int32)  # int32

  denom_in_sqrt_no_eps = ((N.astype(np.int64) * sum_x_sq.astype(np.int64)) - (
            sum_x.astype(np.int64) * sum_x.astype(np.int64))).astype(np.int64)


  denom = np.sqrt(denom_in_sqrt_no_eps.astype(np.float32) + ((N ** 2).astype(np.float32) * epsilon /
                                                             (scale_activations ** 2).astype(np.float32))).astype(np.float32)

  pre_relu = ((gamma.astype(np.float32) * scale_gamma).astype(np.float32) * (numerator / denom).astype(np.float32) +
              bias.astype(np.float32) * scale_bias.astype(np.float32)).astype(np.float32)
  if is_relu:
    post_relu = pre_relu * (pre_relu >= 0)
  else:
    post_relu = pre_relu

  return test_quant_linear_int8(post_relu, a_q_yscale, a_q_yzero_pt)



def remove_node(model__, node_name):
    for node in model__.graph.node:
        if node.name == node_name:
            model__.graph.node.remove(node)
    return model__


def change_node_input(model__, node_name, input_indx, input_name):
    for node in model__.graph.node:
        if node.name == node_name:
            node.input[input_indx] = input_name
    return model__


def test_dequant_linear(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)

    return scale * (inp - zero_pt)


def test_quant_linear(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = inp / scale
    res = np.round(res)
    res = np.clip(res + zero_pt, -128, 127)

    return res.astype(np.int8)


# SRS with round even
def right_broadcasting(arr, target):
    return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim))


def test_quant_linear_uint8(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = inp / scale
    res = np.round(res)
    res = np.clip(res + zero_pt, 0, 255)

    return res.astype(np.uint8)


def test_quant_linear_uint16(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = inp / scale
    res = np.round(res)
    res = np.clip(res + zero_pt, 0, 65535)

    return res.astype(np.uint16)


def test_quant_linear_int16(inp, scale, zero_pt):
    inp = inp.astype(np.float32)
    zero_pt = zero_pt.astype(np.float32)
    res = inp / scale
    res = np.round(res)
    res = np.clip(res + zero_pt, -32768, 32767)

    return res.astype(np.int16)

def test_quant_linear_int8(inp, scale, zero_pt):
  inp = inp.astype(np.float32)
  zero_pt = zero_pt.astype(np.float32)
  res = inp / scale
  res = np.round(res)
  res = np.clip(res + zero_pt, -128, 127)

  return res.astype(np.int8)
# SRS with round even
def srs_int8_even(inp, shift):
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    if frac_lead_bit != 0:
        round_res = inp_floor + 1
        if inp_frac == (1 << shift - 1):

            if inp_floor % 2 == 0:
                round_res = inp_floor
            else:
                round_res = inp_floor + 1
    else:
        round_res = inp_floor
    round_res = sign_inp * round_res
    round_res = np.clip(round_res, -128, 127)

    return round_res.astype(np.int8)


def srs_uint8_even(inp, shift):
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    if frac_lead_bit != 0:
        round_res = inp_floor + 1
        if inp_frac == (1 << shift - 1):

            if inp_floor % 2 == 0:
                round_res = inp_floor
            else:
                round_res = inp_floor + 1
    else:
        round_res = inp_floor
    round_res = sign_inp * round_res
    round_res = np.clip(round_res, 0, 255)

    return round_res.astype(np.uint8)


def srs_uint8_even_fast(inp, shift):
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    frac_lead_bit_nonzero_bool_mtrx = (frac_lead_bit != 0)

    frac_lead_bit_nonzero_bool_mtrx = frac_lead_bit_nonzero_bool_mtrx.astype(np.int64)

    frac_lead_bit_zero_bool_mtrx = 1 - frac_lead_bit_nonzero_bool_mtrx

    inp_floor_even_bool_mtrx = inp_floor % 2
    inp_floor_odd_bool_mtrx = 1 - inp_floor_even_bool_mtrx

    inp_frac_eq_half_bool_mtrx = (inp_frac == (1 << shift - 1))
    inp_frac_eq_half_bool_mtrx = inp_frac_eq_half_bool_mtrx.astype(np.int64)
    inp_frac_neq_half_bool_mtrx = 1 - inp_frac_eq_half_bool_mtrx

    inp_floor_plus_1 = inp_floor + 1

    round_res = (frac_lead_bit_zero_bool_mtrx * inp_floor)

    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_neq_half_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_odd_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_even_bool_mtrx * inp_floor
    # code snippet that is commented out is implemented in np.array format for speed
    '''
  if frac_lead_bit != 0:
    round_res = inp_floor+1
    if inp_frac == (1<<shift-1):

      if inp_floor%2==0:
        round_res = inp_floor
      else:
        round_res = inp_floor+1
  else:
    round_res = inp_floor
  '''
    round_res = sign_inp * round_res
    round_res = np.clip(round_res, 0, 255)

    return round_res.astype(np.uint8)


def extract_instnorm_eps(node):
    for attr in node.attribute:
        if attr.name == 'epsilon':
            return np.float32(attr.f)


def srs_uint16_even_fast(inp, shift):
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    frac_lead_bit_nonzero_bool_mtrx = (frac_lead_bit != 0)

    frac_lead_bit_nonzero_bool_mtrx = frac_lead_bit_nonzero_bool_mtrx.astype(np.int64)

    frac_lead_bit_zero_bool_mtrx = 1 - frac_lead_bit_nonzero_bool_mtrx

    inp_floor_even_bool_mtrx = inp_floor % 2
    inp_floor_odd_bool_mtrx = 1 - inp_floor_even_bool_mtrx

    inp_frac_eq_half_bool_mtrx = (inp_frac == (1 << shift - 1))
    inp_frac_eq_half_bool_mtrx = inp_frac_eq_half_bool_mtrx.astype(np.int64)
    inp_frac_neq_half_bool_mtrx = 1 - inp_frac_eq_half_bool_mtrx

    inp_floor_plus_1 = inp_floor + 1

    round_res = (frac_lead_bit_zero_bool_mtrx * inp_floor)

    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_neq_half_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_odd_bool_mtrx * inp_floor_plus_1
    round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_even_bool_mtrx * inp_floor
    # code snippet that is commented out is implemented in np.array format for speed
    '''
  if frac_lead_bit != 0:
    round_res = inp_floor+1
    if inp_frac == (1<<shift-1):

      if inp_floor%2==0:
        round_res = inp_floor
      else:
        round_res = inp_floor+1
  else:
    round_res = inp_floor
  '''
    round_res = sign_inp * round_res
    round_res = np.clip(round_res, 0, 65535)

    return round_res.astype(np.uint16)


def srs_int32_even_fast(inp, shift):
    if shift == 0:
        round_res = np.clip(inp, -2147483648, 2147483647)
        return round_res.astype(np.int32)
    else:
        sign_inp = np.sign(inp)
        inp = abs(inp)
        inp_floor = inp >> shift
        inp_frac = inp - (inp_floor << shift)
        frac_lead_bit = inp_frac >> (shift - 1)

        frac_lead_bit_nonzero_bool_mtrx = (frac_lead_bit != 0)

        frac_lead_bit_nonzero_bool_mtrx = frac_lead_bit_nonzero_bool_mtrx.astype(np.int64)

        frac_lead_bit_zero_bool_mtrx = 1 - frac_lead_bit_nonzero_bool_mtrx

        inp_floor_even_bool_mtrx = inp_floor % 2
        inp_floor_odd_bool_mtrx = 1 - inp_floor_even_bool_mtrx

        inp_frac_eq_half_bool_mtrx = (inp_frac == (1 << shift - 1))
        inp_frac_eq_half_bool_mtrx = inp_frac_eq_half_bool_mtrx.astype(np.int64)
        inp_frac_neq_half_bool_mtrx = 1 - inp_frac_eq_half_bool_mtrx

        inp_floor_plus_1 = inp_floor + 1

        round_res = (frac_lead_bit_zero_bool_mtrx * inp_floor)

        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_neq_half_bool_mtrx * inp_floor_plus_1
        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_odd_bool_mtrx * inp_floor_plus_1
        round_res += frac_lead_bit_nonzero_bool_mtrx * inp_frac_eq_half_bool_mtrx * inp_floor_even_bool_mtrx * inp_floor
        # code snippet that is commented out is implemented in np.array format for speed
        '''
    if frac_lead_bit != 0:
      round_res = inp_floor+1
      if inp_frac == (1<<shift-1):

        if inp_floor%2==0:
          round_res = inp_floor
        else:
          round_res = inp_floor+1
    else:
      round_res = inp_floor
    '''
        round_res = sign_inp * round_res
        round_res = np.clip(round_res, -2147483648, 2147483647)

        return round_res.astype(np.int32)


# Shift and round without saturation to model zero term after rounding
def sr_int64_even(inp, shift):
    sign_inp = np.sign(inp)
    inp = abs(inp)
    inp_floor = inp >> shift
    inp_frac = inp - (inp_floor << shift)
    frac_lead_bit = inp_frac >> (shift - 1)

    if frac_lead_bit != 0:
        round_res = inp_floor + 1
        if inp_frac == (1 << shift - 1):
            print('Fractional part is 0.5')
            if inp_floor % 2 == 0:
                round_res = inp_floor
            else:
                round_res = inp_floor + 1
    else:
        round_res = inp_floor
    round_res = sign_inp * round_res

    return round_res.astype(np.int64)


def find_closest_shifted_int16(float_val):
    INT16_MAX = 32767
    prev_rel_err = 1e9
    curr_float_val = float_val
    best_float_val = float(0)
    shift_val = np.int16
    shift_val = 0
    best_int = np.int16
    closest_curr_int = np.int16
    best_shift_val = np.int16

    while curr_float_val <= INT16_MAX:
        closest_curr_int = round(curr_float_val)
        cur_rel_err = abs(float_val - closest_curr_int / (2 ** shift_val)) / float_val

        if cur_rel_err < prev_rel_err:
            prev_rel_err = cur_rel_err
            best_float_val = float(closest_curr_int >> shift_val)
            best_shift_val = shift_val
            best_int = closest_curr_int

        curr_float_val *= 2
        shift_val += 1

    return [best_int, best_shift_val]


def find_closest_shifted_int32(float_val):
    INT32_MAX = 16777216  # 2147483647
    prev_rel_err = 1e9
    curr_float_val = float_val
    best_float_val = float(0)
    shift_val = np.int16
    shift_val = 0
    best_int = np.int32
    closest_curr_int = np.int32
    best_shift_val = np.int16(0)

    while curr_float_val <= INT32_MAX:
        closest_curr_int = round(curr_float_val)
        cur_rel_err = abs(float_val - closest_curr_int / (2 ** shift_val)) / float_val

        if cur_rel_err < prev_rel_err:
            prev_rel_err = cur_rel_err
            best_float_val = float(closest_curr_int >> shift_val)
            best_shift_val = shift_val
            best_int = closest_curr_int

        curr_float_val *= 2
        shift_val += 1
    return [best_int, best_shift_val]


def get_act_shapes_using_onnx_runtime_custom_ops(model_path, load_data, provider = 'CPUExecutionProvider', print_to_file = False):
  model = onnx.load_model(model_path, load_external_data=load_data)
  # model = onnx.load(model_path)
  graph = model.graph

  #dictionary with all act tensor shapes
  node_act_signals_shapes = {}

  rand_inputs = {}

  for _input in graph.input:
    input_shape = np.array([d.dim_value for d in _input.type.tensor_type.shape.dim])
    #print(input_shape)

    node_act_signals_shapes[_input.name] = [int(i) for i in input_shape]
    d_type =_input.type.tensor_type.elem_type
    #print(d_type)
    if d_type == 1:
      rand_inputs[_input.name] = np.random.randn(*input_shape).astype(np.float32)
    elif d_type == 7:
      rand_inputs[_input.name] = np.random.randn(*input_shape).astype(np.int64)

  ## create ort session to get existing output nodes
  so = ort.SessionOptions()
  so.register_custom_ops_library(get_library_path())
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

  #ort_session = ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
  ort_session = ort.InferenceSession(model.SerializeToString(), so, providers=[provider])

  outputs = [x.name for x in ort_session.get_outputs()]
  del ort_session

  # add remaining node outputs as global outputs
  for node in model.graph.node:
    for output in node.output:
        if output not in outputs:
            model.graph.output.extend([onnx.ValueInfoProto(name=output)])

  ## create ort session and run model
  #ort_session = ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
  ort_session1 = ort.InferenceSession(model.SerializeToString(), so, providers=[provider])
  outputs = [x.name for x in ort_session1.get_outputs()]
  #print(outputs)
  ort_outs_mod = ort_session1.run(outputs, rand_inputs)
  #print(ort_outs_mod)
  ort_outs_mod_dict = OrderedDict(zip(outputs, ort_outs_mod))


  for out_signal_name, tensor in ort_outs_mod_dict.items():
    node_act_signals_shapes[out_signal_name] = [int(i) for i in tensor.shape]

  print("Got shapes for ", len(node_act_signals_shapes), " nodes output activations using onnx runtime")
  #print(node_act_signals_shapes)

  if print_to_file:
    with open(model_path.split(".onnx")[0] + "_act_shapes.json", "w") as f:
      sys.stdout = f # Change the standard output to the file we created.
      json.dump(node_act_signals_shapes, f, indent = 4)
      sys.stdout = sys.__stdout__ #original_stdout # Reset the standard output to its original value

  return node_act_signals_shapes


if __name__ == "__main__":
  # model_name = 'PSH_v1.0.quant.onnx_allqdqconv_add_mod.onnx'
  #model_name_orig = 'PSF_v1.0.quant.onnx'#'C4_v_0_1.onnx'
  #model_name_orig = 'C4_v_0_1_quantized_mix4conv_noacc_convert_qdq.onnx'  # 'C4_v_0_1.onnx'
  model_name_orig = 'Model_PSF_v1.0.4_4_1_24.onnx'
  #model_name_orig = 'Model_PSL_v1.0.2_4_1_24.onnx'
  #model_name_orig = 'Model_PSH_v1.1.0_4_1_24.onnx'#'Model_PSK_v1.0.2_4_1_24.onnx'
  #model_name_orig = 'Model_PSJ_v1.1_4_1_24.onnx'#'PSJ_v1.0.quant.onnx'#
  #model_name_orig = 'Model_PSJ_v1.0_4_1_24.onnx'
  #model_name_orig = 'Model_PSI_v1.1_4_1_24.onnx'
  #model_name_orig = 'PSO1_4_1_24.onnx'
  # model_name_orig = 'C4_v_0_1_quantized_S16S8-int16s-AAWS-ADAROUND.onnx'
  #model_name = 'Model_PSF_v1.0.4_4_1_24.onnx_biasadd_qdq_removal.onnx_allqdqconv_add_mod.onnx'#model_name_orig + '_allqdqconv_add_mod.onnx'
  model_name = model_name_orig + '_allqdqconv_add_mod.onnx'
  #model_name = 'Model_PSF_v1.0.4_4_1_24.onnx_biasadd_qdq_removal.onnx_allqdqconv_add_mod.onnx'#model_name_orig + '_allqdqconv_add_mod.onnx'
  model = onnx.load(model_name)
  # onnx.checker.check_model(model, full_check=False)
  so = ort.SessionOptions()
  so.register_custom_ops_library(get_library_path())
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
  # ort_input0 = np.fromfile('C4_test_input_247.bin', dtype='f').reshape((1, 64, 128, 8))

  #ort_input0 = np.load('PSJ_emb_000000000139.npy')
  #ort_input1 = np.load('PSJ_mask_000000000139.npy')


  #ort_input0 = np.fromfile('C4_test_input_0_generator_concat_0.bin', dtype='f').reshape((1, 64, 128, 8))
  #ort_input0 = np.load('PSI_000000001268.npy')
  ort_input0 = np.fromfile('attention_mask_0_PSF.raw', dtype='f').reshape((1, 512))
  ort_input1 = np.fromfile('embeddings_0_PSF.raw', dtype='f').reshape((1, 512, 768))
  #ort_input0 = np.fromfile('PSL_attention_mask_0.raw', dtype='f').reshape((1, 512))
  #ort_input1 = np.fromfile('PSL_embeddings_0.raw', dtype='f').reshape((1, 512, 768))
  #ort_input1 = np.fromfile('PSK_embeddings_0.raw', dtype='f').reshape((1, 512, 768))
  #ort_input0 = np.fromfile('PSK_attention_mask_0.raw', dtype='f').reshape((1, 512))
  #ort_input0 = np.fromfile('PSH_attention_mask_0.raw', dtype='f').reshape((1, 512))
  #ort_input1 = np.fromfile('PSH_embeddings_0.raw', dtype='f').reshape((1, 512, 768))
  '''
  ort_input0 = np.fromfile('PSK_embeddings_0.raw', dtype='f').reshape((1, 512, 768))
  ort_input1 = np.fromfile('PSK_attention_mask_0.raw', dtype='f').reshape((1, 512)

  ort_input0 = np.fromfile('C4_test_input_0_generator_concat_0.bin', dtype='f').reshape((1, 64, 128, 8))


  ort_input0 = np.fromfile('attention_mask_0_PSF.raw', dtype='f').reshape((1, 512))
  ort_input1 = np.fromfile('embeddings_0_PSF.raw', dtype='f').reshape((1, 512, 768))

  with open('PSO0_data.pb', 'rb') as fin:
    tensor = onnx.TensorProto()
    tensor.ParseFromString(fin.read())
    ort_input0 = tensor.raw_data
  ort_input0 = np.frombuffer(ort_input0, dtype=np.float32).reshape( 3, 768, 1152)

  ort_input0 = np.fromfile('PSL_attention_mask_0.raw', dtype='f').reshape((1, 512))
  ort_input1 = np.fromfile('PSL_embeddings_0.raw', dtype='f').reshape((1, 512, 768))

  ort_input0 = np.load('PSJ_emb_000000000139.npy')
  ort_input1 = np.load('PSJ_mask_000000000139.npy')

  ort_input0 = np.fromfile('PSA_001.raw', dtype='f').reshape((1, 3,1280, 1280))


  ort_input0 = np.load('PSI_000000001268.npy')

  ort_input0 = np.fromfile('PSH_attention_mask_0.raw', dtype='f').reshape((1, 512))
  ort_input1 = np.fromfile('PSH_embeddings_0.raw', dtype='f').reshape((1, 512, 768))
  '''

  ort_session = ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
  outputs = [x.name for x in ort_session.get_outputs()]
  #del ort_session
  #ort_outs_mod = ort_session.run(outputs,
  #                               {'/lang_encoder/embeddings/Add_1_output_0': ort_input0, 'attention_mask': ort_input1})
  #ort_outs_mod = ort_session.run(outputs, { 'attention_mask': ort_input0 , '/tulrv6/embeddings/Add_2_output_0': ort_input1} )
  #ort_outs_mod = ort_session.run(outputs, { 'attention_mask': ort_input0 , 'embeddings': ort_input1} )
  #ort_outs_mod = ort_session.run(outputs, { 'input_image': ort_input0} )
  #ort_outs_mod = ort_session.run(outputs, { 'generator/concat:0': ort_input0} )
  #ort_outs_mod_dict = OrderedDict(zip(outputs, ort_outs_mod))
  del ort_session


  for node in model.graph.node:
      for output in node.output:
          if output not in outputs:
              model.graph.output.extend([onnx.ValueInfoProto(name=output)])

  ort_session = ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
  outputs = [x.name for x in ort_session.get_outputs()]
  #ort_outs_mod = ort_session.run(outputs, { 'generator/concat:0': ort_input0} )
  #ort_outs_mod = ort_session.run(outputs,
  #                               {'/lang_encoder/embeddings/Add_1_output_0': ort_input0, 'attention_mask': ort_input1})

  ort_outs_mod = ort_session.run(outputs, {'attention_mask': ort_input0, '/tulrv6/embeddings/Add_2_output_0': ort_input1})
  ort_outs_mod_dict = OrderedDict(zip(outputs, ort_outs_mod))
  with open(model_name+'_intermediate_outputs.pickle', 'wb') as handle:
      pickle.dump(ort_outs_mod_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

  del ort_session

  #model_orig = onnx.load('C4.onnx')
  model_orig = onnx.load(model_name_orig)
  ort_session_orig = ort.InferenceSession(model_orig.SerializeToString(), so, providers=['CPUExecutionProvider'])

  outputs_orig = [x.name for x in ort_session_orig.get_outputs()]

  #del ort_session_orig
  #ort_session_orig = ort.InferenceSession(model_orig.SerializeToString(), so, providers=['CPUExecutionProvider'])
  #ort_outs_orig = ort_session_orig.run(outputs_orig,  {'/lang_encoder/embeddings/Add_1_output_0': ort_input0, 'attention_mask': ort_input1}  )
  #ort_outs_orig = ort_session_orig.run(outputs_orig,  { 'generator/concat:0': ort_input0}  )
  #ort_outs_orig = ort_session_orig.run(outputs_orig, {'attention_mask': ort_input0,
  #                                                             '/tulrv6/embeddings/Add_2_output_0': ort_input1})

  #ort_outs_orig = ort_session_orig.run(outputs_orig, { 'attention_mask': ort_input0 , 'embeddings': ort_input1} )

  #ort_outs_orig = ort_session_orig.run(outputs_orig, { 'input_image': ort_input0} )

  del ort_session_orig
  for node in model_orig.graph.node:
      for output in node.output:
          if output not in outputs_orig:
              model_orig.graph.output.extend([onnx.ValueInfoProto(name=output)])

  ort_session_orig = ort.InferenceSession(model_orig.SerializeToString(), providers=["CPUExecutionProvider"], sess_options=so)
  outputs_orig = [x.name for x in ort_session_orig.get_outputs()]
  #ort_outs_orig = ort_session_orig.run(outputs_orig, {'/lang_encoder/embeddings/Add_1_output_0': ort_input0, 'attention_mask': ort_input1})
  #ort_outs_orig = ort_session_orig.run(outputs_orig,  { 'generator/concat:0': ort_input0}  )

  ort_outs_orig = ort_session_orig.run(outputs_orig, {'attention_mask': ort_input0,
                                                              '/tulrv6/embeddings/Add_2_output_0': ort_input1})

  ort_outs_orig_dict = OrderedDict(zip(outputs_orig, ort_outs_orig))
  with open(model_name_orig+'_intermediate_outputs.pickle', 'wb') as handle:
      pickle.dump(ort_outs_orig_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

  #loaded_interm_out_dict = pickle.load( open( model_name_orig+'_intermediate_outputs.pickle', "rb" ) )
  #print(ort_outs_orig_dict.keys())

  #print(len(ort_outs_orig))
  print(np.linalg.norm(ort_outs_mod[0].astype(np.float32) - ort_outs_orig[0].astype(np.float32)))
  #print(loaded_interm_out_dict['logits_QuantizeLinear_Output'])
  #print(np.linalg.norm(ort_outs_mod[1].astype(np.float32) - ort_outs_orig[1].astype(np.float32)))
  #print(np.linalg.norm(ort_outs_mod[2].astype(np.float32) - ort_outs_orig[2].astype(np.float32)))
  #print(np.linalg.norm(ort_outs_mod[3].astype(np.float32) - ort_outs_orig[3].astype(np.float32)))
  #print(np.linalg.norm(ort_outs_mod[4].astype(np.float32) - ort_outs_orig[4].astype(np.float32)))

  #print(np.max(np.abs(ort_outs_orig_dict['logits_QuantizeLinear_Output'].astype(np.float32)-ort_outs_mod_dict['logits_QuantizeLinear_Output'].astype(np.float32))))
  #print(np.linalg.norm(ort_outs_orig_dict['logits'].astype(np.float32)-ort_outs_mod_dict['logits'].astype(np.float32)))
  #print(ort_outs_orig_dict['output_QuantizeLinear_Output'].astype(np.float32)-ort_outs_mod_dict['output_QuantizeLinear_Output'].astype(np.float32))
  #print(np.linalg.norm(ort_outs_orig_dict['output'].astype(np.float32)-ort_outs_mod_dict['output'].astype(np.float32)))
  #print(np.linalg.norm(ort_outs_orig_dict['generator/Tanh:0'].astype(np.float32)-ort_outs_mod_dict['generator/Tanh:0'].astype(np.float32)))
  # print(ort_outs_mod[0][0,:,0,0])
  # print(ort_outs_orig[0][0,:,0,0])
  #print(np.linalg.norm(ort_outs_orig_dict['output'].astype(np.float32)-ort_outs_mod_dict['output'].astype(np.float32)))
  '''
  print(np.linalg.norm(ort_outs_mod[1].astype(np.float32) - ort_outs_orig[1].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[2].astype(np.float32) - ort_outs_orig[2].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[3].astype(np.float32) - ort_outs_orig[3].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[4].astype(np.float32) - ort_outs_orig[4].astype(np.float32)))

  print(np.linalg.norm(ort_outs_mod[5].astype(np.float32) - ort_outs_orig[5].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[6].astype(np.float32) - ort_outs_orig[6].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[7].astype(np.float32) - ort_outs_orig[7].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[8].astype(np.float32) - ort_outs_orig[8].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[9].astype(np.float32) - ort_outs_orig[9].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[10].astype(np.float32) - ort_outs_orig[10].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[11].astype(np.float32) - ort_outs_orig[11].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[12].astype(np.float32) - ort_outs_orig[12].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[13].astype(np.float32) - ort_outs_orig[13].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[14].astype(np.float32) - ort_outs_orig[14].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[15].astype(np.float32) - ort_outs_orig[15].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[16].astype(np.float32) - ort_outs_orig[16].astype(np.float32)))
  print(np.linalg.norm(ort_outs_mod[17].astype(np.float32) - ort_outs_orig[17].astype(np.float32)))
  '''

  del ort_session_orig
