# fmt: on
from typing import Any, Optional
from OGOAT.src.L1_fusion.py_match.checkers import (
    DTypes,
    opType,
)
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import BatchHelper
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BiasHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
    Nowhere,
    Tensor,
    Initializer,
)
from OGOAT.src.L1_fusion.L1_utils.utils import (
    onnxTensorProto_to_array,
    onnxTensorProto_from_array,
)

import numpy as np


def pack_n_bits_into_bytes(data: np.array, bits: int):
    """
    Pack an array of uint8 where each value are over N bits with
    1 < N < 8.
    No padding will be added therefore it is expected that N * data.size % 8 == 0.
    The bit packing is performed in little endiant since that what's happening for MatMulNBits weight.
    A flattened version of the packed array is returned.
    """
    assert data.dtype == np.uint8, "element type as to be a byte"
    assert bits < 8 and bits > 1, "invalid bits number requested"
    assert (data.size * bits) % 8 == 0, "Padding is needed to pack but not supported"

    flat_arr = data.flatten()
    stacked_bit_arrays = list()
    for i in range(bits):
        stacked_bit_arrays.append((flat_arr[:] >> i) & 1)

    bits_array = np.stack(stacked_bit_arrays, axis=1).flatten()
    packed_array = np.packbits(bits_array, bitorder="little")

    return packed_array


def unpack_bytes_into_n_bits(data: np.array, bits) -> np.array:
    """
    Unpack an array of bytes with the requested number of bytes N,
    where 1 < N < 8.
    The type of the input array has to be uint8 and it is expected
    to be already padded for the requested number of bits,
    if not an assert will be raised.
    This return a flatten version of the unpacked array.
    """
    assert data.dtype == np.uint8, "element type as to be a byte"
    unpack_bits = np.unpackbits(data, bitorder="little")  # get each element as 1 bit

    assert bits < 8 and bits > 1, "invalid bits number requested"
    assert (
        len(unpack_bits) % bits == 0
    ), f"cannot unpack array of shape {data.shape} and dtype {data.dtype} with number of bits {bits}"

    res = None
    for i in range(bits):
        # In little endian LSB is the first bit of the byte, take the bit every \p bits
        # and add it as element of the new array
        if res is None:
            res = unpack_bits[::bits]
        # More bits needs to be added to each element, shift on the left and add it since
        # we are unpacking in little endian the LSB should be kept first and the MSB will be last
        else:
            res |= unpack_bits[i::bits] << i

    return res


class MatMulNBits(Matcher):

    def match(self) -> None:
        self.matmul = self.n.require(opType.MatMulNBits).require_node()
        self.matmul("A").require(opType.DequantizeLinear)
        self.matmul("Y").require(opType.QuantizeLinear)

        self.matmul("A.x_zero_point").require(DTypes("uint16"))
        self.matmul("Y.y_zero_point").require(DTypes("uint16"))

        if not self.matmul("B").require_tensor().check_initializer():
            raise NoMatch("Expecting the weight input to be a constant")

        self.attributes = self.matmul.get_attributes()

        if self.matmul("scales").check(opType.DequantizeLinear):
            self.quantized_scales = True
            # Check if the scales are quantized data. It should be connected to a
            # Dequantize then to the MatMulNBits. This is the format under which
            # MSFT give us this node.
            # scales_init -> DequantizeLinear -> MatMulNBits

            self.B_scales_dq = (
                self.matmul("scales").require(opType.DequantizeLinear).require_node()
            )
            self.B_scales_dq("x").require_initializer()
        else:
            self.quantized_scales = False
            self.matmul("scales").require_initializer()

        # TODO: the bits attribute is optional, if not present we need to compute it ourselves
        if "bits" not in self.attributes:
            raise NoMatch("Require bits attribute")
        self.nb_bits = self.attributes["bits"]

        if self.nb_bits not in [2, 4]:
            raise NoMatch("Unsupported number of bits")

        # TODO: add suport for missing zero points or unpacked (not uint8) zero points
        self.matmul("zero_points").require_initializer()
        if self.matmul("zero_points").get_dtype() != "uint8":
            raise NoMatch(
                "zero points for the second input is required packed with type uint8"
            )

        if not type(self.matmul("bias")) is Nowhere:
            raise NoMatch(
                "conversion from MatMulNBits with bias to a MatMul_qdq_bias not supported"
            )

    def transpose_packed_n_bits_initializer(
        self, tensor: Tensor, bits: int, permutation: Optional[list] = None
    ):
        assert bits in [2, 4], "unsupported bit size for unpacking, tranpose, packing"
        orig_shape = tensor.get_shape()

        # If permutation not specified reverse the order of all dimensions by default
        if permutation is None:
            permutation = list(range(len(orig_shape))[::-1])
        assert len(permutation) == len(orig_shape)
        transposed_shape = np.array(orig_shape)[permutation].tolist()

        # Extract the data as uint8 numpy array
        data, dtype = onnxTensorProto_to_array(
            self._get_model_dict().get_initializer(tensor.get_name())
        )
        assert dtype == "uint8", "Weight can only be uint8"

        # Unpack into the required number of bits
        unpacked_shape = orig_shape[:-1] + [(orig_shape[-1] * 8) // bits]
        unpacked_data = unpack_bytes_into_n_bits(data, bits).reshape(unpacked_shape)

        # Transpose the unpacked array and pack it again into bytes
        transposed_unpacked_data = unpacked_data.transpose(permutation)
        transposed_packed_data = pack_n_bits_into_bytes(
            transposed_unpacked_data, bits
        ).reshape(transposed_shape)

        # Create a new initializer from the transposed data
        new_initializer_name = tensor.get_name() + "_transposed"
        initializer_new = onnxTensorProto_from_array(
            transposed_packed_data, new_initializer_name, og_dtype=dtype
        )
        self._get_model_dict().add_initializer(initializer_new)

        return Initializer(
            tensor._model_dict,
            tensor._walk_cfg,
            new_initializer_name,
        )

    def modify(self) -> None:
        # Transpose the weights, its scales and its zp
        new_B_input = self.transpose_packed_n_bits_initializer(
            self.matmul("B"),
            self.nb_bits,
            permutation=[
                1,
                2,
                0,
            ],  # from [N, k_blocks, blob_size] to [k_blocks, blob_size, N]
        )
        new_B_zp = self.transpose_packed_n_bits_initializer(
            self.matmul("zero_points"), self.nb_bits
        )
        if self.quantized_scales:
            new_B_scale = self.add_transposed_initializer(
                self.B_scales_dq("x"), self.B_scales_dq("x").get_name() + "_transposed"
            )
        else:
            new_B_scale = self.add_transposed_initializer(
                self.matmul("scales"), self.matmul("scales").get_name() + "_transposed"
            )

        inputs = {
            "A": self.matmul("A.x"),
            "B": new_B_input,
            "A_scale": self.matmul("A.x_scale"),
            "A_zero_point": self.matmul("A.x_zero_point"),
            "B_scale_quant": new_B_scale,  # Transposed quantized scales
            "B_scale_quant_scale": None,  # scale of the B scales
            "B_scale_quant_zero_point": None,  # zp of the the B scales
            "B_zero_point": new_B_zp,
            "Y_scale": self.matmul("Y.y_scale"),
            "Y_zero_point": self.matmul("Y.y_zero_point"),
        }
        if self.quantized_scales:
            inputs["B_scale_quant_scale"] = self.B_scales_dq("x_scale")
            inputs["B_scale_quant_zero_point"] = self.B_scales_dq(
                "x_zero_point"
            )
        outputs = {"Y": self.matmul("Y.y")}

        A_dtype = self.matmul("A.x_zero_point").get_dtype()
        B_dtype = "uint" + str(self.nb_bits)
        Y_dtype = self.matmul("Y.y_zero_point").get_dtype()
        new_op_type = "MatMul_qdq_" + "x".join([A_dtype, B_dtype, Y_dtype])

        # TODO: compute the number of batches
        num_batches = 1
        self.attributes["quantized_scales"] = self.quantized_scales
        self.remove_node(self.matmul)
        self.add_node(
            type=new_op_type,
            inputs=inputs,
            outputs=outputs,
            domain="ai.onnx.contrib",
            attributes=self.attributes
            | MatMul.get_matmul_attributes(num_batches=num_batches),
        )


class MatMul(Matcher, BiasHelper, TransposeHelper, QDQHelper):
    """
    This merges the following classic patterns:
    MatMulBias
    Matmul
    MatmulActAct
    """

    @staticmethod
    def get_matmul_attributes(
        has_bias: bool = False,
        num_batches: int = 1,
        disable_q: bool = False,
        is_actxact: bool = False,
    ) -> dict[str, Any]:
        attributes = {
            "num_batches": num_batches,
            "disable_q": int(disable_q),
            "num_of_tensor_inputs": 2 + int(has_bias),
        }

        # FIXME: ideally we would have always the same attrs generated for a matmul qdq, only
        # their values would change
        if has_bias:
            attributes |= {"trans_to_nchw": 0, "nchw_act": 1, "bias": 1}
        if is_actxact:
            attributes["actxact"] = 1

        return attributes

    def match(self) -> None:
        n = self.n
        n.require(opType.MatMul)
        n("A").require(opType.DequantizeLinear)
        n("B").require(opType.DequantizeLinear)
        n("Y").require(opType.QuantizeLinear)
        self.in1_dq = n("A")  # dequantize node at first overall input
        self.in2_dq = n("B")  # dequantize node at second overall input
        self.out_q = n("Y")  # quantize node at overall output

        # check for act x act: no input is connected to an initializer
        self.is_act_act = (
            not self.in1_dq("x").check_initializer()
            and not self.in2_dq("x").check_initializer()
        )

        # Check for a bias after the matmul.
        # MatMul actxact + bias is not supported so we do not fuse with the bias in that case.
        self.bias_elem = None
        if not self.is_act_act:
            _, self.bias_elem, self.out_q = self.get_bias_elem_in_and_out_tensor(
                self.out_q
            )
        if self.bias_elem is not None:
            self.require_qdq_equal_scale_zeropoint(n("Y.y"), n("Y"))

        # data types at overall inputs/outputs
        if self.is_act_act:
            self.in1_dq("x_zero_point").require(
                DTypes(*n.get_model_activation_dtype_sorted_list())
            )
            self.in2_dq("x_zero_point").require(
                DTypes(*n.get_model_activation_dtype_sorted_list())
            )
        else:
            self.in1_dq("x_zero_point").require(
                DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
            )
            self.in2_dq("x_zero_point").require(
                DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
            )

        if self.out_q.check(opType.QuantizeLinear):
            self.out_q("y_zero_point").require(
                DTypes("int4", "int8", "int16", "uint4", "uint8", "uint16")
            )
            self.out_type = self.out_q("y_zero_point").require_tensor().get_dtype()
        else:
            self.out_type = n.get_model_activation_dtype()

    def modify(self) -> None:
        n = self.n

        type_name = "MatMul_qdq_"
        if self.is_act_act:
            type_name += "actxact_"
        if self.bias_elem is not None:
            type_name += "bias_"

        new_type = (
            type_name
            + self.in1_dq("x_zero_point").get_dtype()
            + "x"
            + self.in2_dq("x_zero_point").get_dtype()
            + "x"
            + self.out_type
        )

        inputs = {
            "A": self.in1_dq("x"),
            "B": self.in2_dq("x"),
            "Bias": None,
            "A_scale": self.in1_dq("x_scale"),
            "A_zero_point": self.in1_dq("x_zero_point"),
            "B_scale": self.in2_dq("x_scale"),
            "B_zero_point": self.in2_dq("x_zero_point"),
            "Bias_scale": None,
            "Bias_zero_point": None,
            "Y_scale": None,
            "Y_zero_point": None,
        }

        q_shape = self.n("A").get_shape()
        k_shape = self.n("B").get_shape()
        num_batches = BatchHelper.extract_matmul_batch_nb(q_shape, k_shape)
        has_bias = self.bias_elem is not None
        disable_q = has_bias and not self.out_q.check(opType.QuantizeLinear)

        attributes: dict[str, Any] = MatMul.get_matmul_attributes(
            has_bias=has_bias,
            disable_q=disable_q,
            num_batches=num_batches,
            is_actxact=self.is_act_act,
        )

        if disable_q:
            outputs = {
                "Y": self.out_q,
            }
        else:
            inputs["Y_scale"] = self.out_q("y_scale")
            inputs["Y_zero_point"] = self.out_q("y_zero_point")
            outputs = {
                "Y": self.out_q("y"),
            }

        if self.bias_elem is not None:
            inputs["Bias"] = self.bias_elem("x")
            inputs["Bias_scale"] = self.bias_elem("x_scale")
            inputs["Bias_zero_point"] = self.bias_elem("x_zero_point")

        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
