import json
import numpy as np
import os
import argparse

fused_nodes_inputs = os.path.join(os.path.dirname(__file__), "fused_nodes_inputs.json")


def extract_ifm_ofm_channels(ir_json, node_json, all_nodes = "unique"):
    dir_name = os.path.dirname(ir_json)
    files = os.listdir(dir_name)
    tiling_json = ""
    for f in files:
        if "tilings.json" in f:
            tiling_json = os.path.join(dir_name, f)

    with open(ir_json) as f:
        json_dict = json.load(f)
    with open(node_json) as f:
        node_dict = json.load(f)
    with open(tiling_json) as f:
        tiling_dict = json.load(f)

    graph_ifm_ofm_dict = {}
    if all_nodes == "unique":
        for key, values in node_dict.items():
            ifm_ofm_dict = {}
            ifm_list = []
            ofm_list = []
            single_node = values['nodenames'][0]
            for inp in json_dict[single_node]["inputs"]:
                if inp["type"] == "act":
                    ifm_list.append(inp["name"])
            node_outputs = json_dict[single_node]['out_act_signal_name']
            ofm_list.append(node_outputs)
            ifm_ofm_dict['ifm'] = ifm_list
            ifm_ofm_dict['ofm'] = ofm_list
            graph_ifm_ofm_dict[single_node] = ifm_ofm_dict
    elif all_nodes == "fused_unique":
        for key, values in tiling_dict.items():
            if values['layer_info']['op_type'] == values['layer_info']['orig_op_type']:
                continue
            ifm_ofm_dict = {}
            ifm_list = []
            ofm_list = []
            single_node = values['layer_info']['nodenames'][0]
            for inp in json_dict[single_node]["inputs"]:
                if inp["type"] == "act":
                    ifm_list.append(inp["name"])
            node_outputs = json_dict[single_node]['out_act_signal_name']
            ofm_list.append(node_outputs)
            ifm_ofm_dict['ifm'] = ifm_list
            ifm_ofm_dict['ofm'] = ofm_list
            graph_ifm_ofm_dict[single_node] = ifm_ofm_dict
    elif all_nodes == "fused":
        for key, values in tiling_dict.items():
            if not 'layer_info' in values:
                continue
            if values['layer_info']['op_type'] == values['layer_info']['orig_op_type']:
                continue
            ifm_ofm_dict = {}
            ifm_list = []
            ofm_list = []
            for single_node in values['layer_info']['nodenames']:
                for inp in json_dict[single_node]["inputs"]:
                    if inp["type"] == "act":
                        ifm_list.append(inp["name"])
                node_outputs = json_dict[single_node]['out_act_signal_name']
                ofm_list.append(node_outputs)
                ifm_ofm_dict['ifm'] = ifm_list
                ifm_ofm_dict['ofm'] = ofm_list
                graph_ifm_ofm_dict[single_node] = ifm_ofm_dict
    elif all_nodes == "all":
        for key, values in json_dict.items():
            ifm_ofm_dict = {}
            ifm_list = []
            ofm_list = []
            for inp in values["inputs"]:
                if inp["type"] == "act":
                    ifm_list.append(inp["name"])
            node_outputs = values['out_act_signal_name']
            ofm_list.append(node_outputs)
            ifm_ofm_dict['ifm'] = ifm_list
            ifm_ofm_dict['ofm'] = ofm_list
            graph_ifm_ofm_dict[key] = ifm_ofm_dict
    else:
        channels_file = open(all_nodes, 'r')
        channels_file_lines = channels_file.readlines()
        for line in channels_file_lines:
            line = line.replace('\n', '')
            for key, values in json_dict.items():
                ifm_list = []
                ofm_list = []
                if line in values['inputs']:
                    ifm_list.append(line)
                if line in values['outputs']:
                    ofm_list.append(line)
                if key not in graph_ifm_ofm_dict:
                    ifm_ofm_dict = {}
                    ifm_ofm_dict['ifm'] = ifm_list
                    ifm_ofm_dict['ofm'] = ofm_list
                    if len(ifm_list) != 0 or len(ofm_list) != 0:
                        graph_ifm_ofm_dict[key] = ifm_ofm_dict
                else:
                    graph_ifm_ofm_dict[key]['ifm'].extend(ifm_list)
                    graph_ifm_ofm_dict[key]['ofm'].extend(ofm_list)

    return graph_ifm_ofm_dict

def extract_const_channels(ir_json):
    with open(ir_json) as f:
        json_dict = json.load(f)
    weights_qdq_dict = {}
    for key, value in json_dict.items():
        node_wgt_qdq = {}
        op_type = value['op_type']
        if "runtime" in op_type or "noop" in op_type:
            continue
        idx = 0
        for v in value['inputs']:
            if v['type'] == "const":
                if not "param_name" in v and not "name" in v:
                    continue
                name = ""
                if "param_name" in v:
                    if not "scale" in v['param_name'] and not "zero_point" in v['param_name']:
                        if v['name'].endswith("scale"):
                            name = "new_scale_" + str(idx)
                            idx = idx + 1
                        elif v['name'].endswith("zero_point"):
                            name = "new_zero_point_" + str(idx)
                            idx = idx + 1
                        else:
                            name = v['param_name']
                    else:
                        name = v['param_name']
                elif v['name'].endswith("scale"):
                    name = "new_scale_" + str(idx)
                    idx = idx + 1
                elif v['name'].endswith("zero_point"):
                    name = "new_zero_point_" + str(idx)
                    idx = idx + 1
                else:
                    continue
                node_wgt_qdq[name] = [v['name']]
        weights_qdq_dict[key] = node_wgt_qdq

    return weights_qdq_dict


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ir_json", required=True)
    parser.add_argument("--node_json", required=True)

    args = parser.parse_args()
    ir_json    = os.path.abspath(args.ir_json)
    node_json  = os.path.abspath(args.node_json)

    extract_ifm_ofm_channels(args.ir_json, args.node_json)
    extract_const_channels(args.ir_json)
