# fmt: on
import numpy as np

from OGOAT.src.L1_fusion.py_match.basic.conv import conv_category
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BiasHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import (
    TransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, NoMatch


class GroupConvToConv(Matcher, BiasHelper, TransposeHelper):
    """
    Converts a group conv to a standard conv by expanding the weights.
    """

    dependencies = [conv_category]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck(conv_category))
        self.group = n.get_attributes().get("group")
        if self.group is None or self.group == 1:
            raise NoMatch(
                "GroupConvToConv does not support group == 1, found group: {}".format(
                    self.group
                )
            )
        self.B = n("B").require_initializer()

    def modify(self) -> None:
        n = self.n
        # 3 3 1 16 from 16 1 3 3
        weight_array = self.B.get_value_as_array()
        # 16 1 3 3
        Co, Ci_per_Group, kH, kW = weight_array.T.shape
        Weight_Transposed = weight_array.T  # (Output_channel, Ci_per_Group, kH, kW)
        # 1, 112, 112, 16 from N, H, W, C
        N,H,W,Channel_input = n("A").require_tensor().get_shape()
        # 1, 112, 112, 16 from N, H, W, C
        N,H,W,Co = n("Y").require_tensor().get_shape()
        Ci_per_Group = Channel_input // self.group
        assert Co % self.group == 0, "O must be divisible by groups"

        # Expand weights
        Co_per_Group = Co // self.group  # output channels per group
        CI_expanded = Ci_per_Group * self.group   # total input channels after expansion

        weight_zero_point_value = self.n("B_zero_point").require_initializer().get_value()
        # Initialize with zeros
        W_expanded = np.full((Co, CI_expanded, kH, kW), weight_zero_point_value, dtype=Weight_Transposed.dtype)

        # For each output channel, place its group slice at the correct input range
        for o in range(Co):
            g = o // Co_per_Group               # group index of this output
            in_start = g * Ci_per_Group         # beginning of group's input slice
            # Copy the Ipg input-channel kernels for this output
            W_expanded[o, in_start:in_start + Ci_per_Group, :, :] = Weight_Transposed[o, :, :, :] # (O, I=Ipg*G, kH, kW)
            pass

        expanded_weights = W_expanded.T

        inputs = n.get_inputs_dict()
        outputs = n.get_outputs_dict()
        inputs["B"] = self.add_initializer(
            self.B.get_name() + "_group", expanded_weights
        )

        copy_attributes = n.get_attributes()
        copy_attributes["group"] = 1
        op_type = n.get_op_type()
        self.remove_node(n)
        self.add_node(
            type=op_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )
