import os
import argparse
from typing import List
import shutil
import subprocess
import sys
import json
import traceback
import re
CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))

import dummy_dataflow
from dataflow_common import (
    overlay_stack_size,
    overlay_heap_size,
    clean_overlay,
    build_txn_aiert
)
from dmacompiler import \
    set_dev_gen, DevGen, config 
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list


def detect_kernel_change() -> List[str]:
    base_branch = "origin/main"
    dir_path = 'kernels'
    changed_files = set()

    try:
        # 1. Files changed in commits on current branch vs base_branch
        result_commits = subprocess.run(
            ['git', 'diff', '--name-only', f'{base_branch}...HEAD'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_commits.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)

        # 2. Staged but uncommitted changes (index vs HEAD)
        result_staged = subprocess.run(
            ['git', 'diff', '--name-only', '--cached'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_staged.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)
        # 3. Unstaged changes (working tree vs index)
        result_unstaged = subprocess.run(
            ['git', 'diff', '--name-only'],
            capture_output=True,
            text=True,
            check=True
        )
        for f in result_unstaged.stdout.splitlines():
            if f.startswith(dir_path):
                filename_only = f[len(dir_path):]
                if filename_only.startswith('/'):
                    filename_only = filename_only[1:]
                    changed_files.add(filename_only)

        return sorted(changed_files)

    except subprocess.CalledProcessError as e:
        print("Error detecting changed files:", e)
        return []

def map_pm_changes(files_list: List[str]) -> List[str]:
    map_path = os.path.join(CURRDIR, '..', '..', 'OGOAT', 'Collaterals', 'pm_kernel_map.json')
    with open(map_path, 'r') as f:
        pm_data = json.load(f)
    matched_pms = set()
    changed_dirs = set()
    pm_key_pattern = re.compile(r'^pm_\d+$')
    print(f"Detected change in these files {files_list}")
    for changed_file in files_list:
        changed_dir = changed_file.partition('/')[0]
        changed_dirs.add(changed_dir)
    print(f"Detected change in these kernels {changed_dirs}")
    for pm_key, pm_info in pm_data.items():
        if isinstance(pm_info, dict) and pm_key_pattern.match(pm_key):
            kernel_directory = pm_info.get('kernel_directory', [])
            for dir in changed_dirs:
                if dir in kernel_directory:
                    matched_pms.add(pm_key)
    print(f"PM bins need to be updated {matched_pms}")
    return sorted(matched_pms)

def verbose_run(command: str):
    try:
        print(command)
        result = subprocess.run(command)
        if result.returncode != 0:
            return result.returncode
    except subprocess.CalledProcessError as e:
        print(f"Error running command: {command}")
        print(traceback.format_exc())


def build_xclbin(kernel_names: List[str], kernel_includes: List[str], output_dir: str, overlay: str, disable_fast_pm=False, fast_pm_bin_name='pm_0.bin', pm_id=0, change_dir=True, sub_dir=False):
    try:
        root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        output_dir_abs_path = os.path.abspath(output_dir)
        if not os.path.exists(output_dir_abs_path):
           os.makedirs(output_dir_abs_path)
        xclbin_path = os.path.join(CURRDIR, 'out.xclbin')
        if os.path.exists(xclbin_path):
            os.remove(xclbin_path)      
        prev_cwd = os.getcwd()
        if not change_dir:
            build_dir = CURRDIR
        else:
            build_dir = output_dir_abs_path#CURRDIR
            if sub_dir:
                output_dir_abs_path = os.path.join(output_dir_abs_path, '..')        
        os.chdir(build_dir)
        clean_overlay()
        if disable_fast_pm:
            config.ENABLE_FAST_PM = False
        dummy_dataflow.compile_dataflow(kernel_names, kernel_includes, overlay, disable_fastPM=disable_fast_pm)
        stack_size = overlay_stack_size()
        heap_size = overlay_heap_size()
        xclbin_name = 'out.xclbin'
        xclbin_name_transformed = 'out_transformed_pdi.xclbin'
        if overlay == '8x4':
            aie_cols = 8
        else:
            aie_cols = 4
        enable_conv_directive = int('run_conv_a16w8_qdq' in kernel_names)
        result = verbose_run(['make', '-C' f'{CURRDIR}',f'ROOT_DIR={root_dir}',f'OUT_DIR={build_dir}',f'STACKSIZE={stack_size}', f'HEAPSIZE={heap_size}', f'FILENAME={xclbin_name}', f'HAS_CONV={enable_conv_directive}', f'DISABLE_FAST_PM={1 if disable_fast_pm else 0}'])
        print("returncode is: ", result)
        #FAST PM changes
        if not disable_fast_pm:
            from pm_load_bin import combine_bins
            combine_bins('Work', aie_cols, fast_pm_bin_name)
            file_path = os.path.join("Work", "aie", "0_0", "elf_ctrl_pkt.bin")
            elf_file_offset = os.path.getsize(file_path)
            testbench = os.path.join(CURRDIR, 'main.cpp') 
            build_txn_aiert(f'{testbench}', [f"-DNCOLS={aie_cols}", f"-DBINOFFSET={elf_file_offset}", "-D_FAST_PM_=1", "-I.", f"-DPM_ID={pm_id}"])
            if build_dir == output_dir_abs_path:
                output_dir_abs_path = os.path.join(output_dir_abs_path, "bins_dir")
                if not os.path.exists(output_dir_abs_path):
                    os.makedirs(output_dir_abs_path)
                
            src_file = os.path.join(build_dir, fast_pm_bin_name)
            dst_file = os.path.join(output_dir_abs_path, fast_pm_bin_name)
            shutil.copy(src_file, dst_file)

            src_file = os.path.join(build_dir, 'txn_pm.bin')
            dst_file = os.path.join(output_dir_abs_path, 'txn_' + fast_pm_bin_name.split('.')[0] + '.bin')
            shutil.copy(src_file, dst_file)
            print(f"Bin files copied to {output_dir_abs_path}")
        print("changing directory")
        os.chdir(prev_cwd)
        print("current directory is: ",prev_cwd)
        if sub_dir:
            shutil.rmtree(build_dir)

    except Exception as e:
        print("Error in build_xclbin:")
        print(traceback.format_exc())

def main(args):
    try:
        output_dir = args.output_dir
        change_dir = args.change_dir
        sub_dir = args.sub_dir
        disable_fast_pm = args.disable_fast_pm # Override : if disable_fast_pm passed as arg
        if not args.kernel_file:
            if not args.kernel_names or not args.kernel_includes:
                raise Exception("kernel_names and kernel_includes are required")
            kernel_list = args.kernel_names.split(',')
            kernel_names = {}
            for s in kernel_list:
                try:
                    kidx = kernel_func_list.index(s)
                    kernel_names[s] = kidx
                except ValueError:
                    print(f"Error: '{s}' not found in the kernel func list!")
            kernel_includes = args.kernel_includes.split(',')
            build_xclbin(kernel_names, kernel_includes, output_dir, args.overlay, disable_fast_pm=disable_fast_pm, change_dir=change_dir, sub_dir=sub_dir)
            print("build_xclbin done")
        else:
            with open(args.kernel_file) as f:
                kernel_dict = json.load(f)
            disable_fast_pm = kernel_dict.get('disable_fast_pm', False) # Override : disable_fast_pm by kernel_list.json
            if "kernel_list" in kernel_dict:
                kernel_names = kernel_dict['kernel_list']
                kernel_includes = kernel_dict['kernel_include']
                build_xclbin(kernel_names, kernel_includes, output_dir, args.overlay, disable_fast_pm=disable_fast_pm,change_dir=change_dir, sub_dir=sub_dir)
            else:
                pm_id_keys = list(kernel_dict.keys())
                if args.build_bins:
                    print(f"Build bin flag given")
                    if len(args.build_bins) == 0:
                        print("Generating all the pm bins")
                    else:
                        pm_id_keys = list(args.build_bins)
                        print(f"Generating these bins: {pm_id_keys}")
                for id in pm_id_keys:
                    if type(kernel_dict[id]) is dict and "kernel_list" in kernel_dict[id]:
                        kernel_names = kernel_dict[id]['kernel_list']
                        kernel_includes = kernel_dict[id]['kernel_include']
                        fast_pm_bin_name = id + '.bin'
                        st = str.split(id,'_')
                        pm_id = int(st[1])
                        build_xclbin(kernel_names, kernel_includes, output_dir, args.overlay, disable_fast_pm=disable_fast_pm, fast_pm_bin_name=fast_pm_bin_name, pm_id=pm_id, change_dir=change_dir, sub_dir=sub_dir)
                        print("build_xclbin done")
    except Exception as e:
        print("Error in main:")
        print(traceback.format_exc())

def xclbin_build_main():
    parser = argparse.ArgumentParser(description="Build xclbin", 
                                   usage='use "%(prog)s --help" for more info', 
                                   formatter_class=argparse.RawTextHelpFormatter)
    # Required args
    ## NOTE: Keep the the order of kernel_names and kernel_includes same
    # ex: python xclbin_build.py -o 4x4 -i super.hh,mha_qdq/wrapper_qkt_sm_4x4_i16i16.cc -k run_qtxk,run_sfmx -d mha
    parser.add_argument('-o','--overlay', required=True, choices=['4x4', '8x4'], help="Name of overlay to run")
    parser.add_argument('-k', '--kernel_names', required=False, help="Comma-separated list of kernel names")
    parser.add_argument('-i', '--kernel_includes', required=False, help="Comma-separated list of kernel includes")
    parser.add_argument('-f', '--kernel_file', required=False, help="kernel file")
    parser.add_argument('-d', '--output_dir', required=True, help="Output directory")
    parser.add_argument('--change_dir', action="store_false", default=True, required=False, help="change the path of Work directory to output folder")
    parser.add_argument('--sub_dir', action="store_true", default=False, required=False, help="build-dir is a sub-directory of the output folder")
    parser.add_argument('--disable_fast_pm', action="store_true", default=False, help="Disable fast pm mode : [WARN : Overriden if kernel list json has disable_fast_pm]")
    parser.add_argument(
        '-b', '--build_bins',
        nargs='*',
        default=None,
        help="Optional list of binaries to build, or use as a flag with no args to build all"
    )
    args = parser.parse_args()
    main(args)

if __name__ == '__main__':
    xclbin_build_main()
