"""
Graph operations for managing computational graphs.

This module provides the GraphOps class which represents a computational graph
with operations and tensors, supporting:
- Tensor and operation management
- Topological sorting for execution order
- Tensor lifetime computation
- Tensor usage analysis
"""

from collections import Counter
from typing import Any, Dict, List, Tuple

from graph.tensor_types import Operation, Tensor
from graph.utilities import logger


class GraphOps:
    """Manages a computational graph with operations and tensors.

    This class provides operations for building and analyzing computational graphs,
    including tensor lifetime computation, topological sorting for execution order,
    and tensor usage analysis.

    Attributes:
        operations: List of operations in the graph
        tensors: Dictionary mapping tensor IDs to Tensor objects
        execution_order: Topologically sorted list of operations
        operation_ids: Dictionary mapping operation IDs to their indices
        with_given_execution_order: Whether execution order was provided at construction

    Examples:
        >>> from graph.tensor_types import Tensor, Operation, XrtId
        >>> graph = GraphOps()
        >>> graph.tensors
        {}
        >>> len(graph.operations)
        0

        >>> # Add tensors
        >>> input_tensor = Tensor(id="input", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> output_tensor = Tensor(id="output", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
        >>> graph.add_tensor(input_tensor)
        >>> graph.add_tensor(output_tensor)
        >>> len(graph.tensors)
        2

        >>> # Add operation
        >>> op = Operation(id="conv1", type="Conv", inputs=["input"], outputs=["output"])
        >>> idx = graph.add_operation(op)
        >>> idx
        0
        >>> len(graph.operations)
        1
    """

    def __init__(self, execution_order: List[Operation] | None = None):
        self.operations: List[Operation] = []
        self.tensors: Dict[str, Tensor] = {}
        self.execution_order: List[Operation] = []
        self.operation_ids: Dict[str, int] = {}
        self.with_given_execution_order = False

        if execution_order is not None:
            self.execution_order = execution_order
            self.with_given_execution_order = True

    def add_tensor(self, tensor: Tensor):
        """Add a tensor to the graph.

        Args:
            tensor: Tensor to add to the graph

        Examples:
            >>> from graph.tensor_types import Tensor, XrtId
            >>> graph = GraphOps()
            >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> graph.add_tensor(tensor)
            >>> "t1" in graph.tensors
            True
            >>> graph.get_tensor("t1").id
            't1'
        """
        self.tensors[tensor.id] = tensor

    def update_tensor(self, tensor_id: str, **kwargs: Any):
        """Update attributes of an existing tensor.

        Args:
            tensor_id: ID of the tensor to update
            **kwargs: Attributes to update

        Raises:
            AssertionError: If tensor ID not found

        Examples:
            >>> from graph.tensor_types import Tensor, XrtId
            >>> graph = GraphOps()
            >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> graph.add_tensor(tensor)
            >>> graph.update_tensor("t1", bin=XrtId.IFM)
            >>> graph.get_tensor("t1").bin
            <XrtId.IFM: 1>
        """
        assert tensor_id in self.tensors, f"Tensor ID {tensor_id} not found"
        self.tensors[tensor_id] = self.tensors[tensor_id].update(**kwargs)

    def add_operation(self, operation: Operation) -> int:
        """Add an operation to the graph.

        Args:
            operation: Operation to add to the graph

        Returns:
            Index of the added operation

        Raises:
            AssertionError: If operation ID already exists

        Examples:
            >>> from graph.tensor_types import Operation
            >>> graph = GraphOps()
            >>> op1 = Operation(id="conv1", type="Conv", inputs=["in"], outputs=["out"])
            >>> idx = graph.add_operation(op1)
            >>> idx
            0
            >>> op2 = Operation(id="add1", type="Add", inputs=["a", "b"], outputs=["c"])
            >>> idx = graph.add_operation(op2)
            >>> idx
            1
            >>> len(graph.operations)
            2
        """
        assert operation.id not in self.operation_ids, "Operation ID already exists"
        self.operations.append(operation)
        self.operation_ids[operation.id] = len(self.operations) - 1
        return len(self.operations) - 1

    def get_operation(self, op_id: int) -> Operation:
        """Get an operation by its index.

        Args:
            op_id: Index of the operation

        Returns:
            Operation at the specified index

        Raises:
            AssertionError: If operation ID out of range

        Examples:
            >>> from graph.tensor_types import Operation
            >>> graph = GraphOps()
            >>> op = Operation(id="conv1", type="Conv", inputs=["in"], outputs=["out"])
            >>> graph.add_operation(op)
            0
            >>> retrieved = graph.get_operation(0)
            >>> retrieved.id
            'conv1'
            >>> retrieved.type
            'Conv'
        """
        assert op_id < len(self.operations), "Operation ID out of range"
        return self.operations[op_id]

    def get_operation_by_name(self, op_name: str) -> Operation:
        """Get an operation by its name.

        Args:
            op_name: Name/ID of the operation

        Returns:
            Operation with the specified name

        Raises:
            AssertionError: If operation name not found

        Examples:
            >>> from graph.tensor_types import Operation
            >>> graph = GraphOps()
            >>> op = Operation(id="conv1", type="Conv", inputs=["in"], outputs=["out"])
            >>> graph.add_operation(op)
            0
            >>> retrieved = graph.get_operation_by_name("conv1")
            >>> retrieved.id
            'conv1'
            >>> retrieved.type
            'Conv'
        """
        op_id = self.operation_ids.get(op_name)
        assert op_id is not None, "Operation name not found"
        return self.operations[op_id]

    def get_tensor(self, tensor_id: str) -> Tensor:
        """Get a tensor by its ID.

        Args:
            tensor_id: ID of the tensor

        Returns:
            Tensor with the specified ID

        Raises:
            AssertionError: If tensor ID not found

        Examples:
            >>> from graph.tensor_types import Tensor, XrtId
            >>> graph = GraphOps()
            >>> tensor = Tensor(id="t1", shape=(10, 10), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT)
            >>> graph.add_tensor(tensor)
            >>> retrieved = graph.get_tensor("t1")
            >>> retrieved.id
            't1'
            >>> retrieved.shape
            (10, 10)
        """
        assert tensor_id in self.tensors, "Tensor ID not found"
        return self.tensors[tensor_id]

    def _compute_tensor_lifetimes(self) -> Dict[str, Tuple[int, int]]:
        """Compute when each tensor is first created and last used.

        Returns:
            Dictionary mapping tensor IDs to (start_step, end_step) tuples

        Examples:
            >>> from graph.tensor_types import Tensor, Operation, XrtId
            >>> graph = GraphOps()
            >>> # Add tensors
            >>> graph.add_tensor(Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_tensor(Tensor(id="t2", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_tensor(Tensor(id="t3", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> # Add operations: t1 -> op1 -> t2 -> op2 -> t3
            >>> graph.add_operation(Operation(id="op1", type="Conv", inputs=["t1"], outputs=["t2"]))
            0
            >>> graph.add_operation(Operation(id="op2", type="Add", inputs=["t2"], outputs=["t3"]))
            1
            >>> lifetimes = graph._compute_tensor_lifetimes()
            >>> lifetimes["t1"]
            (0, 0)
            >>> lifetimes["t2"]
            (0, 1)
            >>> lifetimes["t3"]
            (1, 1)
        """
        lifetimes: Dict[str, List[int]] = {}

        # Build execution order based on dependencies
        execution_order = self.get_execution_order()

        for step, op in enumerate(execution_order):
            # Output tensors are created at this step
            for tensor_id in op.outputs:
                if tensor_id not in lifetimes:
                    lifetimes[tensor_id] = [step, step]
                else:
                    lifetimes[tensor_id][0] = min(lifetimes[tensor_id][0], step)

            # Input tensors are used at this step
            for tensor_id in op.inputs:
                if tensor_id not in lifetimes:
                    lifetimes[tensor_id] = [step, step]
                else:
                    lifetimes[tensor_id][1] = max(lifetimes[tensor_id][1], step)

        logger.debug("\n== Tensor lifetime information ==\n")
        for a, b, c in sorted((v[0], v[1], k) for k, v in lifetimes.items()):
            logger.debug("tensor %s => lifetime starts at layer %s, and ends at %s", c, a, b)

        return {k: (v[0], v[1]) for k, v in lifetimes.items()}

    def _topological_sort(self) -> None:
        """Sort operations in topological order based on dependencies.

        Uses Kahn's algorithm to perform topological sort. Updates the
        execution_order attribute with the sorted operations.

        Raises:
            AssertionError: If graph has at least one cycle

        Examples:
            >>> from graph.tensor_types import Tensor, Operation, XrtId
            >>> graph = GraphOps()
            >>> # Add tensors
            >>> graph.add_tensor(Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_tensor(Tensor(id="t2", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_tensor(Tensor(id="t3", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> # Add operations in arbitrary order
            >>> graph.add_operation(Operation(id="op2", type="Add", inputs=["t2"], outputs=["t3"]))
            0
            >>> graph.add_operation(Operation(id="op1", type="Conv", inputs=["t1"], outputs=["t2"]))
            1
            >>> graph._topological_sort()
            >>> # Execution order should be op1 -> op2
            >>> graph.execution_order[0].id
            'op1'
            >>> graph.execution_order[1].id
            'op2'
        """
        in_degree = {op.id: 0 for op in self.operations}
        graph: Dict[str, Any] = {op.id: [] for op in self.operations}
        op_map = {op.id: op for op in self.operations}

        # Build lookup table once: O(n)
        tensor_to_producer = {}
        for op in self.operations:
            for output_tensor in op.outputs:
                tensor_to_producer[output_tensor] = op.id

        # Build dependency graph
        for op in self.operations:
            for input_tensor in op.inputs:
                producer_op_id = tensor_to_producer.get(input_tensor)
                if producer_op_id is not None and producer_op_id != op.id:
                    graph[producer_op_id].append(op.id)
                    in_degree[op.id] += 1

        # Kahn's algorithm
        queue = [op_id for op_id, degree in in_degree.items() if degree == 0]
        result = []

        while queue:
            current = queue.pop(0)
            result.append(op_map[current])

            for neighbor in graph[current]:
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        assert len(result) == len(self.operations), "Graph has at least one cycle"
        self.execution_order = result

    def _compute_tensor_usage(self) -> Dict[str, int]:
        """Compute the number of times each tensor is used as input or output.

        Returns:
            Dictionary mapping tensor IDs to usage counts

        Examples:
            >>> from graph.tensor_types import Operation
            >>> graph = GraphOps()
            >>> graph.add_operation(Operation(id="op1", type="Conv", inputs=["t1"], outputs=["t2"]))
            0
            >>> graph.add_operation(Operation(id="op2", type="Add", inputs=["t2", "t3"], outputs=["t4"]))
            1
            >>> usage = graph._compute_tensor_usage()
            >>> usage["t1"]
            1
            >>> usage["t2"]
            2
            >>> usage["t3"]
            1
            >>> usage["t4"]
            1
        """
        degree: Dict[str, int] = Counter()

        for op in self.operations:
            degree.update(op.inputs)
            degree.update(op.outputs)

        return degree

    def get_execution_order(self, recompute: bool = False) -> List[Operation]:
        """Get the execution order of operations.

        If execution order was provided at construction, returns that order.
        Otherwise, computes topological sort on first call or when recompute=True.

        Args:
            recompute: Whether to recompute the execution order

        Returns:
            List of operations in execution order

        Examples:
            >>> from graph.tensor_types import Tensor, Operation, XrtId
            >>> graph = GraphOps()
            >>> graph.add_tensor(Tensor(id="t1", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_tensor(Tensor(id="t2", shape=(10,), dtype="TensorProto.FLOAT", bin=XrtId.DEFAULT))
            >>> graph.add_operation(Operation(id="op1", type="Conv", inputs=["t1"], outputs=["t2"]))
            0
            >>> order = graph.get_execution_order()
            >>> len(order)
            1
            >>> order[0].id
            'op1'
        """
        if self.with_given_execution_order:
            return self.execution_order

        if not self.execution_order or recompute:
            self._topological_sort()

        return self.execution_order
