#Initial test bench to build simple onnx model and run it on O-GOAT

import os
import onnx
import numpy as np
from onnx import helper, TensorProto

# Create input tensors
input1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2048, 2048])
input2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [2048, 1024])
input3 = helper.make_tensor_value_info('input3', TensorProto.FLOAT, [1024, 1000])

# Create output tensor
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [2048, 1000])

# Create GEMM nodes
gemm1 = helper.make_node(
    'Gemm1',
    inputs=['input1', 'input2'],
    outputs=['intermediate'],
)

gemm2 = helper.make_node(
    'Gemm2',
    inputs=['intermediate', 'input3'],
    outputs=['output'],
)

# Create the graph
graph = helper.make_graph(
    [gemm1, gemm2],
    'gemm_graph',
    [input1, input2, input3],
    [output]
)

# Create the model
model = helper.make_model(graph)

# Save the model
print("Saving back to back gemm model as:",os.getcwd()+'/gemm_model.onnx')
onnx.save(model, 'gemm_model.onnx')

