#Initial test bench to build simple onnx model and run it on O-GOAT

import os
import onnx
import numpy as np
from onnx import helper, TensorProto, numpy_helper
import sys
import re

x, y, z = [int(i) for i in sys.argv[1:4]]
dtype_ifm, dtype_wgt, dtype_ofm = sys.argv[4:7]
actxact = sys.argv[7]
print(x, y, z, dtype_ifm, dtype_wgt, dtype_ofm, actxact)

max_wgt = 123
min_wgt = 0
bits = int((re.findall(r'\d+$', dtype_wgt))[0])
if dtype_wgt[0] == 'u':
    max_wgt = 2**bits - 1
else:
    max_wgt = 2**(bits-1) - 1
    min_wgt = -2**(bits-1)

# Create input tensors
if actxact=='yes':
    inputs = [helper.make_tensor_value_info('input1', TensorProto.FLOAT, [x, y]), helper.make_tensor_value_info('input2', TensorProto.FLOAT, [y, z])]
else:
    inputs = [helper.make_tensor_value_info('input1', TensorProto.FLOAT, [x, y])]

# Create output tensor
#outputs = [helper.make_empty_tensor_value_info('output')]
outputs = [helper.make_tensor_value_info('output', TensorProto.FLOAT, [x, z])]

# Create GEMM nodes

node_list = []

node_list.append(helper.make_node('QuantizeLinear', inputs=['input1', 'input1_q_scale', 'input1_q_zero_point'], outputs=['input1_q_out'], name='input1_quant', domain='com.microsoft'))
node_list.append(helper.make_node('DequantizeLinear', inputs=['input1_q_out', 'gemm1_act_dq_scale', 'gemm1_act_dq_zero_point'], outputs=['gemm1_act_dq_out'], name='gemm1_act_dq', domain='com.microsoft'))

if actxact=='yes':
    node_list.append(helper.make_node('QuantizeLinear', inputs=['input2', 'input2_q_scale', 'input2_q_zero_point'], outputs=['input2_q_out'], name='input2_quant', domain='com.microsoft'))
    node_list.append(helper.make_node('DequantizeLinear', inputs=['input2_q_out', 'gemm1_wgt_dq_scale', 'gemm1_wgt_dq_zero_point'], outputs=['gemm1_wgt_dq_out'], name='gemm1_wgt_dq', domain='com.microsoft'))
else:
    node_list.append(helper.make_node('DequantizeLinear', inputs=['gemm1_wgt_dq_x', 'gemm1_wgt_dq_scale', 'gemm1_wgt_dq_zero_point'], outputs=['gemm1_wgt_dq_out'], name='gemm1_wgt_dq', domain='com.microsoft'))
node_list.append(helper.make_node('MatMul', inputs=['gemm1_act_dq_out', 'gemm1_wgt_dq_out'], outputs=['gemm1_out'], name='gemm1', domain='ai.onnx'))
node_list.append(helper.make_node('QuantizeLinear', inputs=['gemm1_out', 'gemm1_act_q_scale', 'gemm1_act_q_zero_point'], outputs=['gemm1_act_q_out'], name='gemm1_act_q', domain='com.microsoft'))

node_list.append(helper.make_node('DequantizeLinear', inputs=['gemm1_act_q_out', 'output_q_scale', 'output_q_zero_point'], outputs=['output'], name='output_q_', domain='com.microsoft'))

initializer_list = []
initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'input1_q_scale'))
initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_ifm), 'input1_q_zero_point'))

initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'gemm1_act_dq_scale'))
initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_ifm), 'gemm1_act_dq_zero_point'))
if actxact=='yes':
    initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'input2_q_scale'))
    initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_ifm), 'input2_q_zero_point'))
else:
    initializer_list.append(numpy_helper.from_array(np.random.randint(min_wgt, max_wgt, size=y*z).reshape(y,z).astype(dtype_wgt), 'gemm1_wgt_dq_x'))
initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'gemm1_wgt_dq_scale'))
initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_wgt), 'gemm1_wgt_dq_zero_point'))
initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'gemm1_act_q_scale'))
initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_ofm), 'gemm1_act_q_zero_point'))

initializer_list.append(numpy_helper.from_array(np.array(0.0005).astype('float32'), 'output_q_scale'))
initializer_list.append(numpy_helper.from_array(np.array(127).astype(dtype_ofm), 'output_q_zero_point'))


opset_import_list = []
opset_import_list.append(helper.make_opsetid("", 18))
opset_import_list.append(helper.make_opsetid("aimet_torch", 1))
opset_import_list.append(helper.make_opsetid("com.microsoft", 1))

ir_version_list = 6

# Create the graph
graph = helper.make_graph(
    node_list,
    'gemm_graph',
    inputs,
    outputs,
    initializer_list,
)

# Create the model
model = helper.make_model(graph, ir_version=ir_version_list, opset_imports=opset_import_list)

# Save the model
print("Saving back to back gemm model as:",os.getcwd()+'/gemm_model.onnx')
onnx.save(model, 'gemm_model.onnx')

