import os
import argparse
from typing import List
import shutil
import subprocess
import sys
import json
import traceback
import re
import yaml
import glob
import ast
import numpy as np
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list
from concurrent.futures import ProcessPoolExecutor

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)

Unknown_pm_size = 400
threshold = 16384 - Unknown_pm_size


def get_kernel_calls(kernel, kernel_size_tree):
    kernel_calls = [kernel]
    if kernel_size_tree['subkernels']:
        for sub_kernel in kernel_size_tree['subkernels']:
            kernel_calls.extend(get_kernel_calls(sub_kernel, kernel_size_tree['subkernels'][sub_kernel]))
    return kernel_calls


def get_kernel_size(yaml_file_name):
    kernel_size = {}
    parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    yaml_file = os.path.join(parent_dir, 'OGOAT/Collaterals', yaml_file_name)
    with open(yaml_file) as f:
        ks = yaml.safe_load(f)
        kernel_size.update(ks)
    return kernel_size


def calc_kernel_size(kernel, kernel_size_tree, current_subkernels_list):
    if kernel in current_subkernels_list:
        return 0

    kernel_size = kernel_size_tree['size']
    if kernel_size_tree['subkernels']:
        for sub_kernel in kernel_size_tree['subkernels']:
            kernel_size = kernel_size + calc_kernel_size(sub_kernel, kernel_size_tree['subkernels'][sub_kernel],
                                                         current_subkernels_list)
    return kernel_size


def calc_pm_size(kernel_names):
    all_subgraph_kernels_list = []
    all_subgraph_kernel_includes_list = []
    current_subgraph_kernels = []
    current_subgraph_kernel_includes = []
    current_subgraph_kernel_calls = []

    conv_kernel_flag = False
    if 'run_conv_a16w8_qdq' in kernel_names:
        conv_kernel_flag = True
    if conv_kernel_flag:
        kernel_size_dict = get_kernel_size('kernels_size.yaml')
    else:
        kernel_size_dict = get_kernel_size('kernels_size_without_conv.yaml')
    current_subgraph_size = kernel_size_dict['_waic_main_init'] + kernel_size_dict['_main'] + kernel_size_dict[
        'BufferPort'] + kernel_size_dict['super_kernel_loop']
    current_subgraph_kernels = []
    current_subgraph_kernel_calls = []
    current_subgraph_kernel_includes = []
    new_kernels_size = 0
    new_kernels = []
    new_kernel_includes = []
    new_kernel_calls = []
    new_kernels_size = 0
    for node_kernel in kernel_names:
        if node_kernel and node_kernel not in current_subgraph_kernels:
            new_kernels.append(node_kernel)
            new_kernels_size = new_kernels_size + calc_kernel_size(node_kernel, kernel_size_dict[node_kernel],
                                                                   current_subgraph_kernel_calls + new_kernel_calls)
            new_kernel_calls.extend(get_kernel_calls(node_kernel, kernel_size_dict[node_kernel]))
            current_subgraph_kernels.append(node_kernel)
    pm_size = current_subgraph_size + new_kernels_size
    return pm_size

def validate_pm_size(kernel_dict):
    pm_id_keys = list(kernel_dict.keys())
    gen_pm_ids = set()
    next_id = sum(1 for x in pm_id_keys if x.startswith('pm_'))
    newkernel_id_list = list(range(len(kernel_func_list),30,-1))
    del_list = newkernel_id_list + [28, 29, 22, 26, 27, 16, 14, 25, 24, 7, 9, 23, 17, 18, 19, 20, 21, 13, 5, 12, 3, 8, 15, 2, 1, 0]
    mha_list = ['run_gemm_qdq_mini', 'run_bcast_add_mini', 'run_mini_mha_preprocess', 'run_presoftmax_dequant']
    xint8_list = ['run_conv_xint8', 'run_matadd']
    global pm_flag
    pm_flag = 0
    for id in pm_id_keys:
        if type(kernel_dict[id]) is dict and "kernel_list" in kernel_dict[id]:
            kernel_names = kernel_dict[id]['kernel_list']
            old_kernel_names = kernel_names.copy()
            pm_size = calc_pm_size(kernel_names)
            initial_pm_size = pm_size
            del_kernels = []
            while pm_size > threshold:
                gen_pm_ids.add(id)
                break_flag=0
                print(f"Program Memory overflow by {pm_size - threshold} Bytes")
                print("Reformatting kernel_list of ", id)
                for dk in del_list:
                    for node_kernel in kernel_names:
                      if kernel_names[node_kernel] == dk:
                        if node_kernel in xint8_list:
                          for xint8_kernel in xint8_list:
                            del_kernels.append(xint8_kernel)
                            del kernel_dict[id]["kernel_list"][xint8_kernel]
                        elif node_kernel in mha_list:
                          for mha_kernel in mha_list:
                            del_kernels.append(mha_kernel)
                            del kernel_dict[id]["kernel_list"][mha_kernel]
                        else:
                            del_kernels.append(node_kernel)
                            del kernel_dict[id]["kernel_list"][node_kernel]
                        break_flag=1
                        break
                    if break_flag==1:
                        break
                kernel_names = kernel_dict[id]['kernel_list']
                pm_size = calc_pm_size(kernel_names)
                if pm_size<threshold:
                  for x in range(2):
                    if set(del_kernels).intersection(mha_list) == set(mha_list):
                      if 'run_softmax_qdq' not in del_kernels:
                        del_kernels.append('run_softmax_qdq')
                    if (set(del_kernels) == set(old_kernel_names)):
                        break_flag=1
                        break
                    else: 
                      new_kernel_names = old_kernel_names.copy()
                      initial_pm_size = calc_pm_size(new_kernel_names)
                      while initial_pm_size > threshold:
                       if set(del_kernels).intersection(mha_list) == set(mha_list):
                         if 'run_softmax_qdq' not in del_kernels:
                           del_kernels.append('run_softmax_qdq')
                       break_flag=0
                       for dk in del_list:
                         for node_kernel1 in old_kernel_names:
                            if old_kernel_names[node_kernel1] == dk:
                               if node_kernel1 not in del_kernels:
                                 if node_kernel1 in xint8_list:
                                   for xint8_kernel in xint8_list:
                                     del_kernels.append(xint8_kernel)
                                     del new_kernel_names[xint8_kernel]
                                 elif node_kernel1 in mha_list:
                                   for mha_kernel in mha_list:
                                     del_kernels.append(mha_kernel)
                                     del new_kernel_names[mha_kernel]
                                 else:
                                   del_kernels.append(node_kernel1)
                                   del new_kernel_names[node_kernel1]
                                 break_flag=1
                                 break
                         if break_flag==1:
                           break
                       initial_pm_size = calc_pm_size(new_kernel_names)
                      is_subset = 0
                      for key in list(kernel_dict.keys()):
                        if set(new_kernel_names).issubset(kernel_dict[key]["kernel_list"].items()): 
                          is_subset=1
                      if not is_subset:                 
                        new_id = "pm_" + str(next_id)
                        gen_pm_ids.add(new_id)
                        next_id = next_id + 1
                        kernel_dict[new_id] = {}
                        kernel_dict[new_id]["kernel_list"] = new_kernel_names
                        kernel_dict[new_id]["kernel_include"] = kernel_dict[id]['kernel_include']
                        kernel_dict[new_id]["kernel_directory"] = kernel_dict[id]['kernel_directory']
    return (kernel_dict, gen_pm_ids)

def detect_kernel_change() -> List[str]:
    base_branch = "origin/main"
    dir_path = 'kernels'
    changed_files = set()

    try:
        # 1. Files changed in commits on current branch vs base_branch
        result_commits = subprocess.run(
            ['git', 'diff', '--name-only', f'{base_branch}...HEAD'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_commits.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)

        # 2. Staged but uncommitted changes (index vs HEAD)
        result_staged = subprocess.run(
            ['git', 'diff', '--name-only', '--cached'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_staged.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)
        # 3. Unstaged changes (working tree vs index)
        result_unstaged = subprocess.run(
            ['git', 'diff', '--name-only'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_unstaged.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)

        return sorted(changed_files)

    except subprocess.CalledProcessError as e:
        print("Error detecting changed files:", e)
        return []


def map_pm_changes(files_list: List[str]) -> List[str]:
    map_path = os.path.join(CURRDIR, '..', 'Collaterals', 'pm_kernel_map.json')
    with open(map_path, 'r') as f:
        pm_data = json.load(f)
    matched_pms = set()
    changed_dirs = set()
    pm_key_pattern = re.compile(r'^pm_\d+$')
    print(f"Detected change in these files: {files_list}")
    for changed_file in files_list:
        changed_dir = changed_file.partition('/')[0]
        changed_dirs.add(changed_dir)
    print(f"Detected change in these kernel directories: {changed_dirs}")
    for pm_key, pm_info in pm_data.items():
        if isinstance(pm_info, dict) and pm_key_pattern.match(pm_key):
            kernel_directory = pm_info.get('kernel_directory', [])
            for dir in changed_dirs:
                if dir in kernel_directory:
                    matched_pms.add(pm_key)
    print(f"PM bins need to be updated {matched_pms}")
    return sorted(matched_pms)


def get_kernels_list(kernel_dict):
    kernels_list = []
    pm_id_keys = list(kernel_dict.keys())
    for id in pm_id_keys:
        if type(kernel_dict[id]) is dict and "kernel_list" in kernel_dict[id]:
            kernel_names = kernel_dict[id]['kernel_list']
            for new_kernel in kernel_names:
                if new_kernel not in kernels_list:
                    kernels_list.append(new_kernel)
    return kernels_list


def new_kernel_map_generation(output_dir: str) -> set[str] | int:
    try:
        parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        yaml_files = glob.glob(os.path.join(parent_dir, 'OGOAT/Collaterals/*_kernel_metadata.yaml'))
        main_kernel_dict = {}
        for file in yaml_files:
            with open(file) as f:
                kd = yaml.safe_load(f)
                main_kernel_dict.update(kd)
        global pm_flag
        pm_flag = 1
        Allkernel_list = os.path.join(output_dir, 'all_kernel_list.json')

        with open(Allkernel_list) as f1:
            Allkernel_dict = json.load(f1)

        kernel_path = os.path.join(CURRDIR, '..', 'Collaterals', 'pm_kernel_map.json')
        with open(kernel_path) as f:
            kernel_dict = json.load(f)
        item2 = kernel_dict.popitem()
        item1 = kernel_dict.popitem()

        (kernel_dict, new_pm_ids) = validate_pm_size(kernel_dict)
        kernels_list = get_kernels_list(kernel_dict)

        new_kernels_list = list(set(Allkernel_dict['kernel_list']) - set(kernels_list))
        while len(new_kernels_list) > 0:
            pm_id_keys = list(kernel_dict.keys())
            next_id = sum(1 for x in pm_id_keys if x.startswith('pm_'))
            new_id = "pm_" + str(next_id)
            new_pm_ids.add(new_id)
            next_id = next_id + 1
            kernel_dict[new_id] = {}
            new_kernel_names = {}
            old_kernel_names = {}
            kernel_include_list = []
            break_flag = 0
            for kernel in new_kernels_list:
              if kernel not in kernels_list:
                for key in main_kernel_dict:
                    if "kernel_path" in main_kernel_dict[key] and "bfp" not in key:
                        if kernel in main_kernel_dict[key]['kernel_path']['kernel_list']:
                            unique_kernels_list = main_kernel_dict[key]['kernel_path']['kernel_list']
                            for k in unique_kernels_list:
                                try:
                                    kidx = kernel_func_list.index(k)
                                    new_kernel_names[k] = kidx
                                except ValueError:
                                    print(f"Error: '{k}' not found in the kernel func list!")
                            kernel_dict[new_id]["kernel_list"] = dict(
                                sorted(new_kernel_names.items(), key=lambda item: item[1]))
                            kernel_names = kernel_dict[new_id]['kernel_list']
                            pm_size = calc_pm_size(kernel_names)
                            if pm_size<threshold:
                                old_kernel_names = kernel_names.copy()
                                new_kernel_include_list= main_kernel_dict[key]['kernel_path']['kernel_include']
                                for new_kernel_include in new_kernel_include_list:
                                    if new_kernel_include not in kernel_include_list:
                                        kernel_include_list.append(new_kernel_include)
                            else:
                                kernel_dict[new_id]["kernel_list"] = old_kernel_names
                                break_flag = 1
                                break
              kernels_list = get_kernels_list(kernel_dict)
              if break_flag:
                  break 

            new_kernels_list = list(set(Allkernel_dict['kernel_list']) - set(kernels_list))   
            kernel_dict[new_id]["kernel_include"] = kernel_include_list
            kernel_dict[new_id]["kernel_directory"] = []
            for s in kernel_dict[new_id]["kernel_include"]:
                if s != "super.hh":
                    new_dir = s.split('/')[0]
                    if new_dir not in kernel_dict[new_id]["kernel_directory"]:
                        kernel_dict[new_id]["kernel_directory"].append(new_dir)

        kernelmap = os.path.join(output_dir, 'new_kernel_map.json')
        kernel_dict.update({item1[0]: item1[1]})
        kernel_dict.update({item2[0]: item2[1]})
        with open(kernelmap, "w") as outfile:
            json.dump(kernel_dict, outfile, indent=4)
        shutil.copy(kernelmap, "OGOAT/Collaterals/pm_kernel_map.json")
        return new_pm_ids
    except Exception as e:
        print(f"Error in new_kernel_map_generation() {e}")
        print(traceback.format_exc())
        return 1


def extract_kernel_func_list_from_source(source_code: str) -> List[str]:
    tree = ast.parse(source_code)
    for node in tree.body:
        if isinstance(node, ast.Assign):
            for target in node.targets:
                if isinstance(target, ast.Name) and target.id == 'kernel_func_list':
                    if isinstance(node.value, (ast.List, ast.Tuple)):
                        return [elt.s for elt in node.value.elts if isinstance(elt, ast.Str)]
    return []


def check_new_kernel_changes() -> bool:
    commit_ref = "origin/main"
    func_list_file_path = os.path.join(CURRDIR, "..", "src", "L1_fusion", "kernel_func_list.py")
    remote_file_path = "OGOAT/src/L1_fusion/kernel_func_list.py"
    try:
        with open(func_list_file_path, 'r') as f:
            current_source = f.read()
        current_entries = extract_kernel_func_list_from_source(current_source)

        result = subprocess.run(['git', 'show', f'{commit_ref}:{remote_file_path}'], capture_output=True, text=True)
        if result.returncode != 0:
            raise FileNotFoundError(f"File {remote_file_path} not found at commit {commit_ref}")
        file_content = result.stdout
        old_entries = extract_kernel_func_list_from_source(file_content)
        new_entries = [entry for entry in current_entries if entry not in old_entries]
        if new_entries:
            print("New kernel functions added:", new_entries)
            return True
        else:
            print("No new kernel functions added.")
            return False
    except Exception as e:
        print(f"Error in check_new_kernel_changes() {e}")
        print(traceback.format_exc())
        return False

def check_kernel_removal() -> bool:
    commit_ref = "origin/main"
    func_list_file_path = os.path.join(CURRDIR, "..", "src", "L1_fusion", "kernel_func_list.py")
    remote_file_path = "OGOAT/src/L1_fusion/kernel_func_list.py"
    try:
        with open(func_list_file_path, 'r') as f:
            current_source = f.read()
        current_entries = extract_kernel_func_list_from_source(current_source)

        result = subprocess.run(['git', 'show', f'{commit_ref}:{remote_file_path}'], capture_output=True, text=True)
        if result.returncode != 0:
            raise FileNotFoundError(f"File {remote_file_path} not found at commit {commit_ref}")
        file_content = result.stdout
        old_entries = extract_kernel_func_list_from_source(file_content)
        rem_entries = [entry for entry in old_entries if entry not in current_entries]
        if rem_entries:
            print("Removed kernel functions are:", rem_entries)
            return True
        else:
            print("No kernel functions are removed.")
            return False
    except Exception as e:
        print(f"Error in check_kernel_removal() {e}")
        print(traceback.format_exc())
        return False

def check_files_exist(file_prefixes: List[str], output_dir: str) -> bool:
    missing_files = []
    for prefix in file_prefixes:
        filename = f"{prefix}.bin"
        filepath = os.path.join(output_dir, filename)
        if not os.path.isfile(filepath):
            missing_files.append(filename)

    if missing_files:
        print("Missing bins:", ", ".join(missing_files))
        return False

    print("All bins files generated.")
    return True

def run_command(command):
    result = subprocess.run(command)

def main(args) -> int:
    output_dir = os.path.abspath(args.output_dir)
    overlay = args.overlay
    log_dir = os.path.join(output_dir, "logs")
    if os.path.exists(log_dir):
        shutil.rmtree(log_dir)      
    os.makedirs(log_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    try:
        log_file = os.path.join(log_dir, "gen_metadata_from_calltree.log")
        script_name = "gen_metadata_from_calltree.py"
        script_path = os.path.join(CURRDIR, script_name)
        script_args = ["-d", output_dir, "-o", overlay]
        command = ["bsub", "-Is", "-R", "rusage[mem=32768]", "-q", "long", "-R", "select[type=X86_64]", "-R", "select[osdistro=rhel && (osver=ws8)]", "-o", log_file]+[sys.executable, script_path] + script_args

        result = subprocess.run(command)

        if result.returncode != 0:
            return result.returncode
    except Exception as e:
        print(f"Error in gen_metadata_from_calltree {e}")
        print(traceback.format_exc())
        return 1
    changed_pm_ids = set()
    if check_kernel_removal():
        kernel_path = os.path.join(CURRDIR, '..', 'Collaterals', 'pm_kernel_map.json')
        with open(kernel_path) as f:
            kernel_dict = json.load(f)  
        item2 = kernel_dict.popitem()
        item1 = kernel_dict.popitem()
        pm_id_keys = list(kernel_dict.keys())

        for key in pm_id_keys:
            new_kernel_names = {}
            if "kernel_list" in kernel_dict[key]:
                kernels_list = kernel_dict[key]['kernel_list']
                flag_rem = 0
                for k in kernels_list:
                    try:
                        kidx = kernel_func_list.index(k)
                        new_kernel_names[k] = kidx
                    except ValueError:
                        print(f"Error: '{k}' not found in the kernel func list!")
                        flag_rem =1
                if flag_rem:
                    for k_include in kernel_dict[key]['kernel_include']:
                        file_path = os.path.join(CURRDIR, '..', '..', 'kernels', k_include)
                        if not os.path.exists(file_path):
                            kernel_dict[key]['kernel_include'].remove(k_include)
                    kernel_dict[key]["kernel_directory"] = []
                    for s in kernel_dict[key]["kernel_include"]:
                        if s != "super.hh":
                            new_dir = s.split('/')[0]
                            if new_dir not in kernel_dict[key]["kernel_directory"]:
                                kernel_dict[key]["kernel_directory"].append(new_dir)
           
                if kernels_list != new_kernel_names:
                    changed_pm_ids.add(key)
                kernel_dict[key]["kernel_list"] = dict(
                sorted(new_kernel_names.items(), key=lambda item: item[1]))
                if not kernel_dict[key]["kernel_list"]:
                    del kernel_dict[key]
                    changed_pm_ids.remove(key)
        kernel_dict.update({item1[0]: item1[1]})
        kernel_dict.update({item2[0]: item2[1]})
        with open(kernel_path, "w") as outfile:
            json.dump(kernel_dict, outfile, indent=4)
      
    diff_files = detect_kernel_change()
    print(f"Detected change in these files: {diff_files}")
    new_pm_ids = new_kernel_map_generation(output_dir)
    if isinstance(new_pm_ids, int):
        print(f"Error occurred with code {new_pm_ids}")
        return new_pm_ids
    pm_id_keys = map_pm_changes(diff_files)
    pm_id_keys.extend(changed_pm_ids)

    if not new_pm_ids:
        print("List of pm bins with new kernel addition is empty")
    else:
        pm_id_keys.extend(new_pm_ids)
    if args.all_bins:
        kernel_path = os.path.join(CURRDIR, '..', 'Collaterals', 'pm_kernel_map.json')
        with open(kernel_path) as f:
            kernel_dict = json.load(f)  
        item2 = kernel_dict.popitem()
        item1 = kernel_dict.popitem()
        pm_id_keys = list(kernel_dict.keys())

    print(f"Final list of pm bins to be generated : {pm_id_keys}")
    if not pm_id_keys:
        print("Got 0 pm id to generate!")
    else:
        pm_lists = np.array_split(list(pm_id_keys),len(pm_id_keys))
        script_name = "xclbin_build.py"
        script_path = os.path.join(CURRDIR, '..', '..', 'dataflow', 'xclbin', script_name)

        if not os.path.exists(script_path):
            print(f"Error: The script '{script_name}' does not exist in the path '{script_path}'.")
            sys.exit(1)
        file_path = os.path.join(CURRDIR, '..', 'Collaterals', 'pm_kernel_map.json')
        count=0
        commands = []
        for pm_list in pm_lists:
            output_subdir = os.path.join(output_dir,str(count))
            pm_id_list = pm_list.tolist()
            script_args = ["-f", file_path, "-d", output_subdir, "-o", overlay, "-b", *pm_id_list, "--sub_dir"]
            log_filename = f"{pm_id_list[0]}.log"
            log_file = os.path.join(log_dir, log_filename)
            command = ["bsub", "-Is", "-R", "rusage[mem=32768]", "-q", "long", "-R", "select[type=X86_64]", "-R", "select[osdistro=rhel && (osver=ws8)]", "-o", log_file]+[sys.executable, script_path] + script_args
            count = count+1
            commands.append(command)

        with ProcessPoolExecutor(max_workers=6) as executor:
            results = list(executor.map(run_command,commands))
        print("All subprocesses have completed.")
        if not check_files_exist(list(pm_id_keys), output_dir):
            return 2


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-o','--overlay', required=True, choices=['4x4', '8x4'], help="Name of overlay to run")
    parser.add_argument('-k', '--kernel_names', required=False, help="Comma-separated list of kernel names")
    parser.add_argument('-i', '--kernel_includes', required=False, help="Comma-separated list of kernel includes")
    parser.add_argument('-f', '--kernel_file', required=False, help="kernel file")
    parser.add_argument('-b', '--all_bins', action="store_true", default=False, required=False, help="generate all pm bins from the kernel map file")
    parser.add_argument('-d', '--output_dir', required=True, help="Output directory")
    args = parser.parse_args()
    sys.exit(main(args))
