import logging
from typing import Any

from typing import Optional


from OGOAT.src.L1_fusion.py_match.checkers import (
    AttrValue,
    CategoryCheck,
    FusedWithQDQNode,
    opType,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherError,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.basic.binary_op import BinaryOp
from OGOAT.src.L1_fusion.py_match.basic.concat import Concat
from OGOAT.src.L1_fusion.py_match.basic.categories import linear_category
from OGOAT.src.L1_fusion.py_match.basic.unary import Neg
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.elementwise_helper import (
    can_multidirectional_broadcasting,
)
from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import (
    PermutationHelper,
    compute_permutation,
)


class RoPE(Matcher, QDQHelper, PermutationHelper):

    dependencies = [Concat(), BinaryOp(), Neg()]

    def match(self) -> None:
        n = self.n = self.n.with_walk_cfg(WalkCfgPlain())

        n.require(CategoryCheck(BinaryOp()))
        n.require(FusedWithQDQNode())
        n.require(AttrValue("orig_type", "Add"))

        # A input
        had_transpose_A = False
        if n("A").check(opType.Transpose):
            self.mul_A = n("A.data").require_node()
            had_transpose_A = True
        else:
            self.mul_A = n("A").require_node()

        self.mul_A.require(CategoryCheck(BinaryOp()))
        self.mul_A.require(FusedWithQDQNode())
        self.mul_A.require(AttrValue("orig_type", "Mul"))

        # TODO: We might be able to remove more logic for finding transposes
        # as they should be fused away at this point now

        # check if self.mul_A is the only reader of that input
        if self.mul_A("A").get_reader() == self.mul_A:
            if self.mul_A("A").check(opType.Transpose):
                last_input_A = self.mul_A("A.data")
                last_input_A_scale = self.mul_A("A_scale")
                last_input_A_zero_point = self.mul_A("A_zero_point")
            else:
                raise MatcherError("no transpose found")
            if had_transpose_A:
                raise MatcherError("two transpose in branch A")
            had_transpose_A = True
        else:
            last_input_A = self.mul_A("A")
            last_input_A_scale = self.mul_A("A_scale")
            last_input_A_zero_point = self.mul_A("A_zero_point")

        self.require_tensor_equal_value(n("A_scale"), self.mul_A("C_scale"))
        self.require_tensor_equal_value(n("A_zero_point"), self.mul_A("C_zero_point"))

        # B input
        n("B").require(CategoryCheck(BinaryOp()))
        self.mul_B = n("B").require(FusedWithQDQNode()).require_node()

        self.mul_B.require(AttrValue("orig_type", "Mul"))

        self.require_tensor_equal_value(n("B_scale"), self.mul_B("C_scale"))
        self.require_tensor_equal_value(n("B_zero_point"), self.mul_B("C_zero_point"))

        had_transpose_B = False
        if self.mul_B("A").check(opType.Transpose):
            concat = self.mul_B("A.data").require_node()
            had_transpose_B = True
        else:
            concat = self.mul_B("A").require_node()

        self.require_tensor_equal_value(
            concat("concat_result_scale"), self.mul_B("A_scale")
        )
        self.require_tensor_equal_value(
            concat("concat_result_zero_point"), self.mul_B("A_zero_point")
        )

        concat.require(CategoryCheck(Concat())).require_node()
        concat.require(FusedWithQDQNode())
        concat.require(AttrValue("num_inputs", 2))
        neg = concat("input_0").require(CategoryCheck(Neg())).require_node()

        if (
            concat("input_1.data").require_node().get_name()
            != neg("X.data").require_node().get_name()
        ):
            raise MatcherError("does not lead to the same node")

        self.require_tensor_equal_value(neg("Y_scale"), concat("input_0_scale"))
        self.require_tensor_equal_value(
            neg("Y_zero_point"), concat("input_0_zero_point")
        )

        slice_concat = concat("input_1").require_node()
        if not slice_concat.get_op_type().startswith("Slice_qdq_"):
            raise MatcherError("expected Slice conncected to Concat")
        slice_neg = neg("X").require_node()
        if not slice_neg.get_op_type().startswith("Slice_qdq_"):
            raise MatcherError("expected Slice connected to Neg")

        self.require_tensor_equal_value(neg("X_scale"), concat("input_1_scale"))
        self.require_tensor_equal_value(
            neg("X_zero_point"), concat("input_1_zero_point")
        )

        if slice_concat("data").require_node() != slice_neg("data").require_node():
            raise MatcherError("does not lead to the same node")

        data_concat = slice_concat("data").require_tensor()

        # Check that the only readers of this node are slice_concat and slice_neg
        readers = data_concat.get_readers()
        expected_readers = [slice_concat, slice_neg]
        if readers == expected_readers:
            if slice_concat("data").check(opType.Transpose):
                last_input_B = slice_concat("data.data")
            else:
                raise MatcherError("no transpose found")

            if had_transpose_B:
                raise MatcherError("two transpose in branch B")
            had_transpose_B = True
        else:
            last_input_B = slice_concat("data")

        if last_input_A != last_input_B:
            raise MatcherError("Inputs to both branches are not the same")

        if last_input_A.get_shape() != last_input_B.get_shape():
            raise MatcherError("Input shapes to both branches are not the same")

        if had_transpose_A != had_transpose_B:
            raise MatcherError("not both branches have a transpose")

        self.require_tensor_equal_value(self.mul_A("A_scale"), neg("X_scale"))
        self.require_tensor_equal_value(self.mul_A("A_zero_point"), neg("X_zero_point"))

        self.sin = self.mul_B
        self.cos = self.mul_A
        self.sin_const = self.sin("B").check_initializer()
        self.cos_const = self.cos("B").check_initializer()

        self.output = n("C")
        self.output_scale = n("C_scale")
        self.output_zero_point = n("C_zero_point")

        self.input = last_input_A
        self.input_scale = last_input_A_scale
        self.input_zero_point = last_input_A_zero_point

        sin_multicast = can_multidirectional_broadcasting(
            self.input.get_shape(), self.sin("B").get_shape()
        )
        cos_multicast = can_multidirectional_broadcasting(
            self.input.get_shape(), self.cos("B").get_shape()
        )
        self.sin_input_transpose: list[int] = None
        self.cos_input_transpose: list[int] = None
        if not sin_multicast:
            self.sin_input_transpose = compute_permutation(
                self.input.get_shape(), self.sin("B").get_shape()
            )
            if self.sin_input_transpose is None:
                raise MatcherError("cannot compute sin permutation")
        if not cos_multicast:
            self.cos_input_transpose = compute_permutation(
                self.input.get_shape(), self.cos("B").get_shape()
            )
            if self.cos_input_transpose is None:
                raise MatcherError("cannot compute cos permutation")

        self.output_transpose: list[int] = None
        if self.input.get_shape() != self.output.get_shape():
            self.output_transpose = compute_permutation(
                self.output.get_shape(), self.input.get_shape()
            )
            if self.output_transpose is None:
                raise MatcherError("input and output shapes are not the same")

    def modify(self) -> None:
        n = self.n

        attributes: dict[str, Any] = {}

        type_prefix = "RoPE_"
        if not (self.sin_const and self.cos_const):
            type_prefix += "actxact_"
            attributes["sin_cos_const"] = False
        else:
            attributes["sin_cos_const"] = True

        attributes["num_of_tensor_inputs"] = 3

        new_type = (
            type_prefix
            + "qdq_"
            + self.input_zero_point.get_dtype()
            + "x"
            + self.output_zero_point.get_dtype()
        )

        # transpose sin and cos if needed
        sin_input = self.sin("B")
        if self.sin_input_transpose is not None:
            if self.sin("B").check_initializer():
                sin_input = self.add_transposed_initializer(
                    self.sin("B").require_initializer(),
                    self.sin("B").get_name() + "_transposed",
                    self.sin_input_transpose,
                )
            else:
                logging.error("transpose of non initialzer in RoPE not implemented yet")
        cos_input = self.cos("B")
        if self.cos_input_transpose is not None:
            if self.cos("B").check_initializer():
                cos_input = self.add_transposed_initializer(
                    self.cos("B").require_initializer(),
                    self.cos("B").get_name() + "_transposed",
                    self.cos_input_transpose,
                )
            else:
                logging.error("transpose of non initialzer in RoPE not implemented yet")
        # transpose output if needed
        if self.output_transpose is not None:
            input_tensor = Tensor(
                n._model_dict, n._walk_cfg, n.get_name() + "_transposed_output"
            )
            input_tensor.set_shape(self.input.get_shape(), self.output.get_dtype())
            self.add_node(
                type="Transpose",
                domain="",
                inputs={"data": input_tensor},
                outputs={"transposed": self.output},
                attributes={"perm": self.output_transpose},
            )
            self.output = input_tensor

        inputs = {
            "data": self.input,
            "sin": sin_input,
            "cos": cos_input,
            "data_scale": self.input_scale,
            "data_zero_point": self.input_zero_point,
            "sin_scale": self.sin("B_scale"),
            "sin_zero_point": self.sin("B_zero_point"),
            "cos_scale": self.cos("B_scale"),
            "cos_zero_point": self.cos("B_zero_point"),
            "C_scale": self.output_scale,
            "C_zero_point": self.output_zero_point,
        }
        outputs = {"C": self.output}
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class LinearPlusRoPE(Matcher):
    dependencies = [linear_category, RoPE()]

    def match(self):
        n = self.n
        self.linear = n.require(CategoryCheck(linear_category)).require_node()
        if not self.linear.get_op_type().startswith("MatMul_qdq_bias"):
            raise MatcherError("only Matmul with bias is supported")
        self.rope = self.linear("Y").require(CategoryCheck(RoPE())).require_node()

    def modify(self):
        new_type = "MatMul_qdq_bias_RoPE_"
        new_type += (
            self.linear("A_zero_point").get_dtype()
            + "x"
            + self.linear("B_zero_point").get_dtype()
            + "x"
            + self.rope("C_zero_point").get_dtype()
        )
        inputs = {
            # MatMul inputs
            "linear_A": self.linear("A"),
            "linear_B": self.linear("B"),
            "linear_Bias": self.linear("Bias"),
            "linear_A_scale": self.linear("A_scale"),
            "linear_A_zero_point": self.linear("A_zero_point"),
            "linear_B_scale": self.linear("B_scale"),
            "linear_B_zero_point": self.linear("B_zero_point"),
            "linear_Bias_scale": self.linear("Bias_scale"),
            "linear_Bias_zero_point": self.linear("Bias_zero_point"),
            "linear_Y_scale": self.linear("Y_scale"),
            "linear_Y_zero_point": self.linear("Y_zero_point"),
            # RoPE  inputs
            "rope_sin": self.rope("sin"),
            "rope_cos": self.rope("cos"),
            "rope_data_scale": self.rope("data_scale"),
            "rope_data_zero_point": self.rope("data_zero_point"),
            "rope_sin_scale": self.rope("sin_scale"),
            "rope_sin_zero_point": self.rope("sin_zero_point"),
            "rope_cos_scale": self.rope("cos_scale"),
            "rope_cos_zero_point": self.rope("cos_zero_point"),
            "rope_C_scale": self.rope("C_scale"),
            "rope_C_zero_point": self.rope("C_zero_point"),
        }

        outputs = {
            # Rope output
            "rope_C": self.rope("C")
        }
        attributes_linear = self.linear.get_attributes()
        attributes_rope = self.rope.get_attributes()
        self.remove_node(self.linear)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes_linear | attributes_rope,
        )
