from onnx import defs
import onnxruntime.capi._pybind_state as rt
from dataclasses import dataclass
from typing import Optional, Sequence

@dataclass
class OperatorInfo:
    name: str
    domain: str


class OnnxOpsWrapper:

    def __init__(self):
        self._operator_infos: list[OperatorInfo] = [
            OperatorInfo(schema.name, schema.domain)
            for schema in defs.get_all_schemas_with_history()
        ]
        self._operator_names: list[str] = [
            op_info.name for op_info in self._operator_infos
        ]
        self._input_names: dict[str, list[str]] = {
            op_info.name: [
                input.name
                for input in defs.get_schema(op_info.name, op_info.domain).inputs
            ]
            for op_info in self._operator_infos
        }
        self._output_names: dict[str, list[str]] = {
            op_info.name: [
                output.name
                for output in defs.get_schema(op_info.name, op_info.domain).outputs
            ]
            for op_info in self._operator_infos
        }
        self._input_parameters: dict[str, list[defs.OpSchema.FormalParameter]] = {
            op_info.name: defs.get_schema(op_info.name, op_info.domain).inputs
            for op_info in self._operator_infos
        }
        self._output_parameters: dict[str, list[defs.OpSchema.FormalParameter]] = {
            op_info.name: defs.get_schema(op_info.name, op_info.domain).outputs
            for op_info in self._operator_infos
        }

        # add contrib ops from onnxruntime
        for schema in rt.get_all_operator_schema():
            if schema.name in self.get_operator_names():
                continue
            onnx_schema = onnx_rt_schema_to_onnx_schema(schema)
            self.register_schema(onnx_schema)

    def get_operator_names(self) -> list[str]:
        """ "
        Returns a list of operator names
        """
        return self._operator_names

    def get_input_names(self, op_type: str) -> list[str]:
        """
        Returns a list of input names of the corresponding operator type
        """
        # FIXME workaround for enumerating variadic inputs of concat
        if op_type == "Concat":
            return [f"input{i}" for i in range(512)]

        return self._input_names[op_type]

    def get_input_idx(self, op_type: str, input_name: str) -> int:
        """
        Returns the index for a given input name of the corresponding operator type
        """
        return self.get_input_names(op_type).index(input_name)

    def get_input_parameters(self, op_type: str) -> list[defs.OpSchema.FormalParameter]:
        """
        Returns the list of input parameters of the corresponding operator type
        """
        return self._input_parameters[op_type]

    def get_output_names(self, op_type: str) -> list[str]:
        """
        Returns a list of output names of the corresponding operator type
        """
        return self._output_names[op_type]

    def get_output_idx(self, op_type: str, output_name: str) -> int:
        """
        Returns the index for a given output name of the corresponding operator type
        """
        return self.get_output_names(op_type).index(output_name)

    def get_output_parameters(
        self, op_type: str
    ) -> list[defs.OpSchema.FormalParameter]:
        """
        Returns the list of input parameters of the corresponding operator type
        """
        return self._output_parameters[op_type]

    def get_prm_idx_by_name(
        self, prms: Sequence[defs.OpSchema.FormalParameter], prm_name: str
    ) -> Optional[int]:
        """
        Search a sequence of formal operator parameters of an ONNX operator schema
        for a certain formal parameter name.
        Return the index of this parameter if found.
        Return None if there is no parameter with this name.
        """
        for idx, prm in enumerate(prms):
            if prm.option in (
                defs.OpSchema.FormalParameterOption.Single,
                defs.OpSchema.FormalParameterOption.Optional,
            ):
                if prm.name == prm_name:
                    return idx
                continue
            if prm.option == defs.OpSchema.FormalParameterOption.Variadic:
                # mimic what "input_" for "Concat" has in YML by now
                p_name = prm.name
                if p_name.endswith("s"):
                    p_name = p_name[:-1]
                    if p_name == prm_name[:-1]:
                        return int(prm_name[-1])
                return None  # no other params can follow after variadic one
            assert (
                False
            ), f"get_prm_idx_by_name does not implement prm.option {prm.option} yet"
        return None  # param not found

    def get_input_prm_idx_by_name(self, op_type: str, prm_name: str) -> Optional[int]:
        return self.get_prm_idx_by_name(self.get_input_parameters(op_type), prm_name)

    def get_output_prm_idx_by_name(self, op_type: str, prm_name: str) -> Optional[int]:
        return self.get_prm_idx_by_name(self.get_output_parameters(op_type), prm_name)

    def register_schema(self, schema: defs.OpSchema) -> None:
        assert (
            schema.name not in self._operator_names
        ), f"OpSchema(name={schema.name}) already registered"
        self._operator_infos.append(OperatorInfo(schema.name, schema.domain))
        self._operator_names.append(schema.name)
        self._input_names[schema.name] = [input.name for input in schema.inputs]
        self._output_names[schema.name] = [output.name for output in schema.outputs]
        self._input_parameters[schema.name] = list(schema.inputs)
        self._output_parameters[schema.name] = list(schema.outputs)

        # defs manage a global list of schemas ()
        if defs.has(op_type=schema.name, domain=schema.domain):
            return
        defs.register_schema(schema)


def dtype_to_ops_type(type: str, shape: list[int]) -> str:
    if type == "float32":
        type = "float"

    # NOTE I haven't found in the onnx documentation how to convert the TensorProto to type string for OpSchemas
    assert shape is not None, f"shape is None"

    if len(shape) != 0:
        type = f"tensor({type})"

    return type


def onnx_rt_schema_to_onnx_schema(schema) -> defs.OpSchema:
    """
    Convert an OpSchema from onnxruntime to an OpSchema from onnx
    """
    inputs: list[defs.OpSchema.FormalParameter] = []
    for input in schema.inputs:
        inputs.append(
            defs.OpSchema.FormalParameter(name=input.name, type_str=input.typeStr)
        )

    outputs: list[defs.OpSchema.FormalParameter] = []
    for output in schema.outputs:
        outputs.append(
            defs.OpSchema.FormalParameter(name=output.name, type_str=output.typeStr)
        )

    type_constraints: list[tuple[str, list[str], str]] = []
    for constraint in schema.type_constraints:
        type_constraints.append(
            (
                constraint.type_param_str,
                constraint.allowed_type_strs,
                constraint.description,
            )
        )

    attributes: list[defs.OpSchema.Attribute] = []
    for attribute in schema.attributes:
        attr_type = defs.OpSchema.AttrType(schema.attributes[attribute].type)
        attributes.append(defs.OpSchema.Attribute(name=attribute, type=attr_type))

    onnx_schema = defs.OpSchema(
        name=schema.name,
        domain=schema.domain,
        since_version=schema.since_version,
        inputs=inputs,
        outputs=outputs,
        type_constraints=type_constraints,
        attributes=attributes,
    )
    return onnx_schema
