#Initial test bench to build simple onnx model and run it on O-GOAT

import os
import onnx
import numpy as np
from onnx import helper, TensorProto, numpy_helper


def matmulbias(node_list, initializer_list, node_name, act_in, act_out, dim, bias_flag):
    node_list.append(helper.make_node('DequantizeLinear', inputs=[act_in, node_name+'_act_dq_scale', node_name+'_act_dq_zero_point'], outputs=[node_name+'_act_dq_out'], name=node_name+'_act_dq', domain='com.microsoft'))
    node_list.append(helper.make_node('DequantizeLinear', inputs=[node_name+'_wgt_dq_x', node_name+'_wgt_dq_scale', node_name+'_wgt_dq_zero_point'], outputs=[node_name+'_wgt_dq_out'], name=node_name+'_wgt_dq', domain='com.microsoft'))
    node_list.append(helper.make_node('MatMul', inputs=[node_name+'_act_dq_out', node_name+'_wgt_dq_out'], outputs=[node_name+'_out'], name=node_name+'', domain='ai.onnx'))
    if bias_flag==1:
        node_list.append(helper.make_node('QuantizeLinear', inputs=[node_name+'_out', node_name+'_act_q_scale', node_name+'_act_q_zero_point'], outputs=[node_name+'_act_q_out'], name=node_name+'_act_q', domain='com.microsoft'))
        node_list.append(helper.make_node('DequantizeLinear', inputs=[node_name+'_act_q_out', node_name+'_act_q_scale', node_name+'_act_q_zero_point'], outputs=[node_name+'_act_bia_dq_out'], name=node_name+'_act_qdq', domain='com.microsoft'))
        node_list.append(helper.make_node('DequantizeLinear', inputs=[node_name+'_bias_q_out', node_name+'_bias_dq_scale', node_name+'_bias_dq_zero_point'], outputs=[node_name+'_bias_dq_out'], name=node_name+'_bias_dq', domain='com.microsoft'))
        node_list.append(helper.make_node('Add', inputs=[node_name+'_act_bia_dq_out', node_name+'_bias_dq_out'], outputs=[node_name+'_bias_mul_out'], name=node_name+'_mul', domain='ai.onnx'))
        node_list.append(helper.make_node('QuantizeLinear', inputs=[node_name+'_bias_mul_out', node_name+'_bias_mul_out_q_scale', node_name+'_bias_mul_out_q_zero_point'], outputs=[act_out], name=node_name+'_bias_qdq', domain='com.microsoft'))
    else:
        node_list.append(helper.make_node('QuantizeLinear', inputs=[node_name+'_out', node_name+'_act_q_scale', node_name+'_act_q_zero_point'], outputs=[act_out], name=node_name+'_act_q', domain='com.microsoft'))

    initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), node_name+'_act_dq_scale'))
    initializer_list.append(numpy_helper.from_array(np.array(30767).astype('uint16'), node_name+'_act_dq_zero_point'))
    initializer_list.append(numpy_helper.from_array(np.random.randint(0, 123, size=dim[1]*dim[2]).reshape(dim[1],dim[2]).astype('uint8'), node_name+'_wgt_dq_x'))
    initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), node_name+'_wgt_dq_scale'))
    initializer_list.append(numpy_helper.from_array(np.array(127).astype('uint8'), node_name+'_wgt_dq_zero_point'))
    initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), node_name+'_act_q_scale'))
    initializer_list.append(numpy_helper.from_array(np.array(30767).astype('uint16'), node_name+'_act_q_zero_point'))
    if bias_flag==1:
        initializer_list.append(numpy_helper.from_array(np.random.randint(0, 123, size=1*dim[2]).reshape(1,dim[2]).astype('int32'), node_name+'_bias_q_out'))
        initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), node_name+'_bias_dq_scale'))
        initializer_list.append(numpy_helper.from_array(np.array(0).astype('int32'), node_name+'_bias_dq_zero_point'))
        initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), node_name+'_bias_mul_out_q_scale'))
        initializer_list.append(numpy_helper.from_array(np.array(30767).astype('uint16'), node_name+'_bias_mul_out_q_zero_point'))

#(256, 640, 1280) -> (256, 1280, 1280) -> (256, 1280, 160) -> (256, 160, 512) 

# Create input tensors
inputs = [helper.make_tensor_value_info('input1', TensorProto.FLOAT, [256, 640])]

# Create output tensor
#outputs = [helper.make_empty_tensor_value_info('output')]
outputs = [helper.make_tensor_value_info('output', TensorProto.FLOAT, [256, 256])]

# Create matmul nodes
node_list = []
initializer_list = []

node_list.append(helper.make_node('QuantizeLinear', inputs=['input1', 'input1_q_scale', 'input1_q_zero_point'], outputs=['input1_q_out'], name='input1_quant', domain='com.microsoft'))

initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'input1_q_scale'))
initializer_list.append(numpy_helper.from_array(np.array(30767).astype('uint16'), 'input1_q_zero_point'))

n=0
matmulbias(node_list, initializer_list, 'matmul'+str(n), 'input1_q_out', 'matmul'+str(n)+'_bias_qdq_out', (256, 640, 1280), 1) #wgt full & (streaming)

#n=n+1
#matmulbias(node_list, initializer_list, 'matmul'+str(n), 'matmul'+str(n-1)+'_bias_qdq_out', 'matmul'+str(n)+'_bias_qdq_out', (256, 1280, 1280), 0) #wgt full & (streaming)

#n=n+1
#matmulbias(node_list, initializer_list, 'matmul'+str(n), 'matmul'+str(n-1)+'_bias_qdq_out', 'matmul'+str(n)+'_bias_qdq_out', (256, 1280, 2560), 1) #wgt streaming

#n=n+1
#matmulbias(node_list, initializer_list, 'matmul'+str(n), 'matmul'+str(n-1)+'_bias_qdq_out', 'matmul'+str(n)+'_bias_qdq_out', (256, 2560, 256) , 0) #wgt full

node_list.append(helper.make_node('DequantizeLinear', inputs=['matmul'+str(n)+'_bias_qdq_out', 'output_q_scale', 'output_q_zero_point'], outputs=['output'], name='output_quant', domain='com.microsoft'))

initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'output_q_scale'))
initializer_list.append(numpy_helper.from_array(np.array(30767).astype('uint16'), 'output_q_zero_point'))

opset_import_list = []
opset_import_list.append(helper.make_opsetid("", 18))
opset_import_list.append(helper.make_opsetid("aimet_torch", 1))
opset_import_list.append(helper.make_opsetid("com.microsoft", 1))

ir_version_list = 6

# Create the graph
graph = helper.make_graph(
    node_list,
    'matmul_graph',
    inputs,
    outputs,
    initializer_list,
)

# Create the model
model = helper.make_model(graph, ir_version=ir_version_list, opset_imports=opset_import_list)

# Save the model
print("Saving back to back matmul model as:",os.getcwd()+'/matmul_model.onnx')
onnx.save(model, 'matmul_model.onnx')

