import os
import glob
import yaml
import copy
import json
import numpy
import argparse
import logging
import onnx
import subprocess
import re
import sys
from collections import defaultdict
from onnx.helper import make_attribute
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list


def parse_calltree(lines, wrapper_size_dict, wrapper_levels_dict):
    root = {}
    stack = [(-1, root)]
    for line in lines:
        indent = len(line) - len(line.lstrip())
        key = line.strip()
        while stack and indent <= stack[-1][0]:
            stack.pop()
        parent_dict = stack[-1][1]
        if key in wrapper_levels_dict:
            parent_dict[key] = copy.deepcopy(wrapper_levels_dict[key])
        else:
            parent_dict[key] = {'size': wrapper_size_dict[key], 'subkernels': {}}
        stack.append((indent, parent_dict[key]['subkernels']))

    def parse_dict(wrapper, sub_dict, wrapper_list):
        wrapper_list.append(wrapper)
        if sub_dict["subkernels"]:
            for sub_kernel in sub_dict["subkernels"]:
                if sub_kernel in wrapper_list:
                    sub_dict["subkernels"].pop(sub_kernel)
                    return
                parse_dict(sub_kernel, sub_dict["subkernels"][sub_kernel], wrapper_list)

    parse_dict(lines[0], root[lines[0]], [])
    return root


def extract_all_levels(d, result=None):
    if result is None:
        result = {}
    for key, value in d.items():
        result[key] = value
        if isinstance(value, dict):
            extract_all_levels(value, result)
    return result


def remove_redundant(arr):
    arr_dict = defaultdict(list)
    remove_idx = []
    for i in range(len(arr)):
        arr_dict[arr[i].strip()].append(i)
    for key in arr_dict:
        if len(arr_dict[key]) > 1:
            max_indent_idx = arr_dict[key][0]
            max_indent = -1
            for idx in arr_dict[key]:
                indent = len(arr[idx]) - len(arr[idx].lstrip())
                if indent > max_indent:
                    max_indent_idx = idx
                    max_indent = indent
            arr_dict[key].remove(max_indent_idx)
            remove_idx.extend(arr_dict[key])
    for idx in remove_idx:
        arr.remove(arr[idx])
    return arr


def extract_final_dict(filename, kernel_list):
    final_dict1 = {}
    with open(filename, 'r') as f:
        content = f.readlines()
    clean_content = [line.rstrip() for line in content]  # remove the newline characters at the end

    for line in clean_content:
        if line == '_waic_main_init':
            start1 = clean_content.index(line)
            break
    for i in range(start1, len(clean_content)):
        if not clean_content[i]:
            end1 = i
            break
    list1 = clean_content[start1:end1]

    for line in clean_content:
        if line[:5] == '-----':
            start2 = clean_content.index(line) + 1
            break
    for i in range(start2, len(clean_content)):
        if not clean_content[i]:
            end2 = i
            break
    list2 = clean_content[start2:end2]

    subkernel_size_dict = {}
    for line in list2:
        split_line = line.split()
        if split_line[6] not in subkernel_size_dict:
            subkernel_size_dict[split_line[6]] = int(split_line[4])
    remove_wrappers = []
    for wrapper in list1:
        wrapper = wrapper.split()[0]
        if wrapper == '_waic_main_init':
            final_dict1['_waic_main_init'] = subkernel_size_dict['_waic_main_init']
            subkernel_size_dict.pop('_waic_main_init')
            remove_wrappers.append(wrapper)
        elif wrapper == '_main':
            final_dict1['_main'] = subkernel_size_dict['_main']
            subkernel_size_dict.pop('_main')
            remove_wrappers.append(wrapper)
        elif 'BufferPort' in wrapper and 'config' in wrapper:
            if wrapper in subkernel_size_dict:
                final_dict1['BufferPort'] = subkernel_size_dict[wrapper]
                subkernel_size_dict.pop(wrapper)
                remove_wrappers.append(wrapper)
        elif 'super_kernel_loop' in wrapper:
            if wrapper in subkernel_size_dict:
                final_dict1['super_kernel_loop'] = subkernel_size_dict[wrapper]
                subkernel_size_dict.pop(wrapper)
                remove_wrappers.append(wrapper)
        elif 'testPi' in wrapper:
            if wrapper in subkernel_size_dict:
                final_dict1['testPi'] = subkernel_size_dict[wrapper]
                subkernel_size_dict.pop(wrapper)
                remove_wrappers.append(wrapper)
    new_list1 = []
    for line in list1:
        if line.split()[0] in remove_wrappers:
            continue
        else:
            new_list1.append(line[8:])
    final_list1 = []
    for line in new_list1:
        indent = len(line) - len(line.lstrip())
        new_line = re.sub(r'^_Z[N]?\d+', '', line.split()[0])
        final_list1.append(line[:indent] + new_line)

    final_subkernel_size_dict = {}
    for kernel in kernel_list:
        for i in range(len(final_list1)):
            if kernel + 'R10KernelArgs' in final_list1[i]:
                final_list1[i] = kernel
    for final_line in final_list1:
        for key in subkernel_size_dict:
            if final_line.strip() in key:
                if 'R10KernelArgs' not in final_line.strip() and 'R10KernelArgs' in key:
                    if final_line.strip()+'R10KernelArgs' in key:
                        final_subkernel_size_dict[final_line.strip()] = subkernel_size_dict[key]
                else:
                    final_subkernel_size_dict[final_line.strip()] = subkernel_size_dict[key]
    kernel_lines_map = {}
    start = 0
    for i in range(1, len(final_list1)):
        if final_list1[i] in kernel_list:
            temp = copy.deepcopy(final_list1[start:i])
            kernel_lines_map[final_list1[start]] = remove_redundant(temp)
            start = i
    temp = copy.deepcopy(final_list1[start:])
    kernel_lines_map[final_list1[start]] = remove_redundant(temp)
    wrapper_levels_dict = {}
    for kernel in list(kernel_lines_map.keys()):
        kernel_pm_size_dict = parse_calltree(kernel_lines_map[kernel], final_subkernel_size_dict, wrapper_levels_dict)
        wrapper_levels_dict.update(extract_all_levels(kernel_pm_size_dict))
        final_dict1.update(kernel_pm_size_dict)

    def replace_empty_subkernels(d):
        if isinstance(d, dict):
            for key, value in d.items():
                if isinstance(value, dict):
                    subkernels = value.get('subkernels', {})
                    if subkernels == {}:
                        value['subkernels'] = None
                    else:
                        replace_empty_subkernels(subkernels)

    replace_empty_subkernels(final_dict1)
    return final_dict1

def main(args) -> int:
    kernel_dict = {}
    output_dir = os.path.abspath(args.output_dir)
    overlay = args.overlay
    parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    yaml_files = glob.glob(os.path.join(parent_dir, 'OGOAT/Collaterals/*_kernel_metadata.yaml'))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for file in yaml_files:
        with open(file) as f:
            kd = yaml.safe_load(f)
            kernel_dict.update(kd)

    kernel_include_list = []
    kernel_list = []
    for key in kernel_dict:
        if "kernel_path" in kernel_dict[key] and "bfp" not in key:
            new_kernel_list = kernel_dict[key]['kernel_path']['kernel_list']
            new_kernel_include_list = kernel_dict[key]['kernel_path']['kernel_include']
            for new_kernel in new_kernel_list:
                if new_kernel not in kernel_list:
                    kernel_list.append(new_kernel)
            for new_kernel_include in new_kernel_include_list:
                if os.path.exists(os.path.join(parent_dir, 'kernels', new_kernel_include)):
                    kernel_include_list.append(new_kernel_include)
    kernel_include_list = ["super.hh"] + list(set(kernel_include_list))

    json_dict = {}
    json_dict['kernel_list'] = kernel_list
    json_dict['kernel_include'] = kernel_include_list
    json_dict['group_norm_in_model'] = True
    json_dict['disable_fast_pm'] = False
    #kernel_list_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
    #                                'WAIC_Outputs/all_kernel_list.json')
    kernel_list_path = os.path.join(output_dir, 'all_kernel_list.json')
    with open(kernel_list_path, 'w') as file:
        json.dump(dict(json_dict), file, indent=4)
    subprocess.run(
        ["python", "dataflow/xclbin/xclbin_build.py", "-o", overlay, "-f", kernel_list_path, "-d", output_dir, "--change_dir"])
    col = int(overlay.split('x')[0])
    row = int(overlay.split('x')[1])
    break_flag = 0
    for i in range(col):
        for j in range(row):
            calltree_file = ("dataflow/xclbin/Work/aie/"+str(i) + "_" + str(j)+"/Release/"+ str(i) + "_" + str(j)+".calltree")
            if os.path.exists(calltree_file):
                break_flag = 1 
                final_dict1 = extract_final_dict(calltree_file, kernel_list)
                break
        if break_flag==1:
            break
    if break_flag==0:
        print("Calltree generation failed")
        return 10
    kernel_list.remove('run_conv_a16w8_qdq')
    with open(kernel_list_path, 'w') as file:
        json.dump(dict(json_dict), file, indent=4)
    subprocess.run(
        ["python", "dataflow/xclbin/xclbin_build.py", "-o", overlay, "-f", kernel_list_path, "-d", output_dir, "--change_dir"])
    break_flag = 0
    for i in range(col):
        for j in range(row):
            calltree_file = ("dataflow/xclbin/Work/aie/"+str(i) + "_" + str(j)+"/Release/"+ str(i) + "_" + str(j)+".calltree")
            if os.path.exists(calltree_file):
                break_flag = 1 
                final_dict2 = extract_final_dict(calltree_file, kernel_list)
                break
        if break_flag==1:
            break
    if break_flag==0:
        print("Calltree generation failed")
        return 10
    kernels_size_yaml = os.path.join(output_dir,"kernels_size.yaml")
    kernels_size_without_conv_yaml = os.path.join(output_dir,"kernels_size_without_conv.yaml")
    with open(kernels_size_yaml, "w") as f:
        yaml.dump(final_dict1, f, sort_keys=False)
    with open(kernels_size_without_conv_yaml, "w") as f:
        yaml.dump(final_dict2, f, sort_keys=False)
    source_dest_pairs = [
        (kernels_size_yaml, "OGOAT/Collaterals/kernels_size.yaml"),
        (kernels_size_without_conv_yaml, "OGOAT/Collaterals/kernels_size_without_conv.yaml")
    ]
    for src, dest in source_dest_pairs:
        if os.path.exists(src):
            subprocess.run(["cp", src, dest])
        else:
            print(f"Error: Source file '{src}' does not exist.")
            return 10


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-o','--overlay', required=True, choices=['4x4', '8x4'], help="Name of overlay to run")
    parser.add_argument('-d', '--output_dir', required=True, help="Output directory")
    args = parser.parse_args()
    sys.exit(main(args))
