# fmt: on
import glob
import os
import argparse
import logging
import onnx
import yaml
from OGOAT.src.L1_fusion.py_match.helpers.fusion_seq import (
    LeveledFusionSeq,
    FusionSeq,
    filter_by_opt_level,
)
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import (
    FusionConfigs,
    FusionArguments,
)
from OGOAT.src.L1_fusion.kernel_metadata_loader import (
    KernelMetadataLoader,
)
from graph_fusion import Fusion_L1
from OGOAT.src.utils.context import Logger
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner


def main(logger: Logger, runner: SafeRunner, **args) -> onnx.ModelProto:
    fusionArgs = FusionArguments(
        debug=args.get("debug"),
        model_path=args["model_path"],
        model_name=args["model_name"],
        external_data=args["load_data"],
        inits_int4_to_int8=bool(int(args["int4_to_int8"])),
        fusion_seq_path=args.get("fusion_seq"),
        target=args.get("target"),
        opt_level=int(args.get("optimization_level")),
        out_dir_path=args.get("output_dir"),
        fast_pm_enable=args.get("fast_pm"),
        qdq_optimization=bool(int(args.get("qdq_optimization", 0))),
        qdq_int16_cleanup=bool(int(args.get("qdq_int16_cleanup", 1))),
        old_fusion_flow=args.get("old_fusion_flow"),
        shape_inference_outputs=args["shape_inference_outputs"],
        prebuilt_mladf_mha=args["prebuilt_mladf_mha"],
        no_dtype_freeze=args["no_dtype_freeze"],
        assign_pmid_before_partition=args.get("assign_pmid_before_partition", False),
    )

    # load fusion parameters
    base_path = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(base_path, "opset_def.yml"), "r") as f:
        MODULE_OPSET = yaml.full_load(f)

    # load kernel metadata files
    ogoat_dir = os.path.dirname(os.path.dirname(base_path))

    kernel_metadata_yml_files = glob.glob(
        os.path.join(ogoat_dir, "Collaterals/*_kernel_metadata.yaml")
    )

    for file in kernel_metadata_yml_files:
        with open(file) as f:
            kernel_dict = yaml.full_load(f)
            KernelMetadataLoader.load_dict(kernel_dict)

    # Check if a fusion sequence file was specified
    if fusionArgs.fusion_seq_path is not None:
        with open(fusionArgs.fusion_seq_path, "r") as f:
            fusion_seq = FusionSeq.from_dict(yaml.full_load(f))
    elif fusionArgs.target is not None:
        # Load the fusion sequence file according to target
        fusion_seq_path = os.path.join(base_path, f"fusion_seq_{fusionArgs.target}.yml")
        with open(fusion_seq_path, "r") as f:
            leveled_fusion_seq = LeveledFusionSeq.from_dict(yaml.full_load(f))
            fusion_seq = filter_by_opt_level(leveled_fusion_seq, fusionArgs.opt_level)
    else:
        # Load the default fusion sequence with levels
        fusion_seq_path = os.path.join(base_path, "fusion_seq.yml")
        with open(fusion_seq_path, "r") as f:
            leveled_fusion_seq = LeveledFusionSeq.from_dict(yaml.full_load(f))
            fusion_seq = filter_by_opt_level(leveled_fusion_seq, fusionArgs.opt_level)
    
    if int(args.get("optimization_level")) == 3:
        if args["prebuilt_mladf_mha"]:
            fusion_seq.patterns.remove("Attention")
            fusion_seq.patterns.remove("RTROptimize")
            fusion_seq.patterns.remove("MHA_RTRCancellation")
            fusion_seq.patterns.remove("Batching")
        else:
            fusion_seq.patterns.remove("Attention_mladf")
            fusion_seq.patterns.remove("MatMulTransposeActWgt")
    
    # Enable LinearSliceTranspose when optimization >= 1 and prebuilt_mladf_mha is enabled
    if int(args.get("optimization_level")) >= 1:
        if not args.get("prebuilt_mladf_mha"):
            # Remove LinearSliceTranspose if prebuilt_mladf_mha is not enabled
            if "LinearSliceTranspose" in fusion_seq.patterns:
                fusion_seq.patterns.remove("LinearSliceTranspose")

    # store the resulted patterns and configs in a file
    fusion_seq_path = os.path.dirname(fusionArgs.model_path) + "/used_fusion_seq.yml"
    fusion_seq.save_to_file(fusion_seq_path)

    GRAPH_SURGERY_SEQ = fusion_seq.patterns

    # Set ENABLE_OUTPUT_RESHAPE_TRANSPOSE_FUSION based on prebuilt_mladf_mha flag
    if fusionArgs.prebuilt_mladf_mha:
        fusion_seq.configs.MMT_configs['ENABLE_OUTPUT_RTR_FUSION'] = True
    
    FusionConfigs.save_fusion_configs(fusion_seq.configs)

    # load model
    model = onnx.load_model(
        fusionArgs.model_path, load_external_data=fusionArgs.external_data
    )

    # WAIC_rutime calls onnx.checker.check_model(), which required the model
    # graph to have a name. Onnxruntime passes models without graph names in
    # VAIP. So fix this up right here if needed.
    if not model.graph.name:
        model.graph.name = "unnamed"

    # create fusion object and start fusion
    fusion_obj = Fusion_L1(fusionArgs, GRAPH_SURGERY_SEQ, MODULE_OPSET, logger, runner)
    fused_model = fusion_obj.run_fusion(model)
    return fused_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "-to_i8",
        "--int4_to_int8",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "--qdq_optimization",
        type=int,
        choices=[0, 1],
        default=0,
        help="Enable QDQ optimization at end of L1 fusion. Default is 0 (disabled).",
    )
    parser.add_argument(
        "--qdq_int16_cleanup",
        type=int,
        choices=[0, 1],
        default=1,
        help="Enable int16 QDQ cleanup pass for mixed precision models. Removes int16 QDQ pairs adjacent to int8 QDQ. Default is 1 (enabled).",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")

    logger = Logger.get_null_logger()
    runner = SafeRunner(logger=logger, output_dir_path=args.get("ouput_dir"))
    main(logger, runner, **vars(args))
