# fmt: on
"""
Cut a piece of the graph out of an ONNX model.

typical usage:

$ python OGOAT/misc_tools/cut_graph.py --input Model_PSR_v1.1.onnx print_op_types
...
Softmax
...
$ python OGOAT/misc_tools/cut_graph.py --input Model_PSR_v1.1.onnx print_nodes --op_type Softmax
...
Softmax:
  down_blocks.0.attentions.0.transformer_blocks.0.attn1.softmax_1.0
  down_blocks.0.attentions.0.transformer_blocks.0.attn1.softmax_1.1
...
$ python OGOAT/misc_tools/cut_graph.py --input Model_PSR_v1.1.onnx \\
         cut_subgraph --node_name down_blocks.0.attentions.0.transformer_blocks.0.attn1.softmax_1.0 \\
         --levels_up 3 --levels_down 1 --output /tmp/out.onnx
"""

import argparse
import dataclasses
import numpy as np
import onnx
import onnx.checker
import onnx.helper
import onnxruntime as ort
import os
from typing import Iterable

from OGOAT.src.L1_fusion.L1_utils.utils import save_model
from OGOAT.src.L1_fusion.topo_sort import graph_sort


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="cut a piece of the graph out of an ONNX model"
    )
    parser.add_argument("--check", "-c", action="store_true", help="check ONNX model")
    parser.add_argument(
        "--extern_data",
        "-e",
        action="store_true",
        help="load external data (might be needed for --shape-infer)",
    )
    parser.add_argument("--input", "-i", required=True, help="name of input ONNX model")
    subparsers = parser.add_subparsers(required=True, help="action")

    print_op_types_parser = subparsers.add_parser(
        "print_op_types", help="print operator types in model"
    )
    print_op_types_parser.set_defaults(func=print_op_types)

    print_nodes_parser = subparsers.add_parser(
        "print_nodes", help="print nodes in model sorted by op_type"
    )
    print_nodes_parser.set_defaults(func=print_nodes)
    print_nodes_parser.add_argument("--op_type", "-t", help="show only this op type")
    print_nodes_parser.add_argument(
        "--max_nodes",
        "-m",
        type=int,
        help="maximum number of nodes to print per op type",
    )

    cut_subgraph_parser = subparsers.add_parser(
        "cut_subgraph", help="cut subgraph out of whole model"
    )
    cut_subgraph_parser.set_defaults(func=cut_subgraph)
    cut_subgraph_parser.add_argument(
        "--node_name",
        "-n",
        nargs="+",
        help="name of nodes to keep in cut subgraph (keep all nodes by default)",
    )
    cut_subgraph_parser.add_argument(
        "--levels_down",
        "-d",
        type=int,
        default=0,
        help="number of levels to keep below the selected nodes",
    )
    cut_subgraph_parser.add_argument(
        "--levels_up",
        "-u",
        type=int,
        default=0,
        help="number of levels to keep above the selected nodes",
    )
    cut_subgraph_parser.add_argument(
        "--erase_node_name",
        "-e",
        nargs="+",
        help="name of nodes to erase in cut subgraph (applied after -n, -u, -d)",
    )
    cut_subgraph_parser.add_argument(
        "--output", "-o", required=True, help="name of output ONNX model"
    )
    cut_subgraph_parser.add_argument(
        "--recompute_output_info",
        "-r",
        action="store_true",
        help="recompute output tensor info (to fix broken models that miss it)",
    )
    cut_subgraph_parser.add_argument(
        "--shape_infer",
        "-s",
        action="store_true",
        help="run shape inference (needed for models without shape annoations)",
    )
    cut_subgraph_parser.add_argument(
        "--topological_sort",
        "-t",
        action="store_true",
        help="re-sort the model topologically (to fix broken models that are not sorted)",
    )
    cut_subgraph_parser.add_argument(
        "--temp_file",
        default="tmp.onnx",
        help="use the specified temporary file for working with models with external data",
    )

    args = parser.parse_args()
    return args


def print_op_types(model: onnx.ModelProto, args: argparse.Namespace) -> None:
    op_types: set[str] = set()
    for node in model.graph.node:
        op_types.add(node.op_type)
    for op_type in sorted(op_types):
        print(op_type)


def print_nodes(model: onnx.ModelProto, args: argparse.Namespace) -> None:
    node_names_by_op_type: dict[str, list[str]] = {}
    for node in model.graph.node:
        node_names_by_op_type.setdefault(node.op_type, []).append(node.name)
    for op_type, node_names in sorted(node_names_by_op_type.items()):
        if args.op_type is not None and op_type != args.op_type:
            continue
        print(f"{op_type}:")
        if args.max_nodes is not None:
            node_names = node_names[0 : args.max_nodes]
        for node_name in sorted(node_names):
            print(f"  {node_name}")


class LookUp:
    """
    Look-up tables for nodes.

    nodes_by_name -- key: node name, value: ONNX node,
    readers -- key: input tensor name, value: ONNX nodes reading the tensor
    writer -- key: output tensor name, value ONNX node writing the tensor

    This also patches/fixed missing/duplicate node names.
    """

    def __init__(self, model: onnx.ModelProto) -> None:
        # node_name -> ONNX node
        self.nodes_by_name: dict[str, onnx.NodeProto] = {}
        # tensor name -> node name list
        self.readers: dict[str, list[str]] = {}
        # tensor name -> node name
        self.writer: dict[str, str] = {}
        # tensor name -> ONNX value info
        self.value_info_by_name: dict[str, onnx.ValueInfoProto] = {}

        for node in model.graph.node:
            self.nodes_by_name[node.name] = node
            for input_ in node.input:
                if input_ == "":
                    continue  # unused optional input
                self.readers.setdefault(input_, []).append(node.name)
            for output in node.output:
                if output in self.writer:
                    raise RuntimeError(
                        f"multiple writers for {output}: {self.writer[output]}, {node.name}"
                    )
                self.writer[output] = node.name
        for value_info in model.graph.value_info:
            self.value_info_by_name[value_info.name] = value_info

    def find_nodes(
        self, node_names: list[str], levels_down: int, levels_up: int
    ) -> list[str]:
        """
        Find names of nodes. Start at each node in node_names, go levels_down
        levels down and levels_up levels up.
        Return sorted list of node names.
        """
        for node_name in node_names:
            if node_name not in self.nodes_by_name:
                raise RuntimeError(f"node {node_name} not found")
        # first search <levels_down> downwards
        down_todo: list[tuple[str, int]] = [
            # (node name, level relative to start node (positive: down, negative: up))
            (node_name, 0)
            for node_name in node_names
        ]
        down_found: dict[str, int] = {name: level for name, level in down_todo}
        while down_todo:
            cur_name, cur_level = down_todo.pop(0)
            cur_node = self.nodes_by_name[cur_name]
            if cur_level < levels_down:
                for output in cur_node.output:
                    if output in self.readers:
                        for next_name in self.readers[output]:
                            if next_name not in down_found:
                                down_todo.append((next_name, cur_level + 1))
                                down_found[next_name] = cur_level + 1
        # now search upwards to <levels_up> further up from start node
        up_todo: list[tuple[str, int]] = [
            (name, level) for name, level in down_found.items()
        ]
        up_found: set[str] = {name for name, level in up_todo}
        while up_todo:
            cur_name, cur_level = up_todo.pop(0)
            cur_node = self.nodes_by_name[cur_name]
            if cur_level > -levels_up:
                for input_ in cur_node.input:
                    if input_ == "":
                        continue  # unused optional input
                    if input_ in self.writer:
                        next_name = self.writer[input_]
                        if next_name not in up_found:
                            up_todo.append((next_name, cur_level - 1))
                            up_found.add(next_name)
        return sorted(up_found)

    def get_used_tensors(self, node_names: Iterable[str]) -> list[str]:
        """
        Return list of names of tensors that are used (read or written) by at
        least one of the nodes.
        """
        used_tensors: set[str] = set()
        for node_name in node_names:
            node = self.nodes_by_name[node_name]
            for input_ in node.input:
                if input_ == "":
                    continue  # unused optional input
                used_tensors.add(input_)
            for output in node.output:
                used_tensors.add(output)
        return sorted(used_tensors)


@dataclasses.dataclass
class TensorInfo:
    name: str
    shape: list[int]
    dtype: int


class ShapeInfer:
    """
    Infer shapes of ONNX model.
    Do lazy intialization. Perform the action only on the first query.
    This means this does not run if it is not needed.
    """

    def __init__(
        self,
        model: onnx.ModelProto,
        tensors: Iterable[str],
        extern_data: bool,
        temp_file: str,
    ) -> None:
        self.model = model
        self.tensors = tensors
        self.extern_data = extern_data
        self.temp_file = temp_file
        self.info: dict[str, TensorInfo] = {}

        # store input shapes
        self.info: dict[str, TensorInfo] = {}
        for input_ in self.model.graph.input:
            self.info[input_.name] = TensorInfo(
                name=input_.name,
                shape=[d.dim_value for d in input_.type.tensor_type.shape.dim],
                dtype=input_.type.tensor_type.elem_type,
            )

        # create clone of model with all interesting tensors added as outputs
        existing_inputs = set(i.name for i in self.model.graph.input)
        existing_outputs = set(o.name for o in self.model.graph.output)
        if self.extern_data:
            save_model(self.model, self.temp_file, external_data=True)
            cloned_model = onnx.load(self.temp_file, load_external_data=True)
        else:
            cloned_model = onnx.load_from_string(self.model.SerializeToString())
        to_add = set(self.tensors) - existing_inputs - existing_outputs
        for tensor in to_add:
            cloned_model.graph.output.extend([onnx.ValueInfoProto(name=tensor)])

        so = ort.SessionOptions()
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
        if self.extern_data:
            save_model(cloned_model, self.temp_file, external_data=True)
            model_for_ort = self.temp_file
        else:
            model_for_ort = cloned_model.SerializeToString()
        ort_session = ort.InferenceSession(
            model_for_ort, so, providers=["CPUExecutionProvider"]
        )
        outputs = [x.name for x in ort_session.get_outputs()]

        def mk_val(info: TensorInfo):
            shape = info.shape if info.shape else [1]
            rnd = np.random.randn(*shape)
            dtype = onnx.helper.tensor_dtype_to_np_dtype(info.dtype)
            return rnd.astype(dtype)

        rand_inputs = {
            i.name: mk_val(self.info[i.name]) for i in ort_session.get_inputs()
        }
        out_data = ort_session.run(outputs, rand_inputs)
        for out_name, out_val in zip(outputs, out_data):
            self.info[out_name] = TensorInfo(
                name=out_name,
                shape=list(out_val.shape),
                dtype=onnx.helper.np_dtype_to_tensor_dtype(out_val.dtype),
            )


class ModelCutter:
    """
    Reduce a model to a certain set of nodes.
    """

    def __init__(
        self,
        model: onnx.ModelProto,
        look_up: LookUp,
        shape_infer_info: dict[str, TensorInfo],
    ) -> None:
        self.model = model
        self.look_up = look_up
        self.shape_infer_info = shape_infer_info

    @staticmethod
    def remove_not_kept(elems, names_to_keep) -> None:
        remove_elems = list(elem for elem in elems if elem.name not in names_to_keep)
        for elem in remove_elems:
            elems.remove(elem)

    def cut(self, keep_nodes: Iterable[str], keep_tensors: Iterable[str]):
        self.remove_not_kept(self.model.graph.node, keep_nodes)
        self.remove_not_kept(self.model.graph.initializer, keep_tensors)
        self.remove_not_kept(self.model.graph.value_info, keep_tensors)
        self.remove_not_kept(self.model.graph.input, keep_tensors)
        self.remove_not_kept(self.model.graph.output, keep_tensors)
        new_inputs, new_outputs = self.get_new_inputs_and_outputs()
        # create new inputs and outputs
        for new_input in new_inputs:
            i = self.model.graph.input.add()
            self.set_value_info_proto(i, self.get_tensor_info(new_input))
        for new_output in new_outputs:
            o = self.model.graph.output.add()
            self.set_value_info_proto(o, self.get_tensor_info(new_output))

    def get_new_inputs_and_outputs(self) -> tuple[list[str], list[str]]:
        tensors_read: set[str] = set()
        tensors_written: set[str] = set()
        for node in self.model.graph.node:
            for input_ in node.input:
                if input_ == "":
                    continue  # unused optional input
                tensors_read.add(input_)
            for output in node.output:
                tensors_written.add(output)
        for initializer in self.model.graph.initializer:
            tensors_written.add(initializer.name)
        for input_ in self.model.graph.input:
            tensors_written.add(input_.name)
        for output in self.model.graph.output:
            tensors_read.add(output.name)
        new_inputs = tensors_read - tensors_written
        new_outputs = tensors_written - tensors_read
        return sorted(new_inputs), sorted(new_outputs)

    def get_tensor_info(self, tensor_name: str) -> TensorInfo:
        vi = self.look_up.value_info_by_name.get(tensor_name)
        if vi is not None:
            tt = vi.type.tensor_type
            return TensorInfo(
                name=tensor_name,
                shape=[d.dim_value for d in tt.shape.dim],
                dtype=tt.elem_type,
            )
        sii = self.shape_infer_info.get(tensor_name)
        if sii is not None:
            return sii
        raise RuntimeError(
            f"unknown shape for {tensor_name}, need shape annotations in model or run --shape_infer"
        )

    def recompute_output_info(self) -> None:
        for o in self.model.graph.output:
            self.set_value_info_proto(o, self.get_tensor_info(o.name))

    @staticmethod
    def set_value_info_proto(
        value_info_proto: onnx.ValueInfoProto, tensor_info: TensorInfo
    ) -> None:
        value_info_proto.name = tensor_info.name
        tensor_type_proto = value_info_proto.type.tensor_type
        tensor_type_proto.elem_type = tensor_info.dtype
        tensor_type_proto.shape.dim.extend([])  # make sure the list is present
        for n in tensor_info.shape:
            dim = tensor_type_proto.shape.dim.add()
            dim.dim_value = n


def cut_subgraph(model: onnx.ModelProto, args: argparse.Namespace) -> None:
    look_up = LookUp(model)
    node_names = args.node_name
    if not node_names:
        node_names = list(look_up.nodes_by_name.keys())
    keep_nodes = look_up.find_nodes(node_names, args.levels_down, args.levels_up)
    if args.erase_node_name:
        keep_set = set(keep_nodes)
        erase_set = set(args.erase_node_name)
        warn = sorted(erase_set - keep_set)
        if warn:
            print(f"warning: nodes to erase not found: {repr(warn)}")
        keep_nodes = list(keep_set - erase_set)
    keep_tensors = look_up.get_used_tensors(keep_nodes)

    shape_infer_info: dict[str, TensorInfo] = {}
    if args.shape_infer:
        shape_infer = ShapeInfer(model, keep_tensors, args.extern_data, args.temp_file)
        shape_infer_info = shape_infer.info

    mod_cut = ModelCutter(model, look_up, shape_infer_info)
    if args.recompute_output_info:
        mod_cut.recompute_output_info()
    mod_cut.cut(keep_nodes, keep_tensors)

    if args.topological_sort:
        graph_sort(model)

    if args.check:
        onnx.checker.check_model(model, full_check=True)
    save_model(model, args.output, external_data=args.extern_data)


def main():
    args = parse_args()
    model = onnx.load_model(args.input, load_external_data=args.extern_data)
    if args.check:
        onnx.checker.check_model(model, full_check=True)
    args.func(model, args)


if __name__ == "__main__":
    main()
