import os
import argparse
import json

# from ...OGOAT.src.L1_fusion.L1_utils.model_IR_utils import get_kernels_for_model

bytes_for = {"mx9" : 9/8, "bfp16" : 9/8, "float16" : 2, "bfloat16" : 2, "fp32" : 4, "float32" : 4, "double" : 8, "int4" : 1/2, "uint4": 1/2, "int8" : 1, "uint8": 1, "int16": 2, "uint16": 2,  "int32": 4, "uint32": 4, "int64": 8, "uint64": 8,"": 0}

def get_standard_onnx_op_attributes(op_type):
    attributes = {}
    if op_type == "Conv":
        attributes = {"kernel_shape": [3, 3], "strides": [1, 1], "pads": [1, 1, 1, 1], "dilations": [1, 1], "group": 1}
    elif op_type == "Gemm":
        attributes = {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
    elif op_type == "MatMul":
            attributes = {}
    elif op_type == "BatchNormalization":
        attributes = {"epsilon": 1e-5, "momentum": 0.9}
    elif op_type == "GroupNormalization":
        attributes = {"num_groups": 32, "epsilon": 1e-5}
    elif op_type == "LayerNormalization":
        attributes = {"epsilon": 1e-5}
    elif op_type == "LpNormalization":
        attributes = {"p": 2, "axis": -1}
    elif op_type == "MaxPool":
        attributes = {"kernel_shape": [2, 2], "strides": [2, 2], "pads": [0, 0, 0, 0]}
    elif op_type == "Resize":
        attributes = {"mode": "nearest", "coordinate_transformation_mode": "half_pixel", "nearest_mode": "round_prefer_floor"}
    elif op_type == "Add":
        attributes = {}
    elif op_type == "Sub":
        attributes = {}
    elif op_type == "Mul":
        attributes = {}
    elif op_type == "Div":
        attributes = {}
    elif op_type == "Relu":
        attributes = {}
    elif op_type == "LeakyRelu":
        attributes = {"alpha": 0.01}
    elif op_type == "Sigmoid":
        attributes = {}
    elif op_type == "Tanh":
        attributes = {}
    elif op_type == "Softmax":
        attributes = {"axis": 1}
    elif op_type == "AveragePool":
        attributes = {"kernel_shape": [2, 2], "strides": [2, 2], "pads": [0, 0, 0, 0]}
    elif op_type == "GlobalAveragePool":
        attributes = {}
    elif op_type == "Dropout":
        attributes = {"ratio": 0.5}
    elif op_type == "Flatten":
        attributes = {"axis": 1}
    elif op_type == "Reshape":
        attributes = {}
    elif op_type == "Transpose":
        attributes = {"perm": [0, 2, 3, 1]}
    elif op_type == "Concat":
        attributes = {"axis": 1}
    elif op_type == "Split":
        attributes = {"axis": 1}
    elif op_type == "Squeeze":
        attributes = {"axes": [0]}
    elif op_type == "Unsqueeze":
        attributes = {"axes": [0]}
    elif op_type == "Pad":
        attributes = {"pads": [1, 1, 1, 1], "mode": "constant", "value": 0.0}
    elif op_type == "ReduceMean":
        attributes = {"axes": [1], "keepdims": 1}
    elif op_type == "ReduceSum":
        attributes = {"axes": [1], "keepdims": 1}
    elif op_type == "ReduceMax":
        attributes = {"axes": [1], "keepdims": 1}
    elif op_type == "ReduceMin":
        attributes = {"axes": [1], "keepdims": 1}
    elif op_type == "ReduceProd":
        attributes = {"axes": [1], "keepdims": 1}
    elif op_type == "ArgMax":
        attributes = {"axis": 1, "keepdims": 1}
    elif op_type == "ArgMin":
        attributes = {"axis": 1, "keepdims": 1}
    elif op_type == "Tile":
        attributes = {"repeats": [1, 1, 1, 1]}
    elif op_type == "Clip":
        attributes = {"min": 0.0, "max": 1.0}
        
    return attributes

class model_IR:

    def __init__(self, 
                op_type = "Add_qdq_BroadCast_uint16xuint16xuint16", 
                list_in_shapes = [[1, 1280, 16, 16], [1, 1280, 1, 1]], 
                list_in_dtypes = ["uint16", "uint16"], 
                list_in_types = ["act", "wgt"],
                list_out_shapes = [[1, 1280, 16, 16]], 
                list_out_dtypes = ["uint16"]
                ):
        self.op_type = op_type
        self.orig_op_type = op_type.split("_")[0]
        self.inputs = list()
        self.outputs = list()

        
        for i in range(len(list_in_shapes)):
            inp_dict = dict()
            inp_dict["type"] = list_in_types[i]
            inp_dict["shape"] = list_in_shapes[i]
            inp_dict["dtype"] = list_in_dtypes[i]
            inp_dict["dtype_bytes"] = bytes_for[list_in_dtypes[i]]
            self.inputs.append(inp_dict)

        for i in range(len(list_out_shapes)):
            out_dict = dict()
            out_dict["type"] = "act"
            out_dict["shape"] = list_out_shapes[i]
            out_dict["dtype"] = list_out_dtypes[i]
            out_dict["dtype_bytes"] = bytes_for[list_out_dtypes[i]]
            self.outputs.append(out_dict)
    
    def add_scale_zp_fields(self):
        if "qdq" in self.op_type:
            for i in range(len(self.inputs)):
                scale_dict = dict()
                scale_dict["type"] = "wgt"
                scale_dict["shape"] = []
                scale_dict["dtype"] = "float32"
                scale_dict["dtype_bytes"] = bytes_for["float32"]
                self.inputs.append(scale_dict)

                zp_dict = dict()
                zp_dict["type"] = "wgt"
                zp_dict["shape"] = []
                zp_dict["dtype"] = self.inputs[i]["dtype"]
                zp_dict["dtype_bytes"] = bytes_for[self.inputs[i]["dtype"]]
                self.inputs.append(zp_dict)

            for i in range(len(self.outputs)):
                scale_dict = dict()
                scale_dict["type"] = "wgt"
                scale_dict["shape"] = []
                scale_dict["dtype"] = "float32"
                scale_dict["dtype_bytes"] = bytes_for["float32"]
                self.inputs.append(scale_dict)

                zp_dict = dict()
                zp_dict["type"] = "wgt"
                zp_dict["shape"] = []
                zp_dict["dtype"] = self.outputs[i]["dtype"]
                zp_dict["dtype_bytes"] = bytes_for[self.outputs[i]["dtype"]]
                self.inputs.append(zp_dict)
            

    def add_remaining_fields(self):
        self.in_act_shape = self.inputs[0]["shape"]
        self.in_wgt_shape = self.inputs[1]["shape"]
        self.in_wgt1_shape = self.inputs[2]["shape"] if self.inputs[2]["shape"] else []
        self.out_act_shape = self.outputs[0]["shape"]
        self.in_datatype = self.inputs[0]["dtype"]
        self.wgt_datatype = self.inputs[1]["dtype"]
        self.wgt1_datatype = self.inputs[2]["dtype"] if self.inputs[2]["dtype"] else ""
        self.out_datatype = self.outputs[0]["dtype"]

        self.in_bytes = self.inputs[0]["dtype_bytes"]
        self.wgt_bytes = self.inputs[1]["dtype_bytes"]
        self.wgt1_bytes = self.inputs[2]["dtype_bytes"]
        self.out_bytes = self.outputs[0]["dtype_bytes"]

        self.attributes = {
            "disable_dq1": [0],
            "disable_dq0": [0],
            "disable_q": [0]
        }
        self.attributes.update(get_standard_onnx_op_attributes(self.orig_op_type))

        self.qdq_symmetry = "None"
        self.coeff_shape = "None"
        self.in_act_residency = "L3"
        self.out_act_residency = "L3"
        self.Frequency = 5
        self.nodenames = [str(self.op_type + "_" + str(i)) for i in range(self.Frequency)]

    def write_to_file(self):
        model_structure = {
            self.orig_op_type + "_0": {
            "op_type": self.op_type,
            "inputs": self.inputs, #json.dumps(self.inputs),
            "outputs": self.outputs, #json.dumps(self.outputs),
            "in_act_shape": self.in_act_shape,
            "in_wgt_shape": self.in_wgt_shape,
            "in_wgt1_shape": self.in_wgt1_shape,
            "out_act_shape": self.out_act_shape,
            "in_datatype": self.in_datatype,
            "wgt_datatype": self.wgt_datatype,
            "wgt1_datatype": self.wgt1_datatype,
            "out_datatype": self.out_datatype,
            "in_bytes": self.in_bytes,
            "wgt_bytes": self.wgt_bytes,
            "wgt1_bytes": self.wgt1_bytes,
            "out_bytes": self.out_bytes,
            "attributes": self.attributes,
            "qdq_symmetry": self.qdq_symmetry,
            "coeff_shape": self.coeff_shape,
            "in_act_residency": self.in_act_residency,
            "out_act_residency": self.out_act_residency,
            "Frequency": self.Frequency,
            "nodenames": self.nodenames
            }
        }
        
        file_name = self.orig_op_type + "_IR_unique_nodes" + ".json"
        with open(file_name, 'w') as f:
            json.dump(model_structure, f, indent=4)
        
    # def get_kernels_list(self):
    #     #generate model kernel names
    #     model_IR_path = self.orig_op_type + "_IR_unique_nodes" + ".json"
    #     get_kernels_for_model(model_IR_path, "../../Collaterals/", print_to_file = True)

def main():
    parser = argparse.ArgumentParser(description="Initialize model_IR class and print its structure.")
    parser.add_argument('--op_type', type=str, default="Add_qdq_BroadCast_uint16xuint16xuint16", help='Operation type')
    parser.add_argument('--list_in_shapes', type=json.loads, default='[[1, 1280, 16, 16], [1, 1280, 1, 1]]', help='Input shapes')
    parser.add_argument('--list_in_dtypes', type=json.loads, default='["uint16", "uint16"]', help='Input data types')
    parser.add_argument('--list_in_types', type=json.loads, default='["act", "wgt"]', help='Input types')
    parser.add_argument('--list_out_shapes', type=json.loads, default='[[1, 1280, 16, 16]]', help='Output shapes')
    parser.add_argument('--list_out_dtypes', type=json.loads, default='["uint16"]', help='Output data types')
    args = parser.parse_args()

    model = model_IR(op_type=args.op_type, 
                     list_in_shapes=args.list_in_shapes, 
                     list_in_dtypes=args.list_in_dtypes, 
                     list_in_types=args.list_in_types, 
                     list_out_shapes=args.list_out_shapes, 
                     list_out_dtypes=args.list_out_dtypes)

    model.add_scale_zp_fields()
    model.add_remaining_fields()
    model.write_to_file()
    # model.get_kernels_list()

if __name__ == "__main__":
    main()
    # sample commands: 
    # python IR_builder.py --op_type "Silu_qdq_uint16xuint16" --list_in_shapes '[[1, 1280, 16, 16]]' --list_in_dtypes '["uint16"]' --list_in_types '["act"]' --list_out_shapes '[[1, 1280, 16, 16]]' --list_out_dtypes '["uint16"]'
    # python IR_builder.py --op_type "Add_qdq_BroadCast_uint16xuint16xuint16" --list_in_shapes '[[1, 1280, 16, 16], [1, 1280, 1, 1]]' --list_in_dtypes '["uint16", "uint16"]' --list_in_types '["act", "wgt"]' --list_out_shapes '[[1, 1280, 16, 16]]' --list_out_dtypes '["uint16"]'



