# fmt: on
import argparse
from dataclasses import dataclass
from dataclass_wizard import JSONWizard
import json
import sys
from typing import Any, Optional, Union


@dataclass
class InputOutputDict(JSONWizard):
    param_name: str = ""
    type: str = ""
    shape: list[int | str] = None  # shape might be['NA']
    dtype: str = ""
    dtype_bytes: int = None
    hw_dtype: str = None
    hw_dtype_bytes: int = None


@dataclass
class InputOutputList:
    entries: list[InputOutputDict]

    @staticmethod
    def from_string(s: str) -> "InputOutputList":
        entries: list[InputOutputDict] = []

        p = eval(s)
        for entry in p:
            io_dict = InputOutputDict.from_dict(entry)
            # Assert that the dictionary has exactly 7 attributes as defined in InputOutputDict
            assert (
                len(entry) == 7
            ), f"Expected 7 attributes in InputOutputDict, got {len(entry)}. Update the dataclass definition or this assertion."
            entries.append(io_dict)

        return InputOutputList(entries)


def compare_io_dict(
    d1: InputOutputDict, d2: InputOutputDict
) -> list[tuple[str, Any, Any]]:
    diff: list[tuple[str, Any, Any]] = []
    if d1.param_name != d2.param_name:
        diff.append(("param_name", d1.param_name, d2.param_name))
    if d1.type != d2.type:
        diff.append(("type", d1.type, d2.type))
    if d1.shape != d2.shape:
        diff.append(("shape", d1.shape, d2.shape))
    if d1.dtype != d2.dtype:
        diff.append(("dtype", d1.dtype, d2.dtype))
    if d1.dtype_bytes != d2.dtype_bytes:
        diff.append(("dtype_bytes", d1.dtype_bytes, d2.dtype_bytes))
    if d1.hw_dtype != d2.hw_dtype:
        diff.append(("hw_dtype", d1.hw_dtype, d2.hw_dtype))
    if d1.hw_dtype_bytes != d2.hw_dtype_bytes:
        diff.append(("hw_dtype_bytes", d1.hw_dtype_bytes, d2.hw_dtype_bytes))
    return diff


@dataclass
class Pattern(JSONWizard):
    op_type: str
    # TODO perhaps needs refinement (only comparing it as string is not precise enough)
    inputs: str
    outputs: str
    in_act_shape: Union[list[str], str, list[int]]
    in_wgt_shape: Union[list[str], str, list[int]]
    in_wgt1_shape: Union[list[str], str, list[int]]
    out_act_shape: Union[list[str], str, list[int]]
    in_datatype: str
    wgt_datatype: str
    wgt1_datatype: str
    out_datatype: str
    in_bytes: int
    wgt_bytes: int
    wgt1_bytes: int
    out_bytes: int
    attributes: Union[dict[str, Any], str]
    qdq_symmetry: str
    coeff_shape: Union[list[int], str]
    in_act_residency: str
    out_act_residency: str
    Frequency: int
    nodenames: list[str]

    def sanitize(self):
        if self.in_wgt_shape == [] or self.in_wgt_shape  is None:
            self.in_wgt_shape = None
            self.wgt_bytes = None
            self.wgt_datatype = None

        if self.in_wgt1_shape == [] or self.in_wgt1_shape is None:
            self.in_wgt1_shape = None
            self.wgt1_bytes = None
            self.wgt1_datatype = None


def ignore_keys_in_dict(
    d: Union[dict, str], ignore_keys: Optional[list[str]] = None
) -> Union[dict, str]:
    if not isinstance(d, dict) or ignore_keys is None:
        return d
    return {k: v for k, v in d.items() if k not in ignore_keys}


def compare_dict(
    d1: Union[dict, str],
    d2: Union[dict, str],
    ignore_keys: Optional[list[str]] = None,
    cpp_fe_mode: bool = False,
) -> Optional[tuple[Any, Any]]:
    d1 = ignore_keys_in_dict(d1, ignore_keys)
    d2 = ignore_keys_in_dict(d2, ignore_keys)

    if d1 == "None" and d2 == {}:
        return None
    if d1 == {} and d2 == "None":
        return None

    # if we are comparing the CPP F.E (chk) with the python one (ref)
    # ignore the extra attributes present in the second attribute dictionary (chk)
    if type(d2) is dict and type(d1) is dict and cpp_fe_mode:
        d2 = {key: d2[key] for key in d2.keys() if key in d1}

    if d1 != d2:
        return (d1, d2)
    return None


def compare_list(
    l1: Union[list, str], l2: Union[list, str]
) -> Optional[tuple[Any, Any]]:
    empty_list = [[], "None", ["NA"]]
    if l1 in empty_list and l2 in empty_list:
        return None
    if l1 != l2:
        return (l1, l2)
    return None


def compare(
    pattern1: Pattern,
    pattern2: Pattern,
    ignore_attributes: list[str],
    cpp_fe_mode: bool,
) -> list[tuple[str, Any, Any]]:
    differences: list[tuple[str, Any, Any]] = []
    if pattern1.op_type != pattern2.op_type:
        differences.append(("op_type", pattern1.op_type, pattern2.op_type))
    if pattern1.inputs != pattern2.inputs:
        io_list1 = InputOutputList.from_string(pattern1.inputs)
        io_list2 = InputOutputList.from_string(pattern2.inputs)

        for idx, (entry1, entry2) in enumerate(zip(io_list1.entries, io_list2.entries)):
            diff_local = compare_io_dict(entry1, entry2)
            for diff in diff_local:
                differences.append((f"inputs[{idx}].{diff[0]}", diff[1], diff[2]))

    if pattern1.outputs != pattern2.outputs:
        io_list1 = InputOutputList.from_string(pattern1.outputs)
        io_list2 = InputOutputList.from_string(pattern2.outputs)

        for idx, (entry1, entry2) in enumerate(zip(io_list1.entries, io_list2.entries)):
            diff_local = compare_io_dict(entry1, entry2)
            for diff in diff_local:
                differences.append((f"outputs[{idx}].{diff[0]}", diff[1], diff[2]))
    comp = compare_list(pattern1.in_act_shape, pattern2.in_act_shape)
    if comp is not None:
        differences.append(("in_act_shape", comp[0], comp[1]))
    if pattern1.in_wgt_shape != pattern2.in_wgt_shape:
        differences.append(
            ("in_wgt_shape", pattern1.in_wgt_shape, pattern2.in_wgt_shape)
        )
    if pattern1.in_wgt1_shape != pattern2.in_wgt1_shape:
        differences.append(
            ("in_wgt1_shape", pattern1.in_wgt1_shape, pattern2.in_wgt1_shape)
        )
    if pattern1.out_act_shape != pattern2.out_act_shape:
        differences.append(
            ("out_act_shape", pattern1.out_act_shape, pattern2.out_act_shape)
        )
    if pattern1.in_datatype != pattern2.in_datatype:
        differences.append(("in_datatype", pattern1.in_datatype, pattern2.in_datatype))
    if pattern1.wgt_datatype != pattern2.wgt_datatype:
        differences.append(
            ("wgt_datatype", pattern1.wgt_datatype, pattern2.wgt_datatype)
        )
    if pattern1.wgt1_datatype != pattern2.wgt1_datatype:
        differences.append(
            ("wgt1_datatype", pattern1.wgt1_datatype, pattern2.wgt1_datatype)
        )
    if pattern1.out_datatype != pattern2.out_datatype:
        differences.append(
            ("out_datatype", pattern1.out_datatype, pattern2.out_datatype)
        )
    if pattern1.in_bytes != pattern2.in_bytes:
        differences.append(("in_bytes", pattern1.in_bytes, pattern2.in_bytes))
    if pattern1.wgt_bytes != pattern2.wgt_bytes:
        differences.append(("wgt_bytes", pattern1.wgt_bytes, pattern2.wgt_bytes))
    if pattern1.wgt1_bytes != pattern2.wgt1_bytes:
        differences.append(("wgt1_bytes", pattern1.wgt1_bytes, pattern2.wgt1_bytes))
    if pattern1.out_bytes != pattern2.out_bytes:
        differences.append(("out_bytes", pattern1.out_bytes, pattern2.out_bytes))
    comp = compare_dict(
        pattern1.attributes,
        pattern2.attributes,
        ignore_attributes,
        cpp_fe_mode,
    )
    if comp is not None:
        differences.append(("attributes", comp[0], comp[1]))
    if pattern1.qdq_symmetry != pattern2.qdq_symmetry:
        differences.append(
            ("qdq_symmetry", pattern1.qdq_symmetry, pattern2.qdq_symmetry)
        )
    comp = compare_list(pattern1.coeff_shape, pattern2.coeff_shape)
    if comp is not None:
        differences.append(("coeff_shape", comp[0], comp[1]))
    if pattern1.in_act_residency != pattern2.in_act_residency:
        differences.append(
            ("in_act_residency", pattern1.in_act_residency, pattern2.in_act_residency)
        )
    if pattern1.out_act_residency != pattern2.out_act_residency:
        differences.append(
            (
                "out_act_residency",
                pattern1.out_act_residency,
                pattern2.out_act_residency,
            )
        )
    return differences


class PatternDict:
    def __init__(self, patterns: dict[str, Pattern]):
        self.patterns = patterns

        self.node_to_pattern: dict[str, Pattern] = {}
        for pattern in patterns.values():
            for node in pattern.nodenames:
                self.node_to_pattern[node] = pattern

    def get_nodes(self, op_type: str) -> list[str]:
        nodes: list[str] = []
        for node, pattern in self.node_to_pattern.items():
            if pattern.op_type == op_type:
                nodes.append(node)
        return nodes

    def get_op_types(self) -> set[str]:
        types: set[str] = set()
        for pattern in self.patterns.values():
            types.add(pattern.op_type)
        return types

    def sanitize(self):
        for pattern in self.patterns.values():
            pattern.sanitize()


def read_json(file_name: str) -> PatternDict:
    with open(file_name) as f:
        patterns: dict[str, Pattern] = {}
        d = json.load(f)
        for k, v in d.items():
            pattern = Pattern.from_dict(v)
            patterns[k] = pattern
        return PatternDict(patterns)


def compare_and_print_dict_diff(dict1, dict2):
    """
    Compares two dictionaries and prints their differences.
    Differences include:
    - Keys present in dict1 but not in dict2.
    - Keys present in dict2 but not in dict1.
    - Keys present in both but with different values.
    """
    keys_in_dict1_only = set(dict1.keys()) - set(dict2.keys())
    keys_in_dict2_only = set(dict2.keys()) - set(dict1.keys())
    common_keys = set(dict1.keys()) & set(dict2.keys())

    res = list()
    if keys_in_dict1_only:
        res.append("Keys present only in ref:")
        for key in keys_in_dict1_only:
            res.append(f"  - '{key}': '{dict1[key]}'")

    if keys_in_dict2_only:
        res.append("Keys present only in check:")
        for key in keys_in_dict2_only:
            res.append(f"  - '{key}': '{dict2[key]}'")

    common_keys_with_diff = list()
    for key in common_keys:
        if dict1[key] != dict2[key]:
            common_keys_with_diff.append(key)

    if common_keys_with_diff:
        res.append("Keys with different values:")
        for key in common_keys_with_diff:
            res.append(f"  - Key '{key}':")
            res.append(f"    - dict1 value: '{dict1[key]}'")
            res.append(f"    - dict2 value: '{dict2[key]}'")

    return "\n".join(res)


def get_diff_as_string(node: str, data: Optional[list[tuple[str, Any, Any]]]) -> str:
    if data is None:
        return ""

    res = list()
    for diff in data:
        tmp = ""
        if type(diff[1]) is dict and type(diff[2]) is dict:
            tmp += "{:}:\n".format(diff[0])
            tmp += compare_and_print_dict_diff(diff[1], diff[2])
        else:
            field_name = diff[0]
            ref_value = diff[1]
            chk_value = diff[2]

            # Check if one side is missing (empty string or None) and the other has a value
            ref_empty = ref_value in ["", None]
            chk_empty = chk_value in ["", None]

            if ref_empty and not chk_empty:
                tmp += "{:<20}".format(field_name)
                tmp += " present only in check: '{}'".format(chk_value)
            elif chk_empty and not ref_empty:
                tmp += "{:<20}".format(field_name)
                tmp += " present only in ref: '{}'".format(ref_value)
            else:
                # Both have values, show the difference
                tmp += "{:<20}".format(field_name)
                tmp += "{:<40}".format("'" + str(ref_value) + "'")
                tmp += "{:<40}".format("'" + str(chk_value) + "'")
        res.append(tmp)
    return "\n".join(res)


def compare_opt_type(
    op_type: str,
    pd1: PatternDict,
    pd2: PatternDict,
    ignore_attributes: list[str],
    cpp_fe_mode: bool,
) -> bool:
    print(f"op_type: {op_type}")
    nodes1 = pd1.get_nodes(op_type)
    nodes2 = pd2.get_nodes(op_type)

    common_nodes = set(nodes1) & set(nodes2)
    nodes_only1 = set(nodes1) - common_nodes
    nodes_only2 = set(nodes2) - common_nodes

    pattern_nodes_equal = True
    nodes_per_diff: dict[str, list[str]] = dict()
    for node in sorted(common_nodes):
        pattern1 = pd1.node_to_pattern[node]
        pattern2 = pd2.node_to_pattern[node]

        diff = compare(pattern1, pattern2, ignore_attributes, cpp_fe_mode)
        if diff:
            pattern_nodes_equal = False
            nodes_per_diff.setdefault(get_diff_as_string(node, diff), []).append(node)

    if nodes_only1:
        for node in sorted(nodes_only1):
            if node in pd2.node_to_pattern:
                pattern1 = pd1.node_to_pattern[node]
                pattern2 = pd2.node_to_pattern[node]
                diff = compare(pattern1, pattern2, ignore_attributes, cpp_fe_mode)
                nodes_per_diff.setdefault(get_diff_as_string(node, diff), []).append(
                    node
                )
                pattern_nodes_equal = False
            # ignore the missing node in json generated by the CPP F.E as they can be moved in the subgraphs
            elif cpp_fe_mode and (
                op_type.endswith("noop") or op_type in ["Cast", "Gather"]
            ):
                continue
            else:
                pattern_nodes_equal = False
                print(f"node={node} does not exist in CHECK model")

    if nodes_only2:
        pattern_nodes_equal = False
        for node in sorted(nodes_only2):
            if node in pd1.node_to_pattern:
                pattern1 = pd1.node_to_pattern[node]
                pattern2 = pd2.node_to_pattern[node]
                diff = compare(pattern1, pattern2, ignore_attributes, cpp_fe_mode)
                nodes_per_diff.setdefault(get_diff_as_string(node, diff), []).append(
                    node
                )
            else:
                print(f"node={node} does not exist in REF model")

    if pattern_nodes_equal:
        print("-> no differences found")

    for diff_str, nodes in nodes_per_diff.items():
        if not diff_str:
            continue
        print("Found diff:")
        print(f"{diff_str}")
        print("for node.s :")
        for node in nodes:
            print(node)

    return pattern_nodes_equal


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--ref")
    parser.add_argument("--check")
    parser.add_argument("--ignore_attributes", nargs="*", default=[])
    parser.add_argument("--cpp_fe", default=False, action="store_true")
    args = parser.parse_args()

    ref = read_json(args.ref)
    chk = read_json(args.check)

    # Sanitize the patterns to make sure that equality will be present
    # for some specific cases / differences between the python and c++ F.E
    ref.sanitize()
    chk.sanitize()

    op_types = ref.get_op_types() | chk.get_op_types()
    count_equal = 0
    for op_type in sorted(op_types):
        count_equal += compare_opt_type(
            op_type, ref, chk, args.ignore_attributes, args.cpp_fe
        )

    print(f"correct/all types: {count_equal}/{len(op_types)}")
    return 0 if count_equal == len(op_types) else 1


if __name__ == "__main__":
    sys.exit(main())
