from pathlib import Path
import shutil
import onnx
import sys
from onnx import ModelProto, TensorProto

import onnxruntime as ort
from dataclasses import dataclass
from dataclass_wizard import YAMLWizard
from typing import Any

import numpy as np
import json

import argparse
import logging

import os
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import (
    OnnxOpsWrapper,
    dtype_to_ops_type,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Node,
    OutputTensor,
    WalkCfgPlain,
)
from L1_utils.utils import (
    get_shape_params_from_model,
    get_fixed_shapes_from_params,
    save_model,
    remove_model,
    model_dict
)
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.skip import WalkCfgSkipNoop

@dataclass
class DefaultShapeParamsValue(YAMLWizard):
    """
    Default shape parameters and values for the model.
    shape_params -- dictionary, key: shape paramter name, value: shape parameter value
    graph_input_values -- dictionary, key: graph input name, value: fixed value for graph input
    """

    shape_params: dict[str, int]
    graph_input_values: dict[str, Any]


def create_default_output_dir() -> str:
    outputdir = (Path(__file__).resolve().parents[3] / "tensor_shapes_output").resolve()
    print("INFO: All execution data will be saved into: ", outputdir)
    if os.path.exists(outputdir):
        shutil.rmtree(outputdir)
    os.makedirs(outputdir, exist_ok=True)
    return str(outputdir)


def get_act_shapes_using_onnx_runtime_large_model(
    model: ModelProto,
    model_path: str,
    input_shapes_dict: dict[str, list[int]],
    input_value_dict: dict[str, Any],
    provider="CPUExecutionProvider",
    print_to_file=False,
    global_out_max=3000,
):
    graph = model.graph
    # dictionary with all act tensor shapes
    node_act_signals_shapes = {}

    # dictionary with all act tensor dtypes
    node_act_signals_dtypes = {}

    rand_inputs = {}

    if len(input_shapes_dict) != 0:
        for _input in graph.input:
            input_shape = input_shapes_dict[_input.name]

            node_act_signals_shapes[_input.name] = [int(i) for i in input_shape]
            d_type = _input.type.tensor_type.elem_type
            node_act_signals_dtypes[_input.name] = onnx.helper.tensor_dtype_to_np_dtype(
                d_type
            )
            rand_inputs[_input.name] = np.ones(
                input_shape, dtype=onnx.helper.tensor_dtype_to_np_dtype(d_type)
            )

    else:
        for _input in graph.input:
            input_shape = np.array(
                [d.dim_value for d in _input.type.tensor_type.shape.dim]
            )

            node_act_signals_shapes[_input.name] = [int(i) for i in input_shape]
            d_type = _input.type.tensor_type.elem_type
            node_act_signals_dtypes[_input.name] = onnx.helper.tensor_dtype_to_np_dtype(
                d_type
            )
            rand_inputs[_input.name] = np.ones(
                input_shape, dtype=onnx.helper.tensor_dtype_to_np_dtype(d_type)
            )

    for input_name, input_value in input_value_dict.items():
        if input_name in rand_inputs:
            rand_inputs[input_name] = np.reshape(
                np.array(input_value, dtype=node_act_signals_dtypes[input_name]),
                rand_inputs[input_name].shape,
            )

    ## create ort session to get existing output nodes with memory optimization
    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
    # Memory optimization settings
    so.enable_mem_pattern = False
    so.enable_cpu_mem_arena = False
    so.enable_mem_reuse = False
    outputs = [output.name for output in model.graph.output]
    # collect graph value info with static shapes before inference
    value_info_dict_known = {
        value_info.name: value_info
        for value_info in graph.value_info
        if len(value_info.type.tensor_type.shape.dim) > 0
        and all(
            dim.HasField("dim_value") for dim in value_info.type.tensor_type.shape.dim
        )
    }
    # add remaining node outputs as global outputs
    model_temp = onnx.load_model(model_path, load_external_data=False)
    i = 0
    # Reduce batch size to process fewer nodes at once

    temp_model_path = os.path.splitext(model_path)[0] + "_ort_temp.onnx"

    def run_batch_inference(total: int):
        save_model(model_temp, temp_model_path, False)
        ## create ort session and run model with memory optimization
        ort_batch_session = ort.InferenceSession(
            temp_model_path, so, providers=[provider]
        )
        outputs = [x.name for x in ort_batch_session.get_outputs()]

        # Run inference and immediately process results to reduce memory footprint
        ort_outs_mod = ort_batch_session.run(outputs, rand_inputs)

        # Process results immediately and clear memory
        for idx, out_signal_name in enumerate(outputs):
            tensor = ort_outs_mod[idx]
            node_act_signals_shapes[out_signal_name] = [int(i) for i in tensor.shape]
            node_act_signals_dtypes[out_signal_name] = tensor.dtype

        # Clear memory
        del ort_outs_mod
        del ort_batch_session
        # Clean up temporary model file
        if os.path.exists(temp_model_path):
            remove_model(temp_model_path, False)

        print("Got shapes for ", total, " nodes output activations using onnx runtime")
    total = 0
    for node in model.graph.node:
        for output in node.output:
            if output not in outputs and output not in value_info_dict_known:
                model_temp.graph.output.extend([onnx.ValueInfoProto(name=output)])
                i += 1
                total += 1
        if i == global_out_max:
            # save the model
            run_batch_inference(total)
            model_temp = onnx.load_model(model_path, load_external_data=False)
            i = 0
    run_batch_inference(total)
    if print_to_file:
        with open(os.path.splitext(model_path)[0] + "_act_shapes.json", "w") as f:
            sys.stdout = f  # Change the standard output to the file we created.
            json.dump(node_act_signals_shapes, f, indent=4)
            sys.stdout = (
                sys.__stdout__
            )  # original_stdout # Reset the standard output to its original value

    return node_act_signals_shapes, node_act_signals_dtypes


def update_op_set(model: onnx.ModelProto):
    domains: set[str] = set()
    for node in model.graph.node:
        domains.add(node.domain)
    # missing version for "com.microsoft"; ORT only support version 1
    if "com.microsoft" in domains:
        domains.discard("com.microsoft")
        print(
            f'warning: op set version for "com.microsoft" is missing; setting version to 1'
        )
        op_set = onnx.OperatorSetIdProto(domain="com.microsoft", version=1)
        model.opset_import.append(op_set)

def update_model_value_info(
    model_,
    all_act_signal_shapes,
    all_act_signal_dtypes,
    value_infor_count_orig,
):
    def update_value_info(node_value_info: onnx.ValueInfoProto):
        if node_value_info.type.tensor_type.shape.dim:
            i = 0
            if node_value_info.name in all_act_signal_shapes:
                for d in node_value_info.type.tensor_type.shape.dim:
                    d.dim_value = all_act_signal_shapes[node_value_info.name][i]
                    i += 1
        else:
            # method 1, create tensor_type proto and add in value info proto
            if node_value_info.name in all_act_signal_shapes:
                dtype_onnx = onnx.helper.np_dtype_to_tensor_dtype(
                    all_act_signal_dtypes[node_value_info.name]
                )
                tensor_type_proto = onnx.helper.make_tensor_type_proto(
                    dtype_onnx,
                    all_act_signal_shapes[node_value_info.name],
                )
                node_value_info.type.CopyFrom(tensor_type_proto)

    value_info_cnt = 0
    # update value info of intermediate node outputs, which were already there in original value_info
    for node_value_info in model_.graph.value_info:
        value_info_cnt += 1
        update_value_info(node_value_info)

    # update value info of real graph outputs
    for node_value_info in model_.graph.output:
        update_value_info(node_value_info)

    print(
        "Tensors in value_info in original model:",
        value_infor_count_orig,
        ", Tensors in value info after:",
        value_info_cnt,
    )



def add_model_info(main_params):

    ### user configurations
    model_path = os.path.normpath(main_params["model_path"])
    load_data = int(main_params["load_data"])
    output_dir = main_params["output_dir"]
    if output_dir == ".":
        output_dir = os.path.dirname(model_path)

    #   get_shapes_method = str(main_params["shape_infer_method"]) #onnx_runtime" # options: "onnx_tool", "onnx_shape_infer", "onnx_runtime", "onnx_runtime_large_model", "onnx_runtime_custom_ops"

    #   # datatype for Linear layers, Options: ["int8", "mx9", "fp32", "bfloat16", "bfp16"]
    #   low_precision_act_dtype = str(main_params["low_precision_act_dtype"])#"uint16" for win24 # "bfp16" for sdxl turbo
    #   # datatype for non linear layers
    #   high_precision_act_dtype = str(main_params["high_precision_act_dtype"]) #"uint16" for win24 # "bfloat16" for sdxl turbo

    #   low_precision_wgt_dtype = str(main_params["low_precision_wgt_dtype"]) #"uint8" for win24 # "bfp16" for sdxl turbo
    #   # datatype for non linear layers
    #   high_precision_wgt_dtype = str(main_params["high_precision_wgt_dtype"]) #"uint16" for win24 # "bfloat16" for sdxl turbo

    base_name = os.path.splitext(os.path.basename(model_path))[0]
    # load original model.onnx
    model_ = onnx.load_model(model_path, load_external_data=load_data)
    mod_model_path = os.path.join(output_dir, base_name + "_mod.onnx")
    add_node_names(model_)
    update_op_set(model_)

    # check find shape parameters in model
    default_shape_params: dict[str, int] = {}
    default_graph_input_values: dict[str, Any] = {}
    if main_params.get("default_shape_params_values"):
        default_shape_params_values = DefaultShapeParamsValue.from_yaml_file(
            main_params["default_shape_params_values"]
        )
        # use the default shape params for the shape params needed by the model
        default_shape_params = {
            shape_param: default_shape_params_values.shape_params[shape_param]
            for shape_param in get_shape_params_from_model(model_)
            if shape_param in default_shape_params_values.shape_params
        }
        # use the default input values for the graph inputs in the model
        graph_input_names = [input_.name for input_ in model_.graph.input]
        default_graph_input_values = {
            graph_input_name: default_shape_params_values.graph_input_values[
                graph_input_name
            ]
            for graph_input_name in graph_input_names
            if graph_input_name in default_shape_params_values.graph_input_values
        }
        find_dynamic_shape_params_in_model(model_, default_shape_params, default_graph_input_values)

    # dynamic shapes dictionary creation
    input_shape_dict = {}
    output_shape_dict = {}
    
    if main_params["input_dims"]:
        input_names = list(main_params["input_names"])
        input_dims = list(main_params["input_dims"])
        for i in range(len(input_names)):
            input_shape_dict[input_names[i]] = list(
                map(int, list(input_dims[i].split(",")))
            )

    if main_params["in_shape_params"] != "{}" or default_shape_params:
        shape_params = json.loads(main_params["in_shape_params"])
        updated_shape_params = default_shape_params | shape_params
        print("Using default shape parameters:", updated_shape_params)

        input_shape_dict, output_shape_dict = get_fixed_shapes_from_params(
            model_path, updated_shape_params
        )

    input_value_dict = default_graph_input_values | json.loads(
        main_params["fixed_input_values"]
    )
    if input_value_dict:
        print("Using default graph input values:", input_value_dict)

    print("Provided fixed input shapes for the model", input_shape_dict)
    print("Provided fixed output shapes for the model", output_shape_dict)
    

    ## update global inputs shapes if provided and save the model
    if len(input_shape_dict) > 0:
        for input in model_.graph.input:
            i = 0
            for d in input.type.tensor_type.shape.dim:
                d.dim_value = input_shape_dict[input.name][i]
                i += 1
                
    ## update global outputs shapes if provided and save the model
    if len(output_shape_dict) > 0:
        for output in model_.graph.output:
            i = 0
            for d in output.type.tensor_type.shape.dim:
                d.dim_value = output_shape_dict[output.name][i]
                i += 1

    # save _mod.onnx file
    save_model(model_, mod_model_path, load_data == 1)

    # act shapes except dq and q nodes
    all_act_signal_shapes, all_act_signal_dtypes = (
        get_act_shapes_using_onnx_runtime_large_model(
            model_,
            mod_model_path,
            input_shape_dict,
            input_value_dict,
            print_to_file=False,
            global_out_max=main_params["shape_inference_outputs"],
        )
    )
    value_info_dict = {
        value_info.name: value_info for value_info in model_.graph.value_info
    }
    # initializer shapes and dtypes, will be useful for dq nodes
    # model_initializer_shapes, model_initializer_dtypes = get_initializer_shapes(model_path, load_data=load_data, print_to_file=False)

    node_cnt = 0
    qdq_node_cnt = 0
    for node in model_.graph.node:
        node_cnt += 1
        if node.op_type in ["DequantizeLinear", "QuantizeLinear"]:
            qdq_node_cnt += 1

    ## add all tensors to value info, which are not in value_info but in all_act_signal_shapes (mostly Quant, Dequant nodes)
    # Create a dictionary from the ValueInfoProto objects in the value_info field

    value_infor_count_orig = len(value_info_dict)
    for tensor_name, shape in all_act_signal_shapes.items():
        if tensor_name not in value_info_dict:
            dtype_onnx = onnx.helper.np_dtype_to_tensor_dtype(
                all_act_signal_dtypes[tensor_name]
            )
            new_value_info = onnx.helper.make_tensor_value_info(
                tensor_name, dtype_onnx, shape
            )
            model_.graph.value_info.append(new_value_info)

    print("Total nodes:", node_cnt, ", QDQ nodes:", qdq_node_cnt)

    update_model_value_info(
        model_,
        all_act_signal_shapes,
        all_act_signal_dtypes,
        value_infor_count_orig,
    )

    save_model(
        model_, os.path.join(output_dir, base_name) + "_mod.onnx", load_data == 1
    )

def add_node_names(model):
    node_incr = 0
    for node in model.graph.node:
        if node.name == "":
            node.name = node.op_type + str(node_incr)
        node_incr += 1

def find_dynamic_shape_params_in_model(model: onnx.ModelProto, default_shape_params: dict, default_graph_input_values: dict):
    for n in model.graph.node:
        if n.op_type == "GroupQueryAttention":
            try:
                onnx_ops_wrapper = OnnxOpsWrapper()
                _model_dict = ModelDict(model, onnx_ops_wrapper)
                walk_cfg_plain = WalkCfgPlain()
                n = Node(_model_dict, walk_cfg_plain, n.name)
                n = n.with_walk_cfg(WalkCfgSkipNoop())
                query_shape = n("query").get_shape()
            except:
                if n("query").check_input_tensor():
                    query_shape = n("query").search_for_shape().get_shape()
                else:
                    break
            if len(query_shape)<2:
                break
            sequence_length = 0
            if len(query_shape) == 3:
                # (B, S, num_heads*D)
                sequence_length = query_shape[1]
            elif len(query_shape)==4:
                # (B, num_heads, S, D)
                sequence_length = query_shape[2]
            else:
                raise RuntimeError("Unsupported query tensor shape for GroupQueryAttention")
            if isinstance(sequence_length,str) :
                if sequence_length in default_shape_params:
                    sequence_length = default_shape_params[sequence_length]
                else:
                    default_shape_params[sequence_length] = default_shape_params["seq_len"]
                    sequence_length = default_shape_params["seq_len"]

            name_total_seq_len = n("total_sequence_length").get_name() # 'total_seq_len'
            total_sequence_length = 64 # default value
            if name_total_seq_len in default_graph_input_values:
                total_sequence_length = default_graph_input_values[name_total_seq_len] 

            # graph input: seqlen_k, total_sequence_length           
            if sequence_length != total_sequence_length:
                # decoder mode, need to ensure seqlens_k > = sequence_length -1                
                name_seqlen = n("seqlens_k").get_name() # past_seq_len
                if name_seqlen in default_graph_input_values:
                    curr_value = default_graph_input_values[name_seqlen] 
                    if isinstance(curr_value, list):
                        update_valid_input_value(curr_value, sequence_length-1)
                    else:
                        default_graph_input_values[name_seqlen] = sequence_length - 1
                else:
                    # Any input parameter not present in default_shape_params_values should also be added and assigned a proper value, e.g. 
                    # GQA of PSU3_ver_0.1-qdq-20251007.onnx , has seqlens_k has parameter seqlens_k, instead of past_seq_len
                    default_graph_input_values[name_seqlen] = sequence_length - 1
                    
                if name_total_seq_len in default_graph_input_values:
                    curr_total_seqlen = default_graph_input_values[name_total_seq_len] 
                    if isinstance(curr_total_seqlen, list):
                        total_sequence_length = update_valid_input_value(curr_total_seqlen, sequence_length)
                    else:
                        default_graph_input_values[name_total_seq_len] = sequence_length
                else:
                    default_graph_input_values[name_total_seq_len] = sequence_length
                              
                if "total_sequence_length" in default_shape_params:
                    default_shape_params["total_sequence_length"] = total_sequence_length
                    
                key_input = n("key")
                # as total_sequence_length = past_sequence_length + kv_sequence_length. In kv packed mode,
                # kv_sequence_length is 0, so past_sequence_length = total_sequence_length
                if "past_sequence_length" in default_shape_params:
                    if key_input.get_name() == "":
                        default_shape_params["past_sequence_length"] = total_sequence_length
                    else:
                        key_shape = key_input.get_shape()
                        if len(key_shape) == 3:
                            # (B, S, num_heads*D)
                            kv_sequence_length = key_shape[1]
                        elif len(key_shape)==4:
                            # (B, num_heads, S, D)
                            kv_sequence_length = key_shape[2]
                        else:
                            raise RuntimeError("Unsupported query tensor shape for GroupQueryAttention")
                        default_shape_params["past_sequence_length"] = total_sequence_length - kv_sequence_length
            break
                

def update_valid_input_value(nested_list: list, valid_value: int)-> int:
    # in onnxruntime, for GroupQueryAttention, there is formular:
    # 
    # const size_t total_seqlen = static_cast<size_t>(seqlens_k[batch_index]) + 1;
    # const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length;  // Assume no padding sequence length
    # the past_seqlen should be larger or equal to 0, so seqlens_k should be larger or equal to sequence_length -1
    # so it's necessary to update the input value for seqlens_k to make sure the shape inference in onnxruntime works with valid input value
    #
    # More constraints may be added here in future if needed, if some dynamic input value are not valid and causes shape inference failure in onnxruntime
    updated_value = 0
    while isinstance(nested_list, list) and len(nested_list) == 1:
        if isinstance(nested_list[0], list):
            nested_list = nested_list[0]
        else:
            need_update =  nested_list[0] < valid_value
            if need_update:
                nested_list[0] = valid_value
            updated_value = valid_value if need_update else nested_list[0]
            break
    return updated_value

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Output directory for generated files",
        default=".",
        # default=create_default_output_dir(),
    )

    #   parser.add_argument("-method", "--shape_infer_method", help="method to get tensor shapes from the model. Optional Field. Default value = 'onnx_runtime'", default="onnx_runtime")
    #   parser.add_argument("-act_dtype_low", "--low_precision_act_dtype", help="low precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'", default="uint16")
    #   parser.add_argument("-act_dtype_high", "--high_precision_act_dtype", help="high precision activation dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'", default="uint16")
    #   parser.add_argument("-wgt_dtype_low", "--low_precision_wgt_dtype", help="low precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint8'", default="uint8")
    #   parser.add_argument("-wgt_dtype_high", "--high_precision_wgt_dtype", help="high precision weights dtype for tensor datatype assignment. Optional Field. Default value = 'uint16'", default="uint16")

    parser.add_argument(
        "-in",
        "--input_names",
        required=False,
        nargs="+",
        help="Names of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )
    parser.add_argument(
        "-dims",
        "--input_dims",
        required=False,
        nargs="+",
        help="Shapes of inputs if model has dynamic shape inputs. Optional Field. Default value = ''",
        default="",
    )

    parser.add_argument(
        "-shape_params",
        "--in_shape_params",
        required=False,
        type=str,
        help="Dynamic shape parameters for inputs as a JSON string. Optional Field. Default value = '{}'",
        default="{}",
    )
    parser.add_argument(
        "--fixed_input_values",
        required=False,
        type=str,
        help="Fixed input values to the neural network. JSON syntax: input name -> value. Optional Field. Default value = '{}'",
        default="{}",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")
    main_params = vars(args)

    add_model_info(main_params)
