import onnx
import onnxruntime as ort
import os
import re
import numpy as np
import math
import json

def remove_leading_0s(hexstr, num_hex = 8):

    if hexstr == '00000000':
        return '0'

    else:

        num_0s_encountered = 0
        for i in range(0, num_hex):
            if hexstr[i] == '0':
                num_0s_encountered += 1

            else:
                break

        hexstr = hexstr[num_0s_encountered:]

    return hexstr

def write_2_hex(array, num_bits_per_row, file_name):
    bits_per_hex = 4
    hex_per_byte= 2
    num_vals_per_row = np.int64(num_bits_per_row/(bits_per_hex*hex_per_byte))
    num_rows_to_write = np.int64(np.ceil(len(array) / num_vals_per_row))
    num_elems_last_row = len(array) - num_vals_per_row*(num_rows_to_write-1)
    f = open(file_name, "w")
    array_idx = 0
    print(num_rows_to_write)
    for i in range(0, num_rows_to_write-1):
        hex_str = ''
        for j in range(0, num_vals_per_row):
            array_val = array[array_idx+num_vals_per_row-1]
            hex_str += array_val.tobytes().hex()
            array_idx -= 1

        array_idx += 2*num_vals_per_row
        hex_str = remove_leading_0s(hex_str, num_vals_per_row * hex_per_byte)
        f.write(hex_str+'\n')

    array_idx = len(array)-num_elems_last_row-1
    hex_str = ''
    for k in range(0, num_elems_last_row):
        array_val = array[array_idx+num_elems_last_row]
        hex_str += array_val.tobytes().hex()
        array_idx -= 1

    hex_str = remove_leading_0s(hex_str, num_elems_last_row * hex_per_byte)
    f.write(hex_str + '\n')

    f.close()

def reshape_act(act_tensor):

    act_tensor = np.transpose(act_tensor,(0, 2, 3, 1))
    act_tensor = act_tensor.reshape(np.shape(act_tensor)[1:])
    num_h = act_tensor.shape[0]
    num_w = act_tensor.shape[1]
    num_ch = act_tensor.shape[-1]
    new_shape = np.array(np.shape(act_tensor)[0:-1])

    if num_ch == 3:
        ch_group = 4
        if (new_shape[1] % 2) == 1:
            new_shape[1] += 1
    else:
        ch_group = 8

    num_padded_ch = np.int64(np.ceil(num_ch/ch_group)*ch_group)

    new_shape = np.append(new_shape, num_padded_ch)

    act_tensor_padded = np.zeros(new_shape)
    print(new_shape)
    act_tensor_padded[0:num_h, 0:num_w, 0:num_ch] = act_tensor
    reshaped_tensor = np.zeros(np.prod(act_tensor_padded.shape)).astype(np.int8)
    height = act_tensor_padded.shape[0]
    width = act_tensor_padded.shape[1]
    ch_loop = np.int64(num_padded_ch/ch_group)
    idx = 0

    for h in range(0, height):
        for c in range(0, ch_loop):
            for w in range(0, width):

                reshaped_tensor[idx:idx+ch_group] = act_tensor_padded[h, w, c*ch_group:(c+1)*ch_group]
                idx += ch_group

    return reshaped_tensor

def replace_slashes(line):
    out = re.sub(r'[^\w\s-]', '_', line)
    return out

def save_edges(edges_list, ort_outs_orig_dict,
                 top_folder_name, folder = "results", txt_output = False):
    folder_name = os.path.join(top_folder_name, folder)
    print("Dumping data in " + folder_name)
    if not os.path.isdir(folder_name):
        os.makedirs(folder_name)
    for line in edges_list:
        line = line.replace('\n', '')
        output = ort_outs_orig_dict[line]
        file_name = replace_slashes(line)
        if txt_output:
            output = reshape_act(output)
            tensor_write_loc = os.path.join(folder_name, file_name + '.txt')
            write_2_hex(output, 32, tensor_write_loc)

        else:
            tensor_write_loc = os.path.join(folder_name, file_name + '.bin')
            output.tofile(tensor_write_loc)


def save_ifm_ofm(model_orig, ifm_ofm_dict, ort_outs_orig_dict,
                 top_folder_name, prefix = "", txt_output = False):
    folder_name = os.path.join(top_folder_name, prefix)
    if not os.path.isdir(folder_name):
        os.makedirs(folder_name)
    if prefix == "ort":
        for out in model_orig.graph.output:
            tensor_name = out.name
            if not tensor_name in ort_outs_orig_dict:
               continue
            output = ort_outs_orig_dict[tensor_name]
            file_name = replace_slashes(tensor_name)
            if txt_output:
                output = reshape_act(output)
                tensor_write_loc = os.path.join(folder_name, file_name + '.txt')
                write_2_hex(output, 32, tensor_write_loc)

            else:
                tensor_write_loc = os.path.join(folder_name, file_name + '.bin')
                output.tofile(tensor_write_loc)
        return
    for key, value in ifm_ofm_dict.items():
        for i in range(len(value['ifm'])):
            tensor_name = value['ifm'][i]
            if not tensor_name in ort_outs_orig_dict:
                continue
            output = ort_outs_orig_dict[tensor_name]
            file_name = replace_slashes(tensor_name)
            if txt_output:
                output = reshape_act(output)
                tensor_write_loc = os.path.join(folder_name, file_name + '.txt')
                write_2_hex(output, 32, tensor_write_loc)

            else:
                tensor_write_loc = os.path.join(folder_name, file_name + '.bin')
                output.tofile(tensor_write_loc)
        for i in range(len(value['ofm'])):
            tensor_name = value['ofm'][i]
            if not tensor_name in ort_outs_orig_dict:
                continue
            output = ort_outs_orig_dict[tensor_name]
            file_name = replace_slashes(tensor_name)
            if txt_output:
                output = reshape_act(output)
                tensor_write_loc = os.path.join(folder_name, file_name + '.txt')
                write_2_hex(output, 32, tensor_write_loc)

            else:
                tensor_write_loc = os.path.join(folder_name, file_name + '.bin')
                output.tofile(tensor_write_loc)

def save_wgt_bias_from_pickle(weights_qdq_dict, ini_dict, top_folder_name):
    for key, value in weights_qdq_dict.items():
        key_mod = replace_slashes(key)
        folder_name = os.path.join(top_folder_name, key_mod)
        os.makedirs(folder_name, exist_ok=True)
        const_dict = {}
        const_dict['qdq'] = []
        for k, val in value.items():
            if k == "dummy":
                continue
            if "_scale" in k or "_zero_point" in k:
                if val[0] not in ini_dict:
                    print("Wrong input for node " + key)
                    continue
                tmp = ini_dict[val[0]].tolist()
                if not isinstance(tmp, list) or len(tmp) > 1:
                    element = {k: tmp}
                    const_dict['qdq'].append(element)
                else:
                    element = {k: tmp[0]}
                    const_dict['qdq'].append(element)
            else:
                if len(val) != 1:
                    print("Error: Multiple tensor names for one param_name: ", val)
                    continue
                const_output = ini_dict[val[0]]
                const_out_name = os.path.join(folder_name, k)
                const_output.tofile(const_out_name + '.bin')

        if (len(const_dict) != 0):
            os.makedirs(folder_name, exist_ok=True)
            with open(os.path.join(folder_name, 'graph_params.json'), 'w') as f:
                json.dump(const_dict, f, indent=2)

def save_wgt_bias(weights_qdq_dict, ort_outs_orig_dict, top_folder_name):
    for key, value in weights_qdq_dict.items():
        const_dict = {}
        key_mod = replace_slashes(key)
        folder_name = os.path.join(top_folder_name, key_mod)
        os.makedirs(folder_name, exist_ok=True)
        const_dict = {}
        const_dict['qdq'] = []
        for k, val in value.items():
            if k == "dummy":
                continue
            if "_scale" in k or "_zero_point" in k:
                tmp = ort_outs_orig_dict[val[0]].tolist()
                if not isinstance(tmp, list) or len(tmp) > 1:
                    element = {k: tmp}
                    const_dict['qdq'].append(element)
                else:
                    element = {k: tmp[0]}
                    const_dict['qdq'].append(element)
            else:
                if len(val) != 1:
                    print("Error: Multiple tensor names for one param_name: ", val)
                    continue
                const_output = ort_outs_orig_dict[val[0]]
                const_out_name = os.path.join(folder_name, k)
                const_output.tofile(const_out_name + '.bin')
        if (len(const_dict) != 0):
            os.makedirs(folder_name, exist_ok=True)
            with open(os.path.join(folder_name, 'graph_params.json'), 'w') as f:
                json.dump(const_dict, f, indent=2)

def append_outputs_to_model(model_orig):
    for inp in model_orig.graph.input:
        model_orig.graph.output.extend([
            onnx.ValueInfoProto(name=inp.name)])
    for node in model_orig.graph.node:
        if node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear":
            model_orig.graph.output.extend([
                onnx.ValueInfoProto(name=node.output[0])])

def append_outputs_to_fused_model(model_orig, _dict):
    all_inputs = get_all_model_inputs(model_orig)
    for key, value in _dict.items():
        for list_type, names_list in value.items():
            for ln in names_list:
                line = ln.replace('\n', '')
                #extend outputs to get intermediate results
                if ln in all_inputs or line in all_inputs:
                    model_orig.graph.output.extend([
                        onnx.ValueInfoProto(name=line)])
                else:
                    print('\x1b[0;37;41m' + 
                          line + " is missing from model" '\x1b[0m')

def get_all_model_inputs(onnx_model):
    inputs = set()
    for node in onnx_model.graph.node:
        inputs.update(node.input)
        inputs.update(node.output)

    return list(inputs)
