# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.

from dataclasses import dataclass, field

from dataclass_wizard import JSONWizard
import yaml

from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionConfigs
from OGOAT.src.L1_fusion.py_match.fusion_ops import yaml_to_py
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherOrCategory,
)


@dataclass
class LeveledPattern(JSONWizard):
    level: int
    name: str

    @staticmethod
    def from_dict(data: dict):
        if len(data) != 1:
            raise ValueError(
                "Invalid fusion sequence file: each pattern must have one optimization level."
            )
        for level, name in data.items():
            return LeveledPattern(level=level, name=name)

    @property
    def name_plain(self) -> str:
        """
        Return the name without the star suffix, which is used to ask for
        dumping the graph before and after the pattern for debugging.
        """
        return self.name.rstrip("*")


@dataclass
class LeveledConfigs(JSONWizard):
    extend_qdq: bool = False
    keep_border_qdq: bool = False
    batch_by_out_tensor: bool = False
    enable_batch_operator: list[LeveledPattern] = field(default_factory=list)
    MMT_configs: dict = field(default_factory=dict)

    @staticmethod
    def from_dict(data: dict):
        enable_batch_operator = [
            LeveledPattern.from_dict(item)
            for item in data.get("enable_batch_operator", [])
        ]
        configs = LeveledConfigs(
            extend_qdq=data.get("extend_qdq", False),
            keep_border_qdq=data.get("keep_border_qdq", False),
            batch_by_out_tensor=data.get("batch_by_out_tensor", False),
            enable_batch_operator=enable_batch_operator,
            MMT_configs=data.get("MMT_configs", dict()),
        )
        return configs


@dataclass
class LeveledFusionSeq(JSONWizard):
    configs: LeveledConfigs
    patterns: list[LeveledPattern]

    @staticmethod
    def from_dict(data: dict) -> "LeveledFusionSeq":
        patterns = [LeveledPattern.from_dict(item) for item in data.get("patterns", [])]
        LeveledFusionSeq.validate_pattern_levels(patterns)

        configs = LeveledConfigs.from_dict(data.get("configs", {}))
        return LeveledFusionSeq(configs, patterns)

    @staticmethod
    def validate_pattern_levels(patterns: list[LeveledPattern]):
        matcher_name_to_pattern: dict[str, str] = dict()
        for pattern, matcher_or_category in yaml_to_py.items():
            for matcher in matcher_or_category.get_matchers():
                matcher_name_to_pattern[matcher.get_matcher_class_name()] = pattern

        level_to_pattern: list[set[str]] = [set() for _ in range(4)]
        for leveled_pattern in patterns:
            level_to_pattern[leveled_pattern.level].add(
                leveled_pattern.name_plain
            )

        for leveled_pattern in patterns:
            max_level = leveled_pattern.level
            dep_matchers: set[Matcher] = get_dep_matchers(
                yaml_to_py[leveled_pattern.name_plain]
            )
            dep_patterns: set[str] = set(
                matcher_name_to_pattern[matcher.get_matcher_class_name()]
                for matcher in dep_matchers
                if matcher.get_matcher_class_name() in matcher_name_to_pattern
            )

            invalid_patterns: set[str] = dep_patterns
            for i in range(max_level + 1):
                invalid_patterns -= level_to_pattern[i]

            if len(invalid_patterns) > 0:
                raise ValueError(
                    f"Invalid fusion sequence file: '{leveled_pattern.name_plain}' with a lower optimization level depends on {invalid_patterns} with a higher optimization level."
                )


def get_dep_matchers(
    matcher_or_category: MatcherOrCategory, is_top_level=True
) -> list[Matcher]:
    res, travese_list = [], []
    if isinstance(matcher_or_category, Matcher):
        travese_list = matcher_or_category.dependencies
        if not is_top_level:
            res.append(matcher_or_category)
    else:
        travese_list = matcher_or_category.get_matchers()
    return sum([get_dep_matchers(item, False) for item in travese_list], res)


@dataclass
class FusionSeq(JSONWizard):
    configs: FusionConfigs
    patterns: list[str]
    opt_level: int = None

    @staticmethod
    def from_dict(data: dict) -> "FusionSeq":
        FusionSeq.validate_leveled_patterns(data)
        return super(FusionSeq, FusionSeq).from_dict(data)

    @staticmethod
    def validate_leveled_patterns(data: dict):
        has_levels = lambda seq: any(isinstance(item, dict) for item in seq)

        configs = data.get("configs", {})
        enable_batch_operator = configs.get("enable_batch_operator", [])
        patterns = data.get("patterns", [])

        if has_levels(enable_batch_operator) or has_levels(patterns):
            raise ValueError(
                "Invalid fusion sequence file: the --fusion_seq option requires a non-leveled version of the file."
            )

    def save_to_file(self, filename: str):
        with open(filename, "w") as f:
            yaml.dump(self.to_dict(), f)


def filter_by_opt_level(leveled_seq: LeveledFusionSeq, opt_level: int) -> "FusionSeq":
    filter = lambda seq: [item.name for item in seq if item.level <= opt_level]

    enable_batch_operator = filter(leveled_seq.configs.enable_batch_operator)
    configs = FusionConfigs(
        leveled_seq.configs.extend_qdq,
        leveled_seq.configs.keep_border_qdq,
        leveled_seq.configs.batch_by_out_tensor,
        enable_batch_operator,
        leveled_seq.configs.MMT_configs,
    )

    patterns = filter(leveled_seq.patterns)
    return FusionSeq(configs, patterns, opt_level)
