# fmt: on
"""
Script for comparing two onnx models.

If called with two onnx file names, the two contained onnx models are
syntactically compared (nodes, input, outputs, initializers have to have the
same name).
Warnings are printed if one element has additional fields/information which are
not contained in the corresponding element of the other model.
Errors are mismatches of elements (e.g. nodes with the same name have a
different amount of inputs, a different ordering or different input names).

Self-identified elements (e.g. nodes, attributes, graph input/output) are not
checked for the same ordering, but compared to an element with the same
identifier (e.g. name, key, domain, ...).
The elements are categories by the class `OnnxModelNamedList`.

If the script is called with two directories as parameters both directories are
search for `*_fused.onnx` files and if the file is contained in both directories
they are compared.
"""

import sys
from typing import Any, Union
from pathlib import Path
import argparse

from functools import singledispatch
from enum import Enum
from google.protobuf.message import Message
from google.protobuf.descriptor import FieldDescriptor
from onnx import (
    AttributeProto,
    OperatorSetIdProto,
    ModelProto,
    NodeProto,
    StringStringEntryProto,
    TensorProto,
    TensorShapeProto,
    ValueInfoProto,
    load,
)


class Printer:
    class Color(Enum):
        WARNING = 0, "\033[93m"
        ERROR = 1, "\033[91m"
        END_COLOR = 2, "\033[0m"

    def __init__(self):
        self.path: list[str] = []
        self.warnings: list[str] = []
        self.errors: list[str] = []

    def _update_path(self, element: str, indentation: int) -> None:
        if len(self.path) == indentation:
            self.path.append(element)
        elif len(self.path) > indentation:
            self.path = self.path[:indentation]
            self.path.append(element)
        else:
            assert False, "indentation increases by more than 1"

    def print(self, message: str, indentation: int = 0) -> None:
        self._update_path(message, indentation)
        print(" " * 2 * indentation + message)

    def print_warning(self, message: str, indentation: int = 0) -> None:
        self.print(
            Printer.Color.WARNING.value[1] + message + Printer.Color.END_COLOR.value[1],
            indentation,
        )
        self.path[-1] = message
        self.warnings.append("".join(self.path))

    def print_error(self, message: str, indentation: int = 0) -> None:
        self.print(
            Printer.Color.ERROR.value[1] + message + Printer.Color.END_COLOR.value[1],
            indentation,
        )
        self.path[-1] = message
        self.errors.append("".join(self.path))

    def print_warnings_errors(self) -> None:
        for warning in self.warnings:
            print(
                Printer.Color.WARNING.value[1]
                + warning
                + Printer.Color.END_COLOR.value[1]
            )

        for error in self.errors:
            print(
                Printer.Color.ERROR.value[1] + error + Printer.Color.END_COLOR.value[1]
            )


class NoDefaultPrinter(Printer):
    def print(self, message: str, indentation: int = 0) -> None:
        self._update_path(message, indentation)


class OnnxModelNamedList:
    """
    named_list_types: contains a dict mapping proto classes to a field name of that proto message.
        Objects of these classes containing said field which acts as an identified of that objects.
        List of these objects should not be compared as a sequence but rather as a dictionary with the field name as the key.

    unnamed_list_types: is a list of classes which do not act as named_list_types classes.
        In order not to miss a proto class which was not in the test files or is added later this list contains the all known classes.
    """

    named_list_types = {
        OperatorSetIdProto: "domain",
        NodeProto: "name",
        TensorProto: "name",
        ValueInfoProto: "name",
        AttributeProto: "name",
        StringStringEntryProto: "key",
    }

    unnamed_list_types = [
        TensorShapeProto.Dimension,
        int,
        float,
        str,
        bytes,
    ]

    @staticmethod
    def named_list_to_dictionary(obj: list) -> Union[dict[str, Any], None]:
        assert type(obj) == list

        for ty in OnnxModelNamedList.named_list_types:
            if list_of(obj, ty):
                return {
                    getattr(element, OnnxModelNamedList.named_list_types[ty]): element
                    for element in obj
                }

        for ty in OnnxModelNamedList.unnamed_list_types:
            if list_of(obj, ty):
                return None

        list_type = type(obj)
        element_type = type(obj[0]) if obj else None
        assert False, f"unknown list type={list_type}, element type={element_type}"


def message_to_dict(message: Message) -> dict[str, Any]:
    d = {
        desc.name: value
        for (desc, value) in message.ListFields()
        if desc.label != FieldDescriptor.LABEL_REPEATED
    }
    d.update(
        {
            desc.name: list(value)
            for (desc, value) in message.ListFields()
            if desc.label == FieldDescriptor.LABEL_REPEATED
        }
    )
    return d


@singledispatch
def compare(
    message1, message2, printer: Printer, name: str, indentation: int = 0
) -> bool:
    assert False, f"type {type(message1)} not yet implemented"


@compare.register
def _(
    dict1: dict, dict2: dict, printer: Printer, name: str, indentation: int = 0
) -> bool:
    common_keys = dict1.keys() & dict2.keys()
    only_keys1 = dict1.keys() - common_keys
    only_keys2 = dict2.keys() - common_keys

    same_dict = True
    if only_keys1:
        printer.print_warning(
            f".{name}[]: fields only in model 1: {sorted(only_keys1)} (warning)",
            indentation,
        )
        same_dict = False
    if only_keys2:
        printer.print_warning(
            f".{name}[]: fields only in model 2: {sorted(only_keys2)} (warning)",
            indentation,
        )
        same_dict = False

    for i, key in enumerate(sorted(common_keys)):
        label = (
            f"{name}[{i}]"
            if type(dict1[key]) in OnnxModelNamedList.named_list_types
            else key
        )
        same_dict &= compare(dict1[key], dict2[key], printer, label, indentation)
    return same_dict


def list_of(obj: list, cls: type) -> bool:
    if not isinstance(obj, list):
        return False
    else:
        return all(isinstance(e, cls) for e in obj)


@compare.register
def _(
    list1: list, list2: list, printer: Printer, name: str, indentation: int = 0
) -> bool:
    dict1 = OnnxModelNamedList.named_list_to_dictionary(list1)
    dict2 = OnnxModelNamedList.named_list_to_dictionary(list2)
    if dict1 and dict2:
        return compare(dict1, dict2, printer, name, indentation)
    assert dict1 is None and dict2 is None, "list have different types"

    equal = True
    for i in range(max(len(list1), len(list2))):
        if i >= len(list1):
            printer.print_error(f".{name}[{i}] not in model 1", indentation)
            equal = False
        elif i >= len(list2):
            printer.print_error(f".{name}[{i}] not in model 2", indentation)
            equal = False
        else:
            equal &= compare(list1[i], list2[i], printer, f"{name}[{i}]", indentation)
    return equal


@compare.register
def _(
    message1: Message,
    message2: Message,
    printer: Printer,
    name: str,
    indentation: int = 0,
) -> bool:
    assert type(message1) == type(message2)
    printer.print(f".{name}", indentation)
    return compare(
        message_to_dict(message1),
        message_to_dict(message2),
        printer,
        f".{name}",
        indentation + 1,
    )


@compare.register(bytes)
@compare.register(float)
@compare.register(int)
@compare.register(str)
def _(
    a: Union[bytes, float, int, str],
    b: Union[bytes, float, int, str],
    printer: Printer,
    name: str,
    indentation,
) -> bool:
    equal = a == b
    if equal:
        if isinstance(a, bytes):
            printer.print(f".{name}=<skipped>", indentation)
        else:
            printer.print(f".{name}={a}", indentation)
    else:
        printer.print_error(f".{name} {a} != {b}  (error)", indentation)
    return equal


def is_onnx_file(path: Path) -> bool:
    return path.is_file() and path.name.endswith(".onnx")


def compare_onnx_files(
    file1: Path, file2: Path, remove_attributes: list[str], printer: Printer
) -> bool:
    model1 = load(file1.absolute())
    model2 = load(file2.absolute())
    remove_attributes_from_model(model1, remove_attributes)
    remove_attributes_from_model(model2, remove_attributes)

    return compare(model1, model2, printer, "model")


def compare_directories(
    directory1: Path, directory2: Path, remove_attributes: list[str], glob_pattern: str
) -> bool:
    equal = True

    files1 = set(file.relative_to(directory1) for file in directory1.glob(glob_pattern))
    files2 = set(file.relative_to(directory2) for file in directory2.glob(glob_pattern))

    printer = Printer()
    for file2 in sorted(files2 - files1):
        printer.print_error(f"{directory1.joinpath(file2)} not found")
    for file1 in sorted(files1 - files2):
        printer.print_error(f"{directory2.joinpath(file1)} not found")

    for file in files1 & files2:
        file1 = directory1.joinpath(file)
        file2 = directory2.joinpath(file)

        print(file)
        printer = NoDefaultPrinter()
        r = compare_onnx_files(file1, file2, remove_attributes, printer)
        if not r:
            printer.print_warnings_errors()
            print("for complete output run:")
            print(f"python {__file__} --paths {file1} {file2}")
            equal &= r

    return equal


def remove_attributes_from_model(model: ModelProto, remove_attributes: list[str]):
    for node in model.graph.node:
        ignored_attr = [
            attr for attr in node.attribute if attr.name in remove_attributes
        ]
        for attr in ignored_attr:
            node.attribute.remove(attr)


def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--paths", nargs=2, type=Path, required=True)
    parser.add_argument("--remove_attributes", nargs="*", default=[])
    parser.add_argument("--glob_pattern", default="**/*_fused.onnx")
    return parser.parse_args()


def main() -> int:
    args = parse()
    input1, input2 = args.paths

    if is_onnx_file(input1) and is_onnx_file(input2):
        printer = Printer()
        equal = compare_onnx_files(input1, input2, args.remove_attributes, printer)
        printer.print_warnings_errors()
        return 1 - equal
    elif input1.is_dir() and input2.is_dir():
        equal = compare_directories(
            input1, input2, args.remove_attributes, args.glob_pattern
        )
        return 1 - equal
    else:
        print("inputs are neither both onnx files nor both directories")
        return 2


if __name__ == "__main__":
    sys.exit(main())
