# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    Node,
    NoMatch,
    Inputs,
    Outputs,
)
from typing import Optional
import numpy as np


class RemoveSliceConcatRuntime(Matcher):

    def match(self) -> None:
        # Find a split -> [ no-ops] -> concat chain
        self.split_node: Node = self.n.require(opType.Split_runtime).require_node()
        split_axis = self.split_node.get_attribute_value("axis")

        if split_axis != 0:
            raise NoMatch("Split axis should be on the outermost dimension")

        self.concat_node: Optional[Node] = None

        # Make sure that the split is connected to only one concat
        for split_output in self.split_node.get_outputs():
            concat: Node = split_output.require(opType.Concat_runtime).require_node()
            if self.concat_node is None:
                self.concat_node = concat
                continue

            if concat != self.concat_node:
                raise NoMatch("Split outputs should be connected to the same concat")

        if self.concat_node.get_attribute_value("axis") != 0:
            raise NoMatch("Concat axis should be on the outermost dimension")

        split_outputs: Outputs = self.split_node.get_outputs()
        concat_inputs: Inputs = self.concat_node.get_inputs()
        if len(concat_inputs) != len(split_outputs):
            raise NoMatch("All output of the split should be connected to the concat")

        # check that each output of the split are connected to the inputs of
        # the concat in the same order
        for idx in range(len(concat_inputs)):
            if split_outputs[idx].skip() != concat_inputs[idx]:
                raise NoMatch(
                    "Split order and concat order should match, in order to avoid data transposes"
                )

    def modify(self) -> None:
        input_shape = self.split_node("input").get_shape()
        output_shape = self.concat_node("concat_result").get_shape()
        name = self.n.get_name() + "_shape_" + "x".join(list(map(str, output_shape)))
        shape = self.add_initializer(name, np.array(output_shape).astype(np.int64))

        # Connect the two tensor with a reshape if the shape as changed
        if input_shape != output_shape:
            self.add_node(
                "Reshape_noop",
                "ai.onnx.contrib",
                {"data": self.split_node("input"), "shape": shape},
                {"reshaped": self.concat_node("concat_result")},
                new_name=self.n.get_name() + "_reshape",
            )
        else:
            self.connect(self.concat_node("concat_result"), self.split_node("input"))

        self.remove_node(self.split_node)
