# fmt: on
import os
from pathlib import Path

import numpy as np
import onnx
import onnxruntime as ort

from OGOAT.src.L1_fusion.L1_utils.utils import (
    construct_constant_dict,
    construct_nodes_dict,
    remove_model,
    save_model,
)
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionArguments
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Node,
    OutputTensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.common_type import NamedArray
from OGOAT.src.utils.context import Logger


def remove_dead_subgraphs_from_model(model: onnx.ModelProto) -> int:
    """
    Remove all nodes that don't contribute to any graph output.
    This performs a backward traversal from outputs to identify live nodes.
    Also removes unused graph inputs that aren't consumed by any live nodes.
    Works directly on onnx.ModelProto without requiring Node wrappers.

    Args:
        model: ONNX model to clean up

    Returns:
        Number of dead nodes removed
    """
    # Collect all graph output names
    batch_outputs = {output.name for output in model.graph.output}

    # Track live nodes and tensors
    live_nodes: set[str] = set()
    live_tensors: set[str] = set()
    live_tensors.update(batch_outputs)

    # Backward traversal to find live nodes
    changed = True
    while changed:
        changed = False
        for node in model.graph.node:
            node_name = node.name
            if node_name in live_nodes:
                continue

            # Check if any output of this node is live
            if any(out in live_tensors for out in node.output):
                live_nodes.add(node_name)
                # Mark all inputs as live
                for inp in node.input:
                    if inp not in live_tensors:
                        live_tensors.add(inp)
                        changed = True

    # Remove dead nodes from model
    nodes_to_remove = [node for node in model.graph.node if node.name not in live_nodes]
    if nodes_to_remove:
        for node in nodes_to_remove:
            model.graph.node.remove(node)

    # Remove unused graph inputs (inputs not in live_tensors)
    inputs_to_remove = [
        inp for inp in model.graph.input if inp.name not in live_tensors
    ]
    if inputs_to_remove:
        for inp in inputs_to_remove:
            model.graph.input.remove(inp)

    # Remove unused initializers (not in live_tensors)
    initializers_to_remove = [
        init for init in model.graph.initializer if init.name not in live_tensors
    ]
    if initializers_to_remove:
        for init in initializers_to_remove:
            model.graph.initializer.remove(init)

    return len(nodes_to_remove)


class ConstEval:
    """
    Collect and evaluate all nodes that depend only on constant values or known
    shapes.
    """

    def __init__(
        self,
        model: onnx.ModelProto,
        onnx_ops: OnnxOpsWrapper,
        fusionArgs: FusionArguments,
    ) -> None:
        self._model_dict = ModelDict(model, onnx_ops)
        self.fusion_args = fusionArgs

        walk_cfg_plain = WalkCfgPlain()
        self._nodes = [
            Node(self._model_dict, walk_cfg_plain, node_name)
            for node_name in self._model_dict.get_node_names()
        ]

        self._input_values: dict[str, np.ndarray] = {}
        self._const_nodes: dict[str, Node] = {}
        self._const_outputs: dict[str, Node] = {}
        self._new_const_values: dict[str, np.ndarray] = {}

    def _add_const_node(self, node: Node) -> None:
        """
        Add a node to _const_nodes if it is not yet in.
        Add all its outputs as constant.
        """
        if node.get_name() in self._const_nodes:
            return  # already added -> nothing to do
        self._const_nodes[node.get_name()] = node
        for out_tensor in node.get_outputs():
            self._add_const_outputs(out_tensor)

    def _add_const_outputs(self, out_tensor: OutputTensor) -> None:
        """
        Add an output tensor to _const_outputs if it is not yet in.
        Check all nodes reading this output if those can be added as constant.
        """
        if out_tensor.get_name() in self._const_outputs:
            return  # already added -> nothing to do
        self._const_outputs[out_tensor.get_name()] = out_tensor
        for reader in out_tensor.get_readers():
            try:
                next_node = reader.require_node()
            except MatcherError:
                continue
            self._check_const_node(next_node)

    def _check_const_node(self, node: Node) -> None:
        """
        Check if a node is constant (all inputs are initializers) and add it
        if it is constant.
        Note: Constant operator nodes do not have any inputs, so those are
        constant.
        """
        for in_tensor in node.get_inputs():
            if in_tensor.check_initializer():
                continue  # input connected to initializer -> const
            if in_tensor.get_name() in self._const_outputs:
                continue  # input has been found to be computable const
            return  # input is not const -> node is not const
        self._add_const_node(node)

    def _find_const_nodes(self) -> None:
        """
        Find all nodes that are constant and add those.
        """
        for node in self._nodes:
            self._check_const_node(node)

    def _find_known_shapes(self) -> None:
        """
        Find nodes with shape operator that have a known shape on the input.
        Add those nodes to _const_nodes and a value for the data input to
        _input_values.
        """
        for node in self._nodes:
            try:
                node.require(opType.Shape)
                data_tensor = node("data").require_tensor()
                dtype = data_tensor.get_dtype_raw()
                shape = data_tensor.get_shape()
            except MatcherError:
                continue
            self._add_const_node(node)
            value_with_shape = np.zeros(shape).astype(
                onnx.helper.tensor_dtype_to_np_dtype(dtype)
            )
            self._input_values[data_tensor.get_name()] = value_with_shape

    def _find_softmax_nodes(self) -> None:
        """
        Find Softmax nodes which has dimension 1 along with their axis.
        """
        for node in self._nodes:
            try:
                node.require(opType.Softmax)
                axis_val = (
                    -1
                    if node.get_attribute_value("axis") is None
                    else int(node.get_attribute_value("axis"))
                )
                input_shape = node("input").require_tensor().get_shape()
                data_tensor = node("input").require_tensor()
                dtype = data_tensor.get_dtype_raw()
            except MatcherError:
                continue

            if input_shape[axis_val] == 1:
                self._add_const_node(node)
                value_with_shape = np.zeros(input_shape).astype(
                    onnx.helper.tensor_dtype_to_np_dtype(dtype)
                )
                self._input_values[data_tensor.get_name()] = value_with_shape

    def collect(self) -> None:
        """
        Collect constant nodes and known shapes.
        """
        self._find_known_shapes()
        self._find_softmax_nodes()
        self._find_const_nodes()

    def count(self) -> int:
        """
        Return number of constant nodes found.
        """
        return len(self._const_nodes)

    def _remove_dead_subgraphs(self) -> None:
        """
        Remove all nodes that don't contribute to any graph output.
        This performs a backward traversal from outputs to identify live nodes.
        """
        dead_count = remove_dead_subgraphs_from_model(self._model_dict._model)

        if dead_count > 0:
            print(
                f"Dead code elimination: Removing {dead_count} nodes that don't contribute to outputs"
            )

        # Clean up unused initializers
        self._model_dict.remove_unused_ini_nodes()

    def evaluate(self, ) -> None:
        """
        Evaluate all nodes that have a constant value.
        Store newly found constant values in _new_const_values.
        """
        # early exit if no nodes to evaluate
        if not self._const_nodes:
            self._new_const_values = {}
            return

        # exclusion logic for excluding nodes from const eval
        exclude_node_names = set(
            node_name
            for node_name, node in self._const_nodes.items()
            if any(outp.check_graph_output() for outp in node.get_outputs())
            or node.check(opType.DequantizeLinear | opType.Transpose)
        )

        # we should keep the nodes whose outputs aren't prespective constants, i.e, are consumed by other nodes in _const_nodes
        exclude_node_names_2nd_pass = set()
        for node_name in exclude_node_names:
            node = self._const_nodes[node_name]
            # Check if any output of this node is consumed by other const nodes
            is_consumed_by_const_node = False
            for outp in node.get_outputs():
                for reader in outp.get_readers():
                    if reader.check_node():
                        reader_node = reader.require_node()
                        if reader_node.get_name() in self._const_nodes:
                            is_consumed_by_const_node = True
                            break
                if is_consumed_by_const_node:
                    break
            
            # Only exclude if NOT consumed by other const nodes
            if not is_consumed_by_const_node:
                exclude_node_names_2nd_pass.add(node_name)
        

        filtered_const_nodes = {
            name: node
            for name, node in self._const_nodes.items()
            if name not in exclude_node_names_2nd_pass
        }

        if not filtered_const_nodes:
            self._new_const_values = {}
            return

        print(
            f"Excluded {len(exclude_node_names_2nd_pass)} nodes (DequantizeLinear/Transpose/graph outputs) from constant evaluation"
        )

        # Get full const_eval_model
        const_eval_model = self._model_dict.extract_filtered_model(
            filtered_const_nodes.keys()
        )

        # Save full model for reference
        const_eval_model_path = Path(self.fusion_args.out_dir_path) / (
            self.fusion_args.model_name + "_const_eval.onnx"
        )
        save_model(
            const_eval_model._model,
            str(const_eval_model_path),
            self.fusion_args.external_data,
        )

        # Create ORT session options
        so = ort.SessionOptions()
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

        # Get all output names
        all_outputs = [output.name for output in const_eval_model._model.graph.output]
        print(f"Constant evaluation: Processing {len(all_outputs)} outputs in batches")

        self._new_const_values = {}
        BATCH_SIZE = self.fusion_args.shape_inference_outputs

        const_temp_model_path = Path(self.fusion_args.out_dir_path) / (
            self.fusion_args.model_name + "_const_eval_temp.onnx"
        )

        # Process in batches
        for batch_start in range(0, len(all_outputs), BATCH_SIZE):
            batch_end = min(batch_start + BATCH_SIZE, len(all_outputs))
            batch_outputs = all_outputs[batch_start:batch_end]

            print(
                f"Processing batch: outputs {batch_start+1} to {batch_end} ({len(batch_outputs)} outputs)"
            )

                        # Extract minimal subgraph for this batch
            batch_model = None
            print(f"Extracting batch model...")
            # batch_model = onnx.ModelProto()
            # batch_model.CopyFrom(const_eval_model._model)
            batch_model = onnx.load_model(const_eval_model_path, load_external_data=False)


            # Set only batch outputs
            del batch_model.graph.output[:]
            for output_name in batch_outputs:
                for orig_output in const_eval_model._model.graph.output:
                    if orig_output.name == output_name:
                        batch_model.graph.output.append(orig_output)
                        break

            # Remove dead subgraphs
            dead_count = remove_dead_subgraphs_from_model(batch_model)
            if dead_count > 0:
                print(f"  Removed {dead_count} dead nodes from batch model")
            # Save batch model
            save_model(batch_model, str(const_temp_model_path), False)

            print(
                f"  Batch model has {len(batch_model.graph.node)} nodes, {len(batch_model.graph.output)} outputs"
            )

            # Create ORT session and run
            try:
                ort_session = ort.InferenceSession(
                    str(const_temp_model_path),
                    so,
                    providers=["CPUExecutionProvider"],
                )

                inputs = {
                    inp.name: self._input_values[inp.name]
                    for inp in ort_session.get_inputs()
                }
                outputs = [outp.name for outp in ort_session.get_outputs()]

                print(f"  Running ORT inference on {len(outputs)} outputs...")
                out_data = ort_session.run(outputs, inputs)

                # Store results
                for out_name, out_val in zip(outputs, out_data):
                    self._new_const_values[out_name] = out_val

                del out_data
                del ort_session

                print(
                    f"Batch processed successfully ({len(self._new_const_values)} total outputs)"
                )

            except Exception as e:
                print(f"Batch failed with error: {e}")
                print(f"Skipping batch {batch_start+1}-{batch_end}")
                continue  # Skip to next batch

            # Clean up temp model after each batch
            if not self.fusion_args.debug:
                try:
                    remove_model(
                        str(const_temp_model_path), self.fusion_args.external_data
                    )
                except:
                    pass  # Ignore cleanup errors

        # Clean up full model after all batches
        if not self.fusion_args.debug:
            try:
                remove_model(str(const_eval_model_path), self.fusion_args.external_data)
            except:
                pass  # Ignore cleanup errors

        print(
            f"Constant evaluation completed: {len(self._new_const_values)} values computed"
        )

    def exclude_nodes(self) -> None:
        """
        Exclude certain nodes from this optimization.
        - Exclude nodes producing a graph output, because graph outputs
          apparently cannot be provided from an initializer (although
          technically allowed by ONNX spec).
        - Exclude DequantizeLinear nodes, because the fusion patterns depend
          on those nodes to be present even if those read an initializer.
        - Exclude Transpose nodes, because those are needed by several patterns.

        Note: This is now called after evaluate(), so it removes nodes from
        _const_nodes to prevent them from being removed in update().
        The actual exclusion from const_eval_model happens in evaluate() before
        creating the model.
        """
        exclude_node_names = [
            node_name
            for node_name, node in self._const_nodes.items()
            if any(outp.check_graph_output() for outp in node.get_outputs())
            or node.check(opType.DequantizeLinear | opType.Transpose)
        ]
        for node_name in exclude_node_names:
            node = self._const_nodes[node_name]
            # Remove outputs from _new_const_values (if they were computed)
            for outp in node.get_outputs():
                if outp.get_name() in self._new_const_values:
                    del self._new_const_values[outp.get_name()]
            # Remove from _const_nodes so they won't be removed in update()
            del self._const_nodes[node_name]

    def update(self) -> None:
        """
        Update model to use pre-computed constants.
        Remove nodes that compute constants and are not required any more.
        Add initializers to provide the pre-computed constants.
        Then remove any dead subgraphs that don't contribute to outputs.
        """
        # remove nodes that compute known constant values
        for const_node_name in self._const_nodes.keys():
            self._model_dict.remove_node(const_node_name)
        # add initializers for constants
        for ini_name, value in self._new_const_values.items():
            self._model_dict.add_initializer(
                onnx.numpy_helper.from_array(value, ini_name)
            )
        # Cleanup pass - remove unconnected nodes and initializers
        self._model_dict.remove_unconnected()
        self._model_dict.remove_unused_ini_nodes()

def remove_constants(model):
    md_constants: dict[str, NamedArray] = construct_constant_dict(model)
    md_nodes: dict[str, onnx.NodeProto] = construct_nodes_dict(model)
    for const_tensor, (const_node, _value, _dtype) in md_constants.items():
        model.graph.node.remove(md_nodes[const_node])
        model.graph.initializer.append(
            onnx.numpy_helper.from_array(np.array(_value, _dtype), const_tensor)
        )
    print(
        f"Constant evaluation replaced {len(md_constants)} constant nodes with initializers."
    )


def replace_const(
    model: onnx.ModelProto,
    onnx_ops: OnnxOpsWrapper,
    fusionArgs: FusionArguments,
    logger: Logger,
) -> None:
    """
    Replace all constant computations (depending only on constants, initializers
    and known shapes) in model with initializers containing the pre-computed
    values.
    The passed model is updated in place.
    """
    const_eval = ConstEval(model, onnx_ops, fusionArgs)
    const_eval.collect()
    logger.info(f"Constant evaluation found {const_eval.count()} candidate nodes.")
    const_eval.evaluate()
    const_eval.exclude_nodes()
    const_eval.update()
    logger.info(f"Constant evaluation pre-computed {const_eval.count()} nodes.")
