import os
import argparse
import json
import yaml
import traceback
import time
from dataclasses import dataclass, field
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing
import shutil
import sys

from OGOAT.src.Tiler.tiling_result import TilingResult
import OGOAT.src.Tiler.tiler_BFM as Tiler_BFM
from OGOAT.src.utils.context import Context
from layer import Layer, UnsupportedDataTypeError
from device import Device
from .kernel import Kernel, KernInfoFromYaml

from tiling_opt import TilingOpt
from config_loader import waic_config

parent_dir = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)

# Check if GCC is available
gcc_available = shutil.which("gcc") is not None

# extract the dataflow path from sys.path
dataflow_path = next((path for path in sys.path if path.endswith("dataflow")), None)

with open(os.path.join(parent_dir, "Collaterals/overlays.yaml")) as f:
    overlays_dict = yaml.safe_load(f)
with open(os.path.join(parent_dir, "Collaterals/device.yaml")) as f:
    devices_dict = yaml.safe_load(f)


supported_overlays = list(overlays_dict.keys())
supported_op_types = [ops for x in overlays_dict.values() for ops in x.keys()]
supported_q_dq_types = ["Quant", "Dequant", "Transpose", "Slice", "Concat"]
# supported_modes = [mode for x in overlays_dict.values() for y in x.values() for mode in y.keys()]
supported_devices = list(devices_dict.keys())

def init_args_Tiler(args, test_dir, model_name):
    Tiler_arg = vars(args).copy()
    Tiler_arg["ir_json"] = os.path.join(
        test_dir, model_name.replace(".onnx", "_mod_nhwc_fused_IR_unique_nodes.json")
    )
    Tiler_arg["kernel_list"] = os.path.join(
        test_dir, model_name.replace(".onnx", "_mod_nhwc_fused_IR_kernel_list.json" "")
    )
    Tiler_arg["tiler_bfm"] = args.tiler_bfm
    Tiler_arg["mode"] = args.tiler_bfm_mode if args.tiler_bfm else None
    print(model_name)
    print(Tiler_arg["ir_json"])

    return Tiler_arg

def run_Tiler_single_op(args, all_kern_info, layer_dict, layer_id, device, output_dir):
    try:
        layer = Layer(layer_dict)
    except UnsupportedDataTypeError as e:
        print(f"INFO: Tiler encountered unsupported datatype : {e}")
        return ["Fail"]
    except:
        print(f"Tiler failed to run layer id: {layer_id}")
        print(traceback.format_exc())
        return ["Fail"]

    if layer.op_type in Tiler_BFM.supported_op and args["tiler_bfm"]:
        print(f"Running tiler bfm for {layer.op_type}, mode: {args['mode']}")
        Tiler_BFM.main(args.copy())
        sub_dir = os.path.join(output_dir, layer_id)
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
        copy_common_main = False
        if (os.name == 'nt' and not gcc_available) or (gcc_available and waic_config.mode == 'release'):
            if args["build_txn"] == "all" or isinstance(args["build_txn"], int):
                copy_common_main = True
        if copy_common_main:
            generic_testbench_path = os.path.join(dataflow_path, "main_common.cpp")
            shutil.copy(generic_testbench_path, sub_dir)

        json_out = os.path.join(sub_dir, layer_id + ".json")
        # all_tilings[layer_id] = json_out
        return layer_id, json_out
    is_supported_op_type = layer.orig_op_type in supported_op_types and 'noop' not in layer.op_type
    is_fused_op = '_' in layer.op_type
    if is_supported_op_type and (is_fused_op or layer.orig_op_type in supported_q_dq_types):
        try:
            kernel = Kernel(layer, args['kernel_list'], all_kern_info)
        except:
            kernel = None
        if args["overlay"] in supported_overlays:
            try:
                print(
                    f"Running layer id: {layer_id} -- {layer.orig_op_type}"
                )

                sub_dir = os.path.join(output_dir, layer_id)
                if not os.path.exists(sub_dir):
                    os.makedirs(sub_dir)
                copy_common_main = False
                if (os.name == 'nt' and not gcc_available) or (gcc_available and waic_config.mode == 'release'):
                    if args["build_txn"] == "all" or isinstance(args["build_txn"], int):
                        copy_common_main = True
                if copy_common_main:
                    generic_testbench_path = os.path.join(dataflow_path, "main_common.cpp")
                    shutil.copy(generic_testbench_path, sub_dir)
                context = Context(debug=args.get("debug", False), output_dir=sub_dir)

                tiling_optimizer = TilingOpt(
                    layer, device, args["overlay"], kernel, layer_id, context
                )

                if layer.orig_op_type == "MatMul" and args.get(
                    "mode_select", False
                ):
                    modeid, sched = args["mode_select"].split(",")
                    opt_tiling = tiling_optimizer.find_optimal_tiling(
                        int(modeid), int(sched)
                    )
                else:
                    opt_tiling = tiling_optimizer.find_optimal_tiling()
                json_out = os.path.join(sub_dir, layer_id + ".json")
                # map layer_id to json_out file path
                # all_tilings[layer_id] = json_out

                # FIXME, only dataflow operator has tiling output of TilingResult type,
                # need to handle other operators
                if isinstance(opt_tiling, TilingResult):
                    opt_tiling.dump(sub_dir, json_out)
                else:
                    with open(json_out, "w") as f:
                        json.dump(opt_tiling, f, indent=2)
                return [layer_id, json_out, "Pass"]
            except:
                print(f"Tiler failed to run layer id: {layer_id}")
                print(traceback.format_exc())
                return ["Fail"]


@dataclass
class TilerConfig:
    bypass_tiler: List[str] = field(default_factory=lambda: [
        "Conv",
        "Concat",
        "Transpose",
        "Slice",
        "Slice_qdq",
        "Slice_neg",
        "Resize",
        "MHA",
        "DepthToSpace",
        "Quant",
        "Dequant",
        "BilinearResize",
        "MaxPool",
    ] )


def main(args):
    with open(args["ir_json"]) as f:
        mdict = json.load(f)

    device = None
    if args["device"] in supported_devices:
        device = Device(args["device"])
    else:
        raise Exception("invalid device selection")

    if args["output_dir"]:
        output_dir = args["output_dir"]
    else:
        output_dir = os.path.dirname(os.path.abspath(args["ir_json"]))

    output_name = (
        os.path.splitext(os.path.basename(args["ir_json"]))[0][:-26] + "_tilings.json"
    )
    output_path = os.path.join(output_dir, output_name)

    all_tilings = {}
    all_kern_info = KernInfoFromYaml()

    if args["multiprocess"]:
        max_workers = args['j'] if 'j' in args else multiprocessing.cpu_count()
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_task = {
                executor.submit(run_Tiler_single_op, args, all_kern_info, layer_dict, layer_id, device, output_dir): (layer_id, layer_dict) for layer_id, layer_dict in mdict.items()
            }
            for future in as_completed(future_to_task):
                try:
                    layer_id, json_out = future.result()
                    all_tilings[layer_id] = json_out
                except:
                    continue
    else:
        for layer_id, layer_dict in mdict.items():
            layer_id, json_out, _ = run_Tiler_single_op(args, all_kern_info, layer_dict, layer_id, device, output_dir)
            all_tilings[layer_id] = json_out


    # load tiling files from all layer together and dump it to output_path
    merged_data = {}
    # Merge Json data from all layers
    for key, file_path in all_tilings.items():
        if os.path.exists(file_path):
            with open(file_path) as file:
                merged_data[key] = json.load(file)
        else:
            print(f"File not found: {file_path}")
    with open(output_path, "w") as f:
        json.dump(merged_data, f, indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--tiler_bfm",
        help="use tiler bfm for gemm instead of cost model.",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "-ir",
        "--ir_json",
        help="path to unique layers JSON file. *Required Field*",
    )
    parser.add_argument(
        "-kernel",
        "--kernel_list",
        help="path to kernel list for the given kernel",
        default="",
    )
    parser.add_argument(
        "-d",
        "--device",
        help="Name of device to run; e.g., strix or phoenix etc.",
    )
    parser.add_argument(
        "-o", "--overlay", help="Name of overlay to run; e.g., 4x4 or 4x2 etc."
    )
    parser.add_argument(
        "-ms",
        "--mode_select",
        help="Selectively generate tiling for a given mode and schedule",
    )
    parser.add_argument(
        "-j",
        type=int,
        help="Number of workers for parallel Tiler, Scheduler execution",
        default=multiprocessing.cpu_count(),
    )
    parser.add_argument(
        "-multiprocess",
        type=bool,
        help="multiprocess Tiler execution, default True",
        default=True,
    )
    parser.add_argument(
        "-build_txn",
        help="build txn",
        default=None
    )
    parser.add_argument(
        "--output_dir",
        help="output_dir",
        default=""
    )
    args = parser.parse_args()
    main(vars(args))
