# fmt: on
import json
import sys
import os
import argparse


value_keys = [
    "op_type",
    "in_act_shape",
    "in_wgt_shape",
    #"in_wgt1_shape",
    "out_act_shape",
    #"in_datatype",
    #"wgt_datatype",
    #"wgt1_datatype",
    #"out_datatype",
    "perm",
]


# Function to extract the data that we need
# Return a Matrix of string
def extract_data(raw_data) -> list[list[str]]:
    # extracted_values[i] = ([values, ...], num_nodes)
    extracted_values: list[tuple[list[str], int]] = list()

    for key in raw_data:
        item = raw_data[key]

        values = list()

        # aggregate all of the needed value in a list.
        # this is corresponding to a row in the final csv output
        for value_key in value_keys:
            if value_key == "perm" and 'MatMul' in item.get('op_type', ''):
                attribute = item.get('attributes') 
                if attribute != None:
                    perm = {}
                    for k, v in attribute.items():
                        if 'perm' in k:
                            perm[k[-1]] = v
                    
                    if len(perm) > 0:
                        pass
                for p, v in perm.items():
                    values.append(str(v))
            else:
                values.append(str(item.get(value_key)))
        num_nodes = len(item.get("nodenames", []))
        extracted_values.append((values, num_nodes))

    return extracted_values


def convert_to_csv_string_repr(row_data: list[str]) -> str:
    return ";".join(row_data)


def print_data(model_name: str, extracted_data: list[tuple[list[str], int]]) -> None:
    # Remove duplicates rows and count their occurences:
    # To facilitate the comparaison we first convert
    # each row to a single csv string representation
    # and insert it to the dict, counting the number
    # of time it occurs.
    uniq_extracted_values: dict[str, int] = dict()

    for row_data, num_nodes in extracted_data:
        csv_row_data = convert_to_csv_string_repr(row_data)
        if csv_row_data not in uniq_extracted_values:
            uniq_extracted_values[csv_row_data] = num_nodes
        else:
            uniq_extracted_values[csv_row_data] += num_nodes

    # print the data under the form of csv with the number of occurence
    print("model;" + ";".join(value_keys) + ";number of nodes")
    for row_data in uniq_extracted_values:
        print(model_name + ";" + row_data + ";" + str(uniq_extracted_values[row_data]))


def extract_and_count_op_type(model_name: str, data):
    op_type_node_count: dict[str, int] = {}
    op_type_unique_count: dict[str, int] = {}
    for key in data:
        item = data[key]

        # Count op_type values
        op_type = item.get("op_type")
        num_nodes = len(item.get("nodenames", []))
        if op_type in op_type_node_count:
            op_type_node_count[op_type] += num_nodes
        else:
            op_type_node_count[op_type] = num_nodes
        if op_type in op_type_unique_count:
            op_type_unique_count[op_type] += 1
        else:
            op_type_unique_count[op_type] = 1

    print("model;op_type;nb_nodes;nb_unique")
    for op_type in op_type_node_count:
        print(
            ";".join(
                [
                    model_name,
                    op_type,
                    str(op_type_node_count[op_type]),
                    str(op_type_unique_count[op_type]),
                ]
            )
        )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Tool to extract and print various information about a model from the unique IR json file"
    )

    parser.add_argument(
        "--shape",
        action="store_true",
        help="extract and print in csv format the shape of the input/output/weight of each op types along with their number of occurrences in the model.",
    )
    parser.add_argument(
        "--op_type",
        action="store_true",
        help="extract and print in csv format the op type of the nodes in the model along with their occurence",
    )
    parser.add_argument(
        "--input", help="Path to the unique IR json file used to extract the data"
    )
    parser.add_argument(
        "--model_name",
        help="Specify the name of the model from which we are extracting the data",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    # Make sure the input is the unique ir json file
    unique_ir_json_path: str = args.input
    unique_ir_json_filename: str = os.path.basename(unique_ir_json_path)
    assert unique_ir_json_filename.endswith(
        "_mod_nhwc_fused_IR_unique_nodes.json"
    ), "Wrong input file. Unique IR json should be passed"

    # extract the name of the model from the unique ir json file name
    if args.model_name:
        model_name = args.model_name
    else:
        model_name_end_idx = unique_ir_json_filename.find(
            "_mod_nhwc_fused_IR_unique_nodes.json"
        )
        model_name = unique_ir_json_filename[:model_name_end_idx]

    # Load the json file
    with open(args.input, "r") as fd:
        data = json.load(fd)

    # Extract the shape of the input/output/weight of the various nodes.
    # Count their occurences and print them as csv format
    if args.shape:
        extracted_data = extract_data(data)
        print_data(model_name, extracted_data)

    # Extract and print in csv format the occurence of each op type
    if args.op_type:
        extract_and_count_op_type(model_name, data)
