# fmt: on
import numpy as np
from OGOAT.src.L1_fusion.py_match.checkers import DTypeAny, DTypes, opType
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BiasHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import (
    TransposeHelper,
)


class ConvTransposeToConv(Matcher, BiasHelper, TransposeHelper):

    def match(self) -> None:
        n = self.n
        n.require(opType.ConvTranspose)

        self.has_bias = n("B").check_tensor()
        self.input = n("X")
        self.output = n("Y")
        self.input, self.output = self.require_nchw_conversion(self.input, self.output)

        if self.has_bias:
            n("B").require(opType.DequantizeLinear)
            n("B.x_zero_point").require(DTypeAny())
            self.bias_elem = n("B")
        else:
            n("B").require_nowhere()
            self.has_add_bias, self.bias_elem, self.output = (
                self.get_bias_elem_in_and_out_tensor(self.output)
            )

        n("W").require(opType.DequantizeLinear)
        self.input("x_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )
        n("Y.y_zero_point").require(
            DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
        )

        self.kernel_shape = n.get_attribute_value("kernel_shape")
        self.strides = n.get_attribute_value("strides")

        if self.kernel_shape != self.strides:
            raise NoMatch(
                f"ConvTrabspose: Kernel {self.kernel_vec} and Stride {self.stride_vec} not equal"
            )

    def modify(self) -> None:
        n = self.n
        dtypes = (
            self.input("x_zero_point").get_dtype()
            + "x"
            + n("W.x_zero_point").get_dtype()
        )

        new_type = "Conv_qdq_"
        if self.has_bias or self.has_add_bias:
            new_type = "Conv_qdq_bias_"

        new_type += dtypes + "x" + n("Y.y_zero_point").get_dtype()

        prev_attributes = n.get_attributes()
        Sy = prev_attributes["kernel_shape"][0]
        Sx = prev_attributes["kernel_shape"][1]
        wgt_dtype = n("W.x").get_dtype()

        weight = n("W.x").require_initializer()
        weight_value = weight.get_value_as_array()

        perm = [0, 2, 3, 1]
        transposed_weight = weight_value.transpose(perm)
        transposed_shape = transposed_weight.shape

        reshaped_shape = tuple(
            map(
                int,
                (
                    1,
                    1,
                    transposed_shape[0],
                    transposed_shape[1] * transposed_shape[2] * transposed_shape[3],
                ),
            )
        )

        reshaped_weight = transposed_weight.reshape(reshaped_shape)

        new_initializer_name = n("W.x").get_name() + "_trans_reshaped"

        new_weight_init = self.add_initializer(
            new_initializer_name, reshaped_weight, wgt_dtype
        )

        inputs = {
            "A": self.input("x"),
            "B": new_weight_init,
            "Bias": None,
            "A_scale": self.input("x_scale"),
            "A_zero_point": self.input("x_zero_point"),
            "B_scale": n("W.x_scale"),
            "B_zero_point": n("W.x_zero_point"),
            "Bias_scale": None,
            "Bias_zero_point": None,
            "Y_scale": n("Y.y_scale"),
            "Y_zero_point": n("Y.y_zero_point"),
        }

        attributes = {}
        attributes["num_of_tensor_inputs"] = 2
        attributes["trans_to_nchw"] = 0

        if self.has_bias or self.has_add_bias:
            bias = self.bias_elem("x").require_initializer()
            bias_value = bias.get_value_as_array().reshape(-1)
            reshaped_bias = np.tile(bias_value, int(Sy * Sx))
            bias_dtype = self.bias_elem("x").get_dtype()
            new_initializer_name = self.bias_elem("x").get_name() + "_expanded"
            inputs["Bias"] = self.add_initializer(
                new_initializer_name, reshaped_bias, bias_dtype
            )

            attributes |= {"nchw_act": 1, "bias": 1}
            attributes["num_of_tensor_inputs"] = 3
            inputs["Bias_scale"] = self.bias_elem("x_scale")
            inputs["Bias_zero_point"] = self.bias_elem("x_zero_point")
            inputs["Y_scale"] = self.output("y_scale")
            inputs["Y_zero_point"] = self.output("y_zero_point")

        output_tensor = self.output("y")
        out_dtype = output_tensor.get_dtype()

        temp_shape = self.output("y").get_shape()
        Xo_new = temp_shape[1] / Sx

        old_output_shape = [
            temp_shape[0],
            Xo_new,
            Xo_new,
            temp_shape[-1] * Sx * Sx,
        ]
        old_output_shape = list(map(int, old_output_shape))

        old_output = Tensor(n._model_dict, WalkCfgPlain(), f"{n.get_name()}_old_output")
        old_output.set_shape(old_output_shape, out_dtype)

        reshape_0_shape = [
            old_output_shape[0],
            old_output_shape[1],
            old_output_shape[2],
            Sy,
            Sx,
            old_output_shape[3] / Sy / Sx,
        ]
        reshape_0_shape = list(map(int, reshape_0_shape))

        reshape_0_initializer = self.add_initializer(
            f"{n.get_name()}_reshape_0_shape", np.array(reshape_0_shape), out_dtype
        )
        reshape_0_inputs = {"data": old_output, "shape": reshape_0_initializer}

        reshape_0_output = Tensor(
            n._model_dict, WalkCfgPlain(), f"{n.get_name()}_reshape_0_output"
        )

        reshape_0_output.set_shape(reshape_0_shape, out_dtype)

        self.add_node(
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs=reshape_0_inputs,
            outputs={"reshaped": reshape_0_output},
            attributes={},
            new_name=f"{n.get_name()}_Reshaped_0",
        )

        transpose_shape = [
            reshape_0_shape[0],
            reshape_0_shape[1],
            reshape_0_shape[3],
            reshape_0_shape[2],
            reshape_0_shape[4],
            reshape_0_shape[5],
        ]
        transpose_shape = list(map(int, transpose_shape))

        tranpose_output = Tensor(
            n._model_dict, WalkCfgPlain(), f"{n.get_name()}_transpose_output"
        )

        tranpose_output.set_shape(list(map(int, transpose_shape)), out_dtype)

        self.add_node(
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs={"data": reshape_0_output},
            outputs={"transposed": tranpose_output},
            attributes={"perm": [0, 1, 3, 2, 4, 5]},
            new_name=f"{n.get_name()}_Transpose",
        )

        output_shape = output_tensor.get_shape()
        reshape_1_initializer = self.add_initializer(
            f"{n.get_name()}_reshape_1_shape", np.array(output_shape), out_dtype
        )
        reshape_1_inputs = {
            "data": tranpose_output,
            "shape": reshape_1_initializer,
        }

        self.add_node(
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs=reshape_1_inputs,
            outputs={"reshaped": output_tensor},
            attributes={},
            new_name=f"{n.get_name()}_Reshaped_1",
        )

        outputs = {"Y": old_output}

        copy_attributes = n.get_attributes()
        pad_vec = n.get_attribute_value("pads")
        if pad_vec is None:
            # default: no padding at begin and end of each dimension
            pad_vec = [0] * len(n("X").require_tensor().get_shape()) * 2

        self.remove_node(n)

        if len(pad_vec) == 2:
            # Conv1d
            attributes |= {
                "pads_h_beg": pad_vec[0],
                "pads_h_end": pad_vec[1],
                "strides_h": 1,
            }
        else:
            # Conv2d
            attributes |= {
                "pads_h_beg": pad_vec[0],
                "pads_w_beg": pad_vec[1],
                "pads_h_end": pad_vec[2],
                "pads_w_end": pad_vec[3],
                "strides_h": 1,
                "strides_w": 1,
            }

        attributes = attributes | copy_attributes
        attributes["trans_to_nchw"] = 0
        attributes["kernel_shape"] = attributes["strides"] = [1, 1]

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
