# fmt: on
import os
import numpy as np
import onnx
import onnxruntime as ort
from tools.utils import replace_slashes

curr_dir = os.getcwd()
results_dir = os.path.join(curr_dir, "WAIC_Outputs")


def get_fused_model(model_path, output_dir=results_dir):
    model_name = os.path.basename(model_path)
    files_list = os.listdir(output_dir)
    for f in files_list:
        # File name must have "_fuses.onnx" at end.
        # There might also be "_fused.onnx.data", which is not the right one.
        if f.endswith("_fused.onnx"):
            return os.path.join(output_dir, f)
    print("Fused model not found")
    return None


def get_ir_json(model_path, output_dir=results_dir):
    model_name = os.path.basename(model_path)
    files_list = os.listdir(output_dir)
    for f in files_list:
        if "_fused_IR.json" in f:
            return os.path.join(output_dir, f)
    print("IR JSON not found")
    return None


def get_node_json(model_path, output_dir=results_dir):
    model_name = os.path.basename(model_path)
    files_list = os.listdir(output_dir)
    for f in files_list:
        if "_fused_IR_unique_nodes.json" in f:
            return os.path.join(output_dir, f)
    print("Unique nodes JSON not found")
    return None


def list_files(directory):
    out_list = []
    for entry in os.scandir(directory):
        if entry.is_dir() or entry.name.startswith("."):
            continue
        out_list.append(entry.name)
    return out_list


def get_outs_from_graph(ort_session, outputs, data_list, input_nodes):
    input_map = {}
    for i in range(len(input_nodes)):
        input_map[input_nodes[i]] = data_list[i]
    ort_outs = ort_session.run(outputs, input_map)
    return ort_outs


def get_data_from_idx(model_name, input_nodes, output_nodes, idx):
    input_paths = []
    output_paths = []
    base_dir_path = os.path.dirname(model_name)
    for i in input_nodes:
        dir_name = replace_slashes(i)
        full_dir_name = os.path.join(base_dir_path, dir_name)
        node_inputs = list_files(full_dir_name)
        node_inputs_sorted = sorted(
            node_inputs,
            key=lambda x: (
                int(x.split("_")[-1].split(".")[0])
                if x.split("_")[-1].split(".")[0].isdigit()
                else int(x.split("_")[-2])
            ),
        )
        idx_input = os.path.join(full_dir_name, node_inputs_sorted[idx])
        input_paths.append(idx_input)
    for i in output_nodes:
        dir_name = replace_slashes(i)
        full_dir_name = os.path.join(base_dir_path, dir_name)
        node_outputs = list_files(full_dir_name)
        node_outputs_sorted = sorted(
            node_outputs,
            key=lambda x: (
                int(x.split("_")[-1].split(".")[0])
                if x.split("_")[-1].split(".")[0].isdigit()
                else int(x.split("_")[-2])
            ),
        )
        idx_output = os.path.join(full_dir_name, node_outputs_sorted[idx])
        output_paths.append(idx_output)
    return input_paths, output_paths


def get_tensor_type(tensor_type):
    dtype = ""
    if tensor_type == 1:
        dtype = "float32"
    elif tensor_type == 6:
        dtype = "int32"
    elif tensor_type == 2:
        dtype = "uint8"
    elif tensor_type == 4:
        dtype = "uint16"
    elif tensor_type == 7:
        dtype = "int64"
    else:
        print("Unsupported tensor type")
    return dtype


def get_and_reshape_tensors(
    input_paths, output_paths, input_shapes, output_shapes, input_types, output_types
):
    input_tensors = []
    output_tensors = []
    assert len(input_paths) == len(input_shapes)
    assert len(output_paths) == len(output_shapes)
    for i in range(len(input_paths)):
        dtype = get_tensor_type(input_types[i])
        _, extension = os.path.splitext(input_paths[i])
        if extension == ".npy":
            input_tensors.append(np.load(input_paths[i]))
        else:
            tensor = np.fromfile(input_paths[i], dtype=dtype)
            tensor = tensor.reshape(input_shapes[i])
            input_tensors.append(tensor)
    for i in range(len(output_paths)):
        dtype = get_tensor_type(output_types[i])
        _, extension = os.path.splitext(output_paths[i])
        if extension == ".npy":
            output_tensors.append(np.load(output_paths[i]))
        else:
            tensor = np.fromfile(output_paths[i], dtype=dtype)
            tensor = tensor.reshape(output_shapes[i])
            output_tensors.append(tensor)

    return input_tensors, output_tensors


def get_input_nodes(model):
    input_nodes = []
    for i in model.graph.input:
        input_nodes.append(i.name)
    return input_nodes

def get_output_nodes(model):
    output_nodes = []
    for i in model.graph.output:
        output_nodes.append(i.name)
    return output_nodes


def transpose(data):
    data_trans = None
    if len(data.shape) == 4:
        if data.shape[2] == data.shape[3]:
            data_trans = np.transpose(data, (0, 2, 3, 1))
        elif data.shape[1] == data.shape[2]:
            data_trans = np.transpose(data, (0, 3, 1, 2))
    if len(data.shape) == 3:
        data_trans = np.transpose(data, (0, 2, 1))
    if len(data.shape) == 2:
        data_trans = np.transpose(data, (0, 1))
    return data_trans


def modify_input_tensors_for_fused_graph(model_name, data_list):
    _, input_shapes, _, _, output_shapes, _ = get_io_shapes_for_model(model_name)
    max_input_count = 1000
    transposed = False
    data_mod = []
    for idx in range(len(input_shapes)):
        if input_shapes[idx] != list(data_list[idx].shape):
            if data_list[idx] is None:
                print("ERROR: Input data value is None")
                break
            data_mod.append(transpose(data_list[idx]))
        else:
            data_mod.append(data_list[idx])
    for idx in range(len(input_shapes), max_input_count):
        if data_list[idx] is not None:
            print("ERROR: There is input data which shape info is not parsed")
            break
        data_mod.append(None)

    for idx in range(max_input_count, max_input_count + len(output_shapes)):
        if output_shapes[idx - max_input_count] != list(data_list[idx].shape):
            if data_list[idx] is None:
                print("ERROR: Output data value is None")
                break
            data_mod.append(transpose(data_list[idx]))
        else:
            data_mod.append(data_list[idx])
    for idx in range(len(output_shapes) + max_input_count, len(data_list)):
        if data_list[idx] is not None:
            print("ERROR: There is output data which shape info is not parsed")
            break
        data_mod.append(None)

    return data_mod


def get_io_shapes_for_model(model_name):
    model_name = os.path.abspath(model_name)
    model_orig = onnx.load(model_name)
    input_nodes = []
    input_shapes = []
    input_types = []
    output_nodes = []
    output_shapes = []
    output_types = []
    for i in model_orig.graph.input:
        input_nodes.append(i.name)
        shape = []
        for s in i.type.tensor_type.shape.dim:
            if s.dim_value == 0 and "PSU1" in model_name:
                shape.append(2048)
            elif s.dim_value == 0 and "PSU0" in model_name:
                shape.append(64)
            elif s.dim_value == 0 and "HFDS" in model_name:
                shape.append(4096)
            else:
                shape.append(s.dim_value)
        input_shapes.append(shape)
        input_types.append(i.type.tensor_type.elem_type)
    for i in model_orig.graph.output:
        output_nodes.append(i.name)
        shape = []
        if "PSR_v1.1.onnx" in model_name:
            shape = [1, 64, 64, 4]
        else:
            for s in i.type.tensor_type.shape.dim:
                if s.dim_value == 0 and "PSU1" in model_name:
                    shape.append(2048)
                elif s.dim_value == 0 and "PSU0" in model_name:
                    shape.append(64)
                elif s.dim_value == 0 and "HFDS" in model_name:
                    shape.append(4096)
                else:
                    shape.append(s.dim_value)
        output_shapes.append(shape)
        output_types.append(i.type.tensor_type.elem_type)

    return (
        input_nodes,
        input_shapes,
        input_types,
        output_nodes,
        output_shapes,
        output_types,
    )


def get_input_for_model(model_name, data_idx):
    (
        input_nodes,
        input_shapes,
        input_types,
        output_nodes,
        output_shapes,
        output_types,
    ) = get_io_shapes_for_model(model_name)

    input_paths, output_paths = get_data_from_idx(
        model_name, input_nodes, output_nodes, data_idx
    )
    input_tensors, output_tensors = get_and_reshape_tensors(
        input_paths,
        output_paths,
        input_shapes,
        output_shapes,
        input_types,
        output_types,
    )
    data_list = []
    data_list.extend(input_tensors)
    for i in range(len(data_list), 1000):
        data_list.append(None)
    data_list.extend(output_tensors)

    file_name_list = []
    for i in input_paths:
        file_name_list.append(os.path.basename(i))
    for i in output_paths:
        file_name_list.append(os.path.basename(i))

    return data_list, file_name_list
