# fmt: on
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
from typing import Optional

from OGOAT.src.L1_fusion.py_match.checkers import opType, OpType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Node,
    OutputTensor,
)


@dataclass
class StatEvalState:
    """
    A certain state of static evaluation.
    out_tensor -- output tensor to replace with a static value
    value -- static value
    """

    out_tensor: OutputTensor
    value: np.ndarray


class StatEvalBase(ABC):

    def __init__(self, orig: StatEvalState) -> None:
        """
        Evaluate the operator following the output tensor if possible statically.
        Store output tensor and value after operator if possible.
        Raise MatcherError if no evaluation possible.
        """
        self._orig = orig

        self._node = self._orig.out_tensor.require_node()
        self._value = self._orig.value

        self._new_out_tensor: Optional[OutputTensor] = None
        self._new_value: Optional[np.ndarray] = None

        self.match()

        if self._new_out_tensor is None:
            raise MatcherError("missing next_out")
        if self._new_value is None:
            raise MatcherError("missing new_value")

    def get_const_input(self, input_name: str) -> np.ndarray:
        in_tensor = self._node.get_connection(input_name)
        # first possibility for constant: initializer
        if in_tensor.check_initializer():
            return in_tensor.require_initializer().get_value_as_array()
        # second possibility for constant: constant op
        const_node = in_tensor.require_node()
        const_node.require(opType.Constant)
        attrs = list(const_node.get_attributes().values())
        if not attrs:
            raise MatcherError(f"const node {const_node} does not have attrs")
        # Constant op has exactly one attr, don't care which one
        return np.array(attrs[0])

    def get_int_attr(self, attr_name: str) -> int:
        try:
            value = self._node.get_attribute_value(attr_name)
        except KeyError:
            raise MatcherError(f"{self._node}: {attr_name} missing")
        if not isinstance(value, int):
            raise MatcherError(f"{self._node}: {attr_name} is not int")
        return value

    @abstractmethod
    def match(self) -> None:
        """
        Check if the class can evaluate the operator.
        Raise MatcherError if it cannot be handled.
        Otherwise set next output tensor and new value.
        """

    def require_enter_via(self, input_name: str) -> None:
        """
        Require that the current node got entered (from _orig.out_tensor)
        via the passed input.
        """
        input_ = self._node.get_connection(input_name)
        if not input_ == self._orig.out_tensor:
            raise MatcherError(
                f"{self._node} has been entered via {self._orig.out_tensor},"
                f" expected {input_name} {input_}"
            )

    @property
    def result(self) -> StatEvalState:
        return StatEvalState(self._new_out_tensor, self._new_value)


class StatEvalBinOp(StatEvalBase, ABC):

    op_type: str = "set_by_subclass"

    @abstractmethod
    def compute(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        pass

    def match(self) -> None:
        self._node.require(OpType(self.op_type))
        try:
            self.require_enter_via("A")
            enter_via_A = True
        except MatcherError:
            self.require_enter_via("B")
            enter_via_A = False
        self._new_out_tensor = self._node("C").require_tensor()
        const = self.get_const_input("B" if enter_via_A else "A")
        args = (self._value, const) if enter_via_A else (const, self._value)
        try:
            self._new_value = self.compute(*args)
        except ValueError as ve:
            raise MatcherError(f"{self._node}: could not compute: {ve}")


class StatEvalBinAdd(StatEvalBinOp):

    op_type: str = "Add"

    def compute(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.add(a, b)


class StatEvalBinDiv(StatEvalBinOp):

    op_type: str = "Div"

    def compute(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        if a.dtype.kind == "i" and b.dtype.kind == "i":
            return np.floor_divide(a, b)
        else:
            return np.divide(a, b)


class StatEvalBinMul(StatEvalBinOp):

    op_type: str = "Mul"

    def compute(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.multiply(a, b)


class StatEvalBinSub(StatEvalBinOp):

    op_type: str = "Sub"

    def compute(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return np.subtract(a, b)


class StatEvalGather(StatEvalBase):

    def match(self) -> None:
        self._node.require(opType.Gather)
        self.require_enter_via("data")
        self._new_out_tensor = self._node("output").require_tensor()
        axis = self.get_int_attr("axis")
        indices = self.get_const_input("indices")
        if isinstance(indices, int) or len(indices.shape) == 0:
            indices = np.array([indices])
        try:
            self._new_value = np.take_along_axis(self._value, indices, axis=axis)
        except ValueError as ve:
            raise MatcherError(f"{self._node}: could not gather: {ve}")


class StaticEvaluator:
    """
    Evaluate the nodes following a certain output tensor as far as possible
    statically.
    Finally, return the output tensor and the value of the last node that could
    be evaluated statically. This might be the same output tensor and value as
    passed to this function if no further static evaluation is possible.
    """

    def __init__(self, out_tensor: OutputTensor, value: np.ndarray) -> None:
        self._remove_nodes: list[Node] = []
        state = StatEvalState(out_tensor, np.array(value))
        progress = True
        while progress:
            progress = False
            for stat_eval_class in [
                StatEvalBinAdd,
                StatEvalBinDiv,
                StatEvalBinMul,
                StatEvalBinSub,
                StatEvalGather,
            ]:
                try:
                    stat_eval = stat_eval_class(state)
                except MatcherError:
                    continue
                state = stat_eval.result
                self._remove_nodes.append(state.out_tensor.get_origin())
                progress = True
        self._result = state

    @property
    def out_tensor(self) -> OutputTensor:
        """
        Return last output tensor that could be evaluated statically.
        """
        return self._result.out_tensor

    @property
    def remove_nodes(self) -> list[Node]:
        """
        Return list of nodes that have been evaluated statically and can be
        removed.
        """
        return self._remove_nodes

    @property
    def value(self) -> np.ndarray:
        """
        Return static value of self.out_tensor.
        """
        return self._result.value
