# fmt: on
import math
import numpy as np
from OGOAT.src.L1_fusion.py_match.basic.conv_transpose_to_conv import (
    ConvTransposeToConv,
)
from OGOAT.src.L1_fusion.py_match.checkers import DTypeAny, DTypes, opType
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BiasHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Category,
    Matcher,
    NoMatch,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import (
    TransposeHelper,
)


class Conv(Matcher, BiasHelper, TransposeHelper):
    """
    ConvWBias
    ConvWoBias
    FIXME: ConvWoBias_T (currently disabled)
    """

    def match(self) -> None:
        n = self.n
        n.require(opType.Conv)
        self.has_bias = n("B").check_tensor()
        self.is_bias_init = n("B").check_initializer()

        self.input = n("X")
        self.output = n("Y")
        # Check for Nchw transpose conversion around the input and output nodes
        self.input, self.output = self.require_nchw_conversion(
            self.input, self.output
        )

        if self.has_bias:
            if not self.is_bias_init:
                n("B").require(opType.DequantizeLinear)
                n("B.x_zero_point").require(DTypeAny())
            self.bias_elem = n("B")
        else:
            n("B").require_nowhere()
            self.has_add_bias, self.bias_elem, self.output = (
                self.get_bias_elem_in_and_out_tensor(self.output)
            )

        n("W").require(opType.DequantizeLinear)

        n("W.x_zero_point").require(DTypeAny())
        self.input("x_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )
        n("Y.y_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )
        self.stride_vec = n.get_attribute_value("strides")
        if self.stride_vec is None:
            # default: stride 1 in each spatial axis (first two dimensions are not spatial)
            self.stride_vec = [1] * (
                len(n("X").require_tensor().get_shape()) - 2
            )

        self.kernel_vec = n.get_attribute_value("kernel_shape")
        if self.kernel_vec is None:
            self.kernel_vec = n("W").require_tensor().get_shape()[2:]

        # support kernel 3x3, 5x5, 7x7, 9x9 and stride 1x1, 2x2 only
        supported_kernels = [[1, 1], [2, 2], [3, 3], [5, 5], [7, 7], [9, 9]]
        supported_strides = [[1, 1], [2, 2], [4, 4]]
        # kernel and strides are equal -> ConvToMatMul
        self.eq_kernel_strides = self.kernel_vec == self.stride_vec
        if (
            self.kernel_vec not in supported_kernels
            or self.stride_vec not in supported_strides
        ) and not self.eq_kernel_strides:
            raise NoMatch(
                f"Kernel {self.kernel_vec} or Stride {self.stride_vec} not supported"
            )

    def modify(self) -> None:
        n = self.n
        dtypes = (
            self.input("x_zero_point").get_dtype()
            + "x"
            + n("W.x_zero_point").get_dtype()
        )
        new_type = "Conv_qdq_"
        if self.has_bias or self.has_add_bias:
            new_type = "Conv_qdq_bias_"

        new_type += dtypes + "x" + n("Y.y_zero_point").get_dtype()

        inputs = {
            "A": self.input("x"),
            "B": n("W.x"),
            "Bias": None,
            "A_scale": self.input("x_scale"),
            "A_zero_point": self.input("x_zero_point"),
            "B_scale": n("W.x_scale"),
            "B_zero_point": n("W.x_zero_point"),
            "Bias_scale": None,
            "Bias_zero_point": None,
            "Y_scale": n("Y.y_scale"),
            "Y_zero_point": n("Y.y_zero_point"),
        }

        attributes = n.get_attributes()
        factor = self.kernel_vec[0]

        attributes["num_of_tensor_inputs"] = 2
        is_not_conv_to_matmul = self.eq_kernel_strides and not (
            self.kernel_vec == [1, 1] and self.stride_vec == [1, 1]
        )
        if n("W.x").check_initializer():
            new_initializer_name = n("W.x").get_name() + "_trans"
            initializer_transposed = self.add_transposed_initializer(
                n("W.x").require_initializer(), new_initializer_name
            )
            if is_not_conv_to_matmul:
                transposed_shape = initializer_transposed.get_shape()
                transposed_value = initializer_transposed.get_value_as_array()
                dtype = initializer_transposed.get_dtype()

                reshape_shape = map(
                    int,
                    (
                        transposed_shape[0] / factor,
                        transposed_shape[1] / factor,
                        math.prod(transposed_shape[:3]),
                        transposed_shape[-1],
                    ),
                )
                reshaped_value = transposed_value.reshape(tuple(reshape_shape))

                initializer_transposed = self.add_initializer(
                    initializer_transposed.get_name() + "_reshaped",
                    reshaped_value,
                    dtype,
                )

            inputs["B"] = initializer_transposed

        if self.has_bias or self.has_add_bias:
            attributes |= {"trans_to_nchw": 0, "nchw_act": 1, "bias": 1}
            attributes["num_of_tensor_inputs"] = 3
            if self.is_bias_init:
                inputs["Bias"] = self.bias_elem
            else:
                inputs["Bias"] = self.bias_elem("x")
                inputs["Bias_scale"] = self.bias_elem("x_scale")
                inputs["Bias_zero_point"] = self.bias_elem("x_zero_point")
            inputs["Y_scale"] = self.output("y_scale")
            inputs["Y_zero_point"] = self.output("y_zero_point")

        outputs = {"Y": self.output("y")}

        pad_vec = n.get_attribute_value("pads")
        if pad_vec is None:
            # default: no padding at begin and end of each dimension
            pad_vec = [0] * len(n("X").require_tensor().get_shape()) * 2
        self.remove_node(n)

        if len(pad_vec) == 2:
            # Conv1d
            attributes |= {
                "pads_h_beg": pad_vec[0],
                "pads_h_end": pad_vec[1],
                "strides_h": self.stride_vec[0],
            }
        else:
            # Conv2d
            attributes |= {
                "pads_h_beg": pad_vec[0],
                "pads_w_beg": pad_vec[1],
                "pads_h_end": pad_vec[2],
                "pads_w_end": pad_vec[3],
                "strides_h": self.stride_vec[0],
                "strides_w": self.stride_vec[1],
            }

        attributes["trans_to_nchw"] = 0

        if is_not_conv_to_matmul:
            temp_shape = self.input.input().get_shape()
            dtype = self.input.input().get_dtype()
            temp = temp_shape[1] / factor
            reshape_0_shape = [temp, factor, temp, factor, temp_shape[-1]]
            reshape_0_initializer = self.add_initializer(
                f"{n.get_name()}_0_shape",
                np.array(reshape_0_shape, dtype),
                self.input.input().get_dtype_raw(),
            )
            reshape_0_inputs = {
                "data": self.input.input(),
                "shape": reshape_0_initializer,
            }

            reshape_0_output = Tensor(
                n._model_dict,
                WalkCfgPlain(),
                f"{n.get_name()}_reshape_0_output",
            )

            reshape_0_output.set_shape(list(map(int, reshape_0_shape)), dtype)

            self.add_node(
                type="Reshape",
                domain="ai.onnx.contrib",
                inputs=reshape_0_inputs,
                outputs={"reshaped": reshape_0_output},
                attributes={},
            )

            tranpose_output = Tensor(
                n._model_dict,
                WalkCfgPlain(),
                f"{n.get_name()}_transpose_output",
            )

            transpose_shape = [
                reshape_0_shape[0],
                reshape_0_shape[2],
                reshape_0_shape[1],
                reshape_0_shape[3],
                reshape_0_shape[4],
            ]
            tranpose_output.set_shape(list(map(int, transpose_shape)), dtype)

            self.add_node(
                type="Transpose",
                domain="ai.onnx.contrib",
                inputs={"data": reshape_0_output},
                outputs={"transposed": tranpose_output},
                attributes={"perm": [0, 2, 1, 3, 4]},
                new_name=f"{n.get_name()}_Reshape_0",
            )

            reshape_1_shape = [
                1,
                transpose_shape[0],
                transpose_shape[1],
                math.prod(transpose_shape[2:]),
            ]
            reshape_1_initializer = self.add_initializer(
                f"{n.get_name()}_1_shape",
                np.array(reshape_1_shape, dtype),
                self.input.input().get_dtype_raw(),
            )
            reshape_1_inputs = {
                "data": tranpose_output,
                "shape": reshape_1_initializer,
            }

            reshape_1_output = Tensor(
                n._model_dict,
                WalkCfgPlain(),
                f"{n.get_name()}_reshape_1_output",
            )

            reshape_1_output.set_shape(list(map(int, reshape_1_shape)), dtype)

            self.add_node(
                type="Reshape",
                domain="ai.onnx.contrib",
                inputs=reshape_1_inputs,
                outputs={"reshaped": reshape_1_output},
                attributes={},
                new_name=f"{n.get_name()}_Reshape_1",
            )

            inputs["A"] = reshape_1_output
            attributes["kernel_shape"] = attributes["strides"] = [1, 1]

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


conv_category = Category([ConvTransposeToConv(), Conv()])
