"""
This module processes ONNX graph and creates a DAG
"""

from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import onnx
from onnx import (
    AttributeProto,
    GraphProto,
    ModelProto,
    NodeProto,
    TensorProto,
    ValueInfoProto,
)

# pylint: disable-next=import-error,no-name-in-module
from graph.utilities import logger  # type: ignore

# Type alias for ONNX attributes
# Runtime operations use scalar and list attributes for configuration
OnnxAttributes = dict[str, int | float | str | list[int] | list[float] | list[str]]

IGNORED_OPS = {
    "QuantizeLinear",
    "DequantizeLinear",
    "Relu",
    "LeakyRelu",
    "Shape",
    "Cast",
    "Identity",
    "Flatten",
}

Dim = Union[int, str, None]
Shape = List[Dim]


def _safe_node_id(node: NodeProto, idx: int) -> str:
    base = node.name.strip() if node.name else f"{node.op_type}_{idx}"
    return base.replace(":", "_").replace("/", "_")


def _collect_graph_ios(
    g: GraphProto,
) -> Tuple[Set[str], Set[str], Dict[str, TensorProto]]:
    inputs = {vi.name for vi in g.input}
    outputs = {vi.name for vi in g.output}
    inits = {init.name: init for init in g.initializer}
    return inputs, outputs, inits


def _iter_subgraphs(node: NodeProto) -> Iterator[Tuple[GraphProto, str]]:
    for attr in node.attribute:
        if attr.type == AttributeProto.GRAPH:
            yield attr.g, f"{node.op_type}/{attr.name}"
        elif attr.type == AttributeProto.GRAPHS:
            for i, sg in enumerate(attr.graphs):
                yield sg, f"{node.op_type}/{attr.name}[{i}]"


def _dtype_to_str(elem_type: Any) -> str:
    # Prefer ONNX helper if available
    try:
        return onnx.helper.tensor_dtype_to_string(elem_type)
    except Exception:  # pylint: disable=broad-except
        # Fallback minimal map
        m = {
            TensorProto.FLOAT: "float32",
            TensorProto.UINT8: "uint8",
            TensorProto.INT8: "int8",
            TensorProto.UINT16: "uint16",
            TensorProto.INT16: "int16",
            TensorProto.INT32: "int32",
            TensorProto.INT64: "int64",
            TensorProto.STRING: "string",
            TensorProto.BOOL: "bool",
            TensorProto.FLOAT16: "float16",
            TensorProto.DOUBLE: "float64",
            TensorProto.UINT32: "uint32",
            TensorProto.UINT64: "uint64",
            TensorProto.COMPLEX64: "complex64",
            TensorProto.COMPLEX128: "complex128",
            TensorProto.BFLOAT16: "bfloat16",
            TensorProto.FLOAT8E4M3FN: "fp8_e4m3fn",
            TensorProto.FLOAT8E5M2: "fp8_e5m2",
        }
        return m.get(elem_type, f"onnx_dtype_{elem_type}")


def _shape_from_tensor_type(tensor_type: Any) -> Shape:
    shape: Shape = []
    if not tensor_type.HasField("shape"):
        return shape

    for dim in tensor_type.shape.dim:
        if dim.HasField("dim_value"):
            shape.append(int(dim.dim_value))
        elif dim.HasField("dim_param"):
            shape.append(str(dim.dim_param))
        else:
            shape.append(None)
    return shape


def _extract_vi_dtype_shape(vi: ValueInfoProto) -> Tuple[Optional[str], Shape]:
    if not vi.HasField("type") or not vi.type.HasField("tensor_type"):
        return None, []
    tt = vi.type.tensor_type
    dtype = _dtype_to_str(tt.elem_type) if tt.HasField("elem_type") else None
    shape = _shape_from_tensor_type(tt)
    return dtype, shape


def _extract_init_dtype_shape(t: TensorProto) -> Tuple[Optional[str], Shape]:
    dtype = _dtype_to_str(t.data_type)
    shape: Shape = [int(d) for d in t.dims]
    return dtype, shape


def scoped_tensor(scope: str, name: str) -> str:
    """Computes the name of scoped tensor"""
    return f"{scope}:{name}"


def scoped_id(scope: str, idx: str) -> str:
    """Computes the scope id"""
    return f"{scope}:{idx}"


def descoped_id(name: str) -> str:
    """Computes the descoped id"""
    return ":".join(name.split(":")[1:])


def _extract_attributes(node: NodeProto) -> OnnxAttributes:
    """Extract ONNX attributes from a node into a Python dictionary.

    Converts ONNX AttributeProto to native Python types for easier access.
    Only extracts scalar and list types. Silently skips TENSOR and GRAPH attributes
    as those are not needed for runtime operations and are handled separately.

    Args:
        node: ONNX NodeProto with attributes

    Returns:
        Dictionary mapping attribute name to value (int, float, str, or lists thereof)
        Only includes attributes with supported types. Unsupported types are skipped.
    """
    attrs: OnnxAttributes = {}
    for attr in node.attribute:
        name = attr.name
        if attr.type == AttributeProto.INT:
            attrs[name] = int(attr.i)
        elif attr.type == AttributeProto.FLOAT:
            attrs[name] = float(attr.f)
        elif attr.type == AttributeProto.STRING:
            attrs[name] = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else str(attr.s)
        elif attr.type == AttributeProto.INTS:
            attrs[name] = list(attr.ints)
        elif attr.type == AttributeProto.FLOATS:
            attrs[name] = list(attr.floats)
        elif attr.type == AttributeProto.STRINGS:
            attrs[name] = [s.decode('utf-8') if isinstance(s, bytes) else str(s) for s in attr.strings]
        # Skip TENSOR, GRAPH, GRAPHS - these are handled separately and not needed for runtime ops
        elif attr.type in (AttributeProto.TENSOR, AttributeProto.GRAPH, AttributeProto.GRAPHS):
            logger.debug(
                "Skipping attribute '%s' of type %s in node %s (not needed for runtime operations)",
                name, attr.type, node.name or node.op_type
            )
        else:
            # Unknown attribute type - log warning but don't fail
            logger.warning(
                "Unknown attribute type %s for attribute '%s' in node %s. Skipping.",
                attr.type, name, node.name or node.op_type
            )
    return attrs


class CollapsedOpDagBuilder:
    """
    Build a collapsed op-only DAG and collect tensors with shapes/dtypes.

    - Nodes (ops): {id, name, op_type, inputs, outputs, scope}
    - Edges: op_id -> op_id (producer to consumer within the same graph scope)
    - Tensors: {scoped_name, name, scope, kind, dtype, shape}
    """

    def __init__(self, ignored_ops: Set[str]) -> None:
        self.ops: Dict[str, Dict] = {}
        self.edges: Set[Tuple[str, str]] = set()
        self.tensors: Dict[str, Dict] = {}
        self.model_inputs: List[str] = []
        self.model_outputs: List[str] = []
        self.IGNORED_OPS = ignored_ops

    @staticmethod
    def _escalate_kind(old: Optional[str], new: str) -> str:
        # Priority: output > input > init > intermediate
        order = {"output": 3, "input": 2, "init": 1, "intermediate": 0}
        old_priority = order[old] if old in order else -1
        return new if order[new] >= old_priority else old or new

    def _upsert_tensor(
        self,
        scope: str,
        name: str,
        kind: str,
        dtype: Optional[str] = None,
        shape: Optional[List[Optional[Any]]] = None,
    ) -> None:
        if not name:
            return
        sid = scoped_tensor(scope, name)
        entry = self.tensors.get(sid)
        if entry is None:
            self.tensors[sid] = {
                "scoped_name": sid,
                "name": name,
                "scope": scope,
                "kind": kind,
                "dtype": dtype,
                "shape": list(shape) if shape is not None else None,
            }
        else:
            entry["kind"] = self._escalate_kind(entry.get("kind"), kind)
            # Prefer known dtype/shape if we didn't have them yet
            if entry.get("dtype") is None and dtype is not None:
                entry["dtype"] = dtype
            if (
                entry.get("shape") is None or not entry.get("shape")
            ) and shape is not None:
                entry["shape"] = list(shape)

    def _remove_redundant_tensors(self, ops: Any) -> None:
        unique_tensors = set()
        for op in ops:
            inps, outs = op["inputs"], op["outputs"]
            inp = [scoped_tensor(op["scope"], t) for t in inps]
            out = [scoped_tensor(op["scope"], t) for t in outs]
            unique_tensors |= set(inp) | set(out)

        # Also preserve model inputs/outputs even if they're not referenced by operations
        # (e.g., initializers that are exposed as outputs, or unused inputs)
        model_input_tensors = {scoped_tensor("root", name) for name in self.model_inputs}
        model_output_tensors = {scoped_tensor("root", name) for name in self.model_outputs}
        unique_tensors |= model_input_tensors | model_output_tensors

        for sid in list(self.tensors.keys()):
            if sid not in unique_tensors:
                del self.tensors[sid]
            else:
                # Shapes must be known otherwise it doesn't make sense
                shape = self.tensors[sid]["shape"]
                if not (
                    isinstance(shape, list)
                    and all(map(lambda v: isinstance(v, int), shape))
                    and all(map(lambda v: v >= 0, shape))
                ):
                    raise RuntimeError(
                        f"L2 fusion requires static shapes but recieved {shape} for tensor {sid}"
                    )

    def _remove_redundant_edges(self, ops: Any) -> None:
        unique_ops = set()
        for op in ops:
            unique_ops.add(op["id"])
        for src, dst in self.edges.copy():
            if not (src in unique_ops and dst in unique_ops):
                self.edges.remove((src, dst))

    def _infer_shapes(self, model: ModelProto, infer_shapes: bool) -> ModelProto:
        if infer_shapes:
            try:
                m = onnx.shape_inference.infer_shapes(
                    model, check_type=True, strict_mode=False
                )
            except Exception:  # pylint: disable=broad-except
                logger.debug("Shape inference failed, falling back to original model")
                m = model
        else:
            m = model

        if not m.graph.output or not m.graph.input:
            raise ValueError("Model must have at least one input and one output")
        return m

    def _process_graph(self, g: GraphProto, scope: str) -> None:
        graph_inputs, graph_outputs, inits_by_name = _collect_graph_ios(g)

        # Keep top-level IO names
        if scope == "root":
            self.model_inputs = sorted(graph_inputs)
            self.model_outputs = sorted(graph_outputs)

        # Build a quick lookup for value_info (dtype/shape)
        vi_map: Dict[str, Tuple[Optional[str], Shape]] = {}
        for vi in list(g.input) + list(g.output) + list(g.value_info):
            dtype, shape = _extract_vi_dtype_shape(vi)
            vi_map[vi.name] = (dtype, shape)

        # Also collect dtype/shape from initializers
        for name, init in inits_by_name.items():
            dtype, shape = _extract_init_dtype_shape(init)
            # Initializers are also valid tensors in this graph
            self._upsert_tensor(scope, name, kind="init", dtype=dtype, shape=shape)

        # identify ops to ignore
        ignored_ops = {}
        ignored_ops_rev: Dict[str, List[str]] = {}
        for idx, node in enumerate(g.node):
            if node.op_type in self.IGNORED_OPS:
                if node.input and node.output:
                    ignored_ops[node.output[0]] = node.input[0]
                    ignored_ops_rev.setdefault(node.input[0], []).append(node.output[0])
                continue

        def _resolve(t: str) -> str:
            seen: Set[str] = set()
            while t in ignored_ops and t not in seen:
                seen.add(t)
                t = ignored_ops[t]
            return t

        # First pass: create op nodes and register producers
        producer_of: Dict[str, str] = {}  # tensor_name -> producing op_id
        for idx, node in enumerate(g.node):
            if node.op_type in self.IGNORED_OPS:
                continue

            op_id = f"{scope}/{_safe_node_id(node, idx)}"
            raw_inputs = [x for x in node.input if x]
            inputs = [_resolve(x) for x in raw_inputs]
            outputs = [y for y in node.output if y]

            self.ops[op_id] = {
                "id": op_id,
                "name": node.name or "",
                "op_type": node.op_type,
                "inputs": inputs,
                "outputs": outputs,
                "scope": scope,
                "attributes": _extract_attributes(node),
            }

            # Register outputs as produced tensors and upsert their metadata
            for t in outputs:
                producer_of[t] = op_id
                # Kind: graph output or intermediate
                kind = "output" if t in graph_outputs else "intermediate"
                dtype, shape = vi_map.get(t, (None, []))
                # If we don't have VI shape, leave None; shape inference may fill many but not all.
                self._upsert_tensor(scope, t, kind=kind, dtype=dtype, shape=shape)

        # Second pass: wire producer->consumer edges and upsert input tensors
        for idx, node in enumerate(g.node):
            consumer_id = f"{scope}/{_safe_node_id(node, idx)}"
            for t_in in node.input:
                if not t_in:
                    continue

                t = _resolve(t_in)
                prod_id = producer_of.get(t)
                if prod_id and prod_id != consumer_id:
                    self.edges.add((prod_id, consumer_id))

                # Upsert input tensor metadata
                if t in inits_by_name:
                    # Already inserted as 'init'; upgrade to input if needed
                    dtype, shape = _extract_init_dtype_shape(inits_by_name[t])
                    self._upsert_tensor(scope, t, kind="init", dtype=dtype, shape=shape)
                else:
                    # Graph input or intermediate
                    kind = "input" if t in graph_inputs else "intermediate"
                    dtype, shape = vi_map.get(t, (None, []))
                    self._upsert_tensor(scope, t, kind=kind, dtype=dtype, shape=shape)

        # Recurse into control-flow subgraphs
        for idx, node in enumerate(g.node):
            for subgraph, tag in _iter_subgraphs(node):
                sub_scope = f"{scope}/{node.op_type}:{tag}"
                self._process_graph(subgraph, sub_scope)

        # Ensure all model inputs and outputs are in tensors, even if unreferenced
        # This handles: unused inputs, initializers exposed as outputs
        if scope == "root":
            for inp_name in graph_inputs:
                if scoped_tensor(scope, inp_name) not in self.tensors:
                    dtype, shape = vi_map.get(inp_name, (None, []))
                    self._upsert_tensor(scope, inp_name, kind="input", dtype=dtype, shape=shape)

            for out_name in graph_outputs:
                if scoped_tensor(scope, out_name) not in self.tensors:
                    if out_name in inits_by_name:
                        dtype, shape = _extract_init_dtype_shape(inits_by_name[out_name])
                        self._upsert_tensor(scope, out_name, kind="init", dtype=dtype, shape=shape)
                    else:
                        dtype, shape = vi_map.get(out_name, (None, []))
                        self._upsert_tensor(scope, out_name, kind="output", dtype=dtype, shape=shape)

    def build(self, model: ModelProto, infer_shapes: bool = True) -> Dict:
        """Build DAG from given ONNX model, optionally with shape inference"""
        # Optionally run shape inference to populate more value_info shapes
        m = self._infer_shapes(model, infer_shapes)

        # Start traversal
        self.ops.clear()
        self.edges.clear()
        self.tensors.clear()
        self.model_inputs.clear()
        self.model_outputs.clear()

        scope = "root"
        self._process_graph(m.graph, scope=scope)
        logger.info(
            "ONNX graph consists of the following L2/dataflow operations: %s",
            set(op['op_type'] for op in self.ops.values())
        )
        self._remove_redundant_tensors(self.ops.values())
        self._remove_redundant_edges(self.ops.values())
        return {
            "ops": list(self.ops.values()),
            "edges": [{"src": s, "dst": d} for (s, d) in sorted(self.edges)],
            "tensors": list(self.tensors.values()),
            "model_inputs": {scoped_id(scope, t) for t in self.model_inputs},
            "model_outputs": {scoped_id(scope, t) for t in self.model_outputs},
        }


def construct_op_dag(model: onnx.ModelProto, ignored_ops: Set[str] = IGNORED_OPS) -> Dict:
    """
    Construct a collapsed op-only DAG from an ONNX model.

    Parameters
    ----------
    model : onnx.ModelProto
        The loaded ONNX model object.

    Returns
    -------
    dict
        {
          "ops": [
            {
              "id": "root/Conv_0",
              "name": "conv1",
              "op_type": "Conv",
              "inputs": ["X", "W", "B"],
              "outputs": ["Y"],
              "scope": "root"
            },
            ...
          ],
          "edges": [
            {"src": "root/Conv_0", "dst": "root/Relu_1"},
            ...
          ],
          "tensors": [
            {
              "scoped_name": "root:Y",
              "name": "Y",
              "scope": "root",
              "kind": "intermediate|input|output|init",
              "dtype": "float32",
              "shape": [1, 64, "H/2", "W/2"]   # ints, strings (symbolic), or None
            },
            ...
          ],
          "model_inputs": ["input_0", ...],
          "model_outputs": ["output_0", ...]
        }
    """
    builder = CollapsedOpDagBuilder(ignored_ops)
    return builder.build(model)


def dag_ops_to_dot(dag: Dict, filename: str) -> None:
    """
    Construct a dot file for the operators contained in the DAG

    Parameters
    ----------
    dag : Dict
        The dag constructed using the `build` function
    filename: str
        The name for the output dotviz file
    """

    ops = dag["ops"]
    edges = dag["edges"]
    model_inputs = dag["model_inputs"]
    model_outputs = dag["model_outputs"]

    dot_lines = ["digraph ONNXGraph {", "rankdir=LR;"]

    # Input nodes
    for input_name in model_inputs:
        dot_lines.append(
            f'"{input_name}" '
            f'[shape=box, style=filled, color=lightblue, label="{input_name}\\nInput"];'
        )

    # Output nodes
    for output_name in model_outputs:
        dot_lines.append(
            f'"{output_name}" '
            f'[shape=box, style=filled, color=lightgreen, label="{output_name}\\nOutput"];'
        )

    # Operation nodes
    for op in ops:
        op_id = op["id"]
        op_name = op["name"]
        op_type = op["op_type"]
        label = f"{op_name}\\n{op_type}"
        dot_lines.append(
            f'"{op_id}" [shape=ellipse, style=filled, color=lightgray, label="{label}"];'
        )

    # Edges
    for edge in edges:
        src, dst = edge["src"], edge["dst"]
        dot_lines.append(f'"{src}" -> "{dst}";')

    dot_lines.append("}")

    with open(filename, "w", encoding="utf-8") as f:
        f.write("\n".join(dot_lines))
    for edge in edges:
        src, dst = edge["src"], edge["dst"]
        dot_lines.append(f'"{src}" -> "{dst}";')

    dot_lines.append("}")

    with open(filename, "w", encoding="utf-8") as f:
        f.write("\n".join(dot_lines))
