"""
This module determines which operators of ONNX graph can be placed entirely in L2 memory.
"""

import argparse
import json
import logging
import os
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple

import onnx

# pylint: disable=import-error,redefined-outer-name,no-name-in-module
from graph.allocation_types import (
    AllocationConfig,
    AllocationResult,
    AllocationStrategy,
    MemoryBlock,
    MemoryConfig,
    TensorAllocation,
)
from graph.common import remove_nodes_by_op_type_chained
from graph.dag import IGNORED_OPS, construct_op_dag, scoped_id, scoped_tensor
from graph.graph_ops import GraphOps
from graph.L2_fusion_tiling import Tiler, TilingSchema
from graph.memory_configs import COL1_MEMORY_CONFIG, END_MEMORY_CONFIG, START_END_MEMORY_CONFIG
from graph.tensor_types import Operation, Tensor, TensorLifetime, TensorLocation
from graph.utilities import logger
from graph.tensor_memory_allocator import TensorMemoryAllocator

CURRDIR = os.path.dirname(os.path.abspath(__file__))


class GraphMemoryScheduler(GraphOps):
    """Schedules operators of DAG"""

    def __init__(self, alloc_config: AllocationConfig, memory_config: List[MemoryBlock], execution_order: List[Operation] | None = None):
        super().__init__(execution_order=execution_order)
        self.allocator = TensorMemoryAllocator(alloc_config, memory_config)

    def schedule_memory(
        self,
        enable_spilling: bool = False,
        is_nonwaic: bool = False,
    ) -> Dict[int, List[Tuple[AllocationResult, TensorAllocation]]]:
        """Schedule memory allocation for the entire graph."""
        lifetimes = self._compute_tensor_lifetimes()
        tensor_usage = self._compute_tensor_usage()
        allocation_results: Dict[
            int, List[Tuple[AllocationResult, TensorAllocation]]
        ] = defaultdict(list)

        # set tensor sizes to zero for tensors that are streamed or reside in control block
        for tensor_id, tensor in self.tensors.items():
            if (
                (tensor_usage[tensor_id] == 1 or tensor.is_constant)
                and not tensor.is_model_io
            ):
                object.__setattr__(tensor, "size", 0)

        for tensor in self.tensors.values():
            logger.debug(
                "%s Tensor %s: shape=%s, size=%s, bin=%s, usage=%s",
                ("Skipping" if tensor.size == 0 else "Allocating"),
                tensor.id, tensor.shape, tensor.size, tensor.bin, tensor_usage[tensor.id]
            )

        # tensors with lifetimes
        to_allocate = [
            TensorAllocation(
                self.tensors[tensor_id],
                TensorLifetime(start, end),
                TensorLocation.UNKNOWN,
            )
            for tensor_id, (start, end) in lifetimes.items()
        ]
        execution_order = self.get_execution_order()

        # Note: The sorting should prioritize inputs before outputs for noop-like ops which require
        # output memory addresses to be aliased to input addresses. For L2, for now, we relax the
        # constraint to prevent spilling in YoloV3. This should be fixed later.
        if is_nonwaic:
            to_allocate.sort(key=lambda t: (t.range.start, ))
        else:
            to_allocate.sort(key=lambda t: (t.range.start, t.range.end, t.tensor.id in execution_order[t.range.start].outputs))

        # we currently don't support fill/spill from memory at will
        # so we will traverse the graph, and find the tensors that will be spilled
        # during the allocation, we won't allocate them to begin with
        spilled_tensors = set()
        while not enable_spilling:
            has_spill = False
            alloc_copy = [alloc.copy() for alloc in to_allocate]
            logger.debug("\n\n>> Starting iteration\n\n")

            for allocation in alloc_copy:
                start, tensor_id = allocation.range.start, allocation.tensor.id
                logger.debug(
                    "\n ======= Layer %s, Operator %s ======= \nCurrent allocations %s",
                    start,
                    self.execution_order[start].id,
                    self.allocator.allocations.keys(),
                )
                self.allocator.free_expired_allocations(start)

                if allocation.tensor.id in spilled_tensors:
                    continue

                if self.allocator.allocate_in_place(allocation, execution_order[start]):
                    logger.debug(
                        "allocated %s (in place) with size %s, stats %s",
                        tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                    )
                    continue
                if self.allocator.allocate(allocation):
                    logger.debug(
                        "allocated %s with size %s, stats %s",
                        tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                    )
                    continue

                success, spills = self.allocator.allocate_with_spilling(allocation)
                if success:
                    logger.debug(
                        "allocated %s (after spilling) with size %s, stats %s",
                        tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                    )
                    for spill in spills:
                        spilled_tensors.add(spill.tensor.id)
                else:
                    self.allocator.mark_allocation_as_spilled(allocation)
                    spilled_tensors.add(allocation.tensor.id)
                    logger.debug("spilled %s to secondary storage", tensor_id)
                has_spill = True

            self.allocator.__reinit__()  # only supported for L2 allocation
            if not has_spill:
                break

        logger.debug("Spilled tensors %s", spilled_tensors)

        # Process in order
        for allocation in to_allocate:
            start, tensor_id = allocation.range.start, allocation.tensor.id
            current_op = execution_order[start]
            logger.debug(
                "\n ======= Layer %s, Operator %s ======= \nCurrent allocations %s",
                start,
                current_op.id,
                self.allocator.allocations.keys(),
            )

            for dealloc in self.allocator.free_expired_allocations(start):
                allocation_results[start].append(
                    (AllocationResult.DEALLOCATED, dealloc.copy())
                )

            if allocation.tensor.id in spilled_tensors:
                self.allocator.mark_allocation_as_spilled(allocation)
                logger.debug(
                    "spilled %s (marked) with size %s, stats %s",
                    tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                )
                allocation_results[start].append(
                    (AllocationResult.SPILLED, allocation.copy())
                )
                continue

            if self.allocator.allocate_in_place(allocation, execution_order[start]):
                logger.debug(
                    "allocated %s (in place) with size %s, stats %s",
                    tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                )
                allocation_results[start].append(
                    (AllocationResult.ALLOCATED_IN_PLACE, allocation.copy())
                )
            elif self.allocator.allocate(allocation):
                logger.debug(
                    "allocated %s with size %s, stats %s",
                    tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                )
                allocation_results[start].append(
                    (AllocationResult.ALLOCATED, allocation.copy())
                )
            else:
                if not enable_spilling:
                    raise RuntimeError("Must not spill")
                success, spills = self.allocator.allocate_with_spilling(allocation)
                if success:
                    logger.debug(
                        "allocated %s (after spilling) with size %s, stats %s",
                        tensor_id, allocation.tensor.size, self.allocator.get_memory_usage()
                    )
                    for spill in spills:
                        allocation_results[start].append(
                            (AllocationResult.SPILLED, spill.copy())
                        )
                    allocation_results[start].append(
                        (AllocationResult.ALLOCATED_WITH_SPILLING, allocation.copy())
                    )
                else:
                    self.allocator.mark_allocation_as_spilled(allocation)
                    logger.debug("spilled %s to secondary storage", tensor_id)
                    allocation_results[start].append(
                        (AllocationResult.SPILLED, allocation.copy())
                    )

        self.allocator.convert_to_aligned(allocation_results.values())
        return allocation_results

    def get_allocation_summary(
        self,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]],
    ) -> Dict[str, int]:
        """Produces a summary of allocations and spillings"""
        results = allocation_results.values()
        counts = Counter([ar for res in results for ar, _ in res])
        return {
            "total_layers_allocated": len(allocation_results.keys()),
            "tensors_allocated": counts[AllocationResult.ALLOCATED],
            "tensors_allocated_with_spilling": counts[
                AllocationResult.ALLOCATED_WITH_SPILLING
            ],
            "tensors_allocated_in_place": counts[
                AllocationResult.ALLOCATED_IN_PLACE
            ],
            "tensors_spilled": counts[AllocationResult.SPILLED],
        }


class GraphMemorySchedulerAndTiler(GraphMemoryScheduler):
    """Schedules operators and performs tiling of DAG"""

    def __init__(self, alloc_config: AllocationConfig, memory_config: MemoryConfig):
        super().__init__(alloc_config, [block.copy() for block in memory_config.memory])
        self.memory_config = memory_config

    def write_tiling(
        self,
        model: onnx.ModelProto,
        tiling_json_path: str,
        allocation_results: Dict[int, List[Tuple[AllocationResult, TensorAllocation]]],
    ) -> None:
        """Write tiling information to a JSON."""
        tiling: TilingSchema = {}
        tiler = Tiler(model, self.memory_config)

        for step, op in enumerate(self.get_execution_order()):
            assert step in allocation_results
            tiling[str(step)] = tiler.get_operator_tiling(op, allocation_results[step])

        ordered_keys = sorted(tiling.keys(), key=int)

        with open(tiling_json_path, "w", encoding="utf-8") as f:
            json.dump(
                {key: tiling[key].model_dump() for key in ordered_keys},
                f,
                indent=4,
                sort_keys=False,
            )


def get_memory_config_for_model(model_name: str) -> MemoryConfig:
    """Try to find an optimal memory config for the given ONNX model"""
    mapping = {
        "ResNet50_INT8_Model.onnx": COL1_MEMORY_CONFIG,
        "YoloV3_INT8_Model.onnx": START_END_MEMORY_CONFIG,
        "YoloV3_INT8_Model_sub_graph.onnx": END_MEMORY_CONFIG,
        "YoloV3_INT8_Model_cleaned_subgraph.onnx": END_MEMORY_CONFIG
    }
    return mapping.get(model_name, END_MEMORY_CONFIG)


def build_parser() -> argparse.ArgumentParser:
    """Builds the command line parser"""
    parser = argparse.ArgumentParser(
        description="Generate tensor allocation and spilling information for L2 memory."
    )
    parser.add_argument(
        "-m",
        "--model_path",
        default=os.path.join(CURRDIR, "ResNet50_INT8_Model.onnx"),
        help="Path to input ONNX model",
    )
    parser.add_argument(
        "-d",
        "--model_dtype",
        default="TensorProto.INT8",
        help="Data-type for all tensors",
    )
    parser.add_argument(
        "-j",
        "--fusion_json_path",
        default=os.path.join(CURRDIR, "L2_fusion_tiling.json"),
        help="Path to L2 fusion tiling JSON",
    )
    parser.add_argument(
        "--write_allocation_events",
        action="store_true",
        help="Write L2 fusion allocation events",
    )
    parser.add_argument(
        "--allocation_events_path",
        default=os.path.join(CURRDIR, "L2_fusion_allocation_events.json"),
        help="Path to allocation events json",
    )
    parser.add_argument(
        "-v",
        "--verbosity",
        default="INFO",
        choices=["DEBUG", "INFO"],
        help="Set the logging level",
    )
    return parser


def main(argv: Any = None) -> None:
    """Main function"""
    parser = build_parser()
    args = parser.parse_args(argv)

    # define logger verbosity level
    logger.setLevel(getattr(logging, args.verbosity))
    logger.debug("Provided command line arguments %s", vars(args))

    # Load ONNX model
    logger.debug("Loading ONNX model")
    model = remove_nodes_by_op_type_chained(
        onnx_model_path=args.model_path,
        ops_to_remove=IGNORED_OPS,
    )

    # Construct a directed acyclic graph
    logger.debug("Constructing a directed acyclic graph from the ONNX model")
    dag = construct_op_dag(model)

    # Create a graph memory scheduler to schedule memory by traversing the DAG
    scheduler = GraphMemorySchedulerAndTiler(
        AllocationConfig(AllocationStrategy.BEST_FIT),
        get_memory_config_for_model(Path(args.model_path).name),
    )
    for t in dag["tensors"]:
        scheduler.add_tensor(Tensor(t["scoped_name"], tuple(t["shape"]), args.model_dtype,
                                    is_constant=t.get("kind") == "init"))
    for op in dag["ops"]:
        # pylint: disable=C3001, W0640
        st = lambda items: [scoped_tensor(op["scope"], item) for item in items]  # noqa: E731
        scheduler.add_operation(
            Operation(
                scoped_id(op["scope"], op["name"]), op["op_type"],
                st(op["inputs"]), st(op["outputs"])
            )
        )

    # schedule tensors following the DAG topologically
    logger.debug("Sheduling memory")
    allocation_results = scheduler.schedule_memory()
    logger.info(scheduler.get_allocation_summary(allocation_results))

    # Write allocation events to a json
    if args.write_allocation_events:
        events = {
            k: {a[1].tensor.id: a[0] for a in v} for k, v in allocation_results.items()
        }
        with open(args.allocation_events_path, "w", encoding="utf-8") as f:
            json.dump(dict(dag, **{"events": events}), f, indent=2)

    # Write tiling to a json
    logger.debug("Writing tiling information")
    scheduler.write_tiling(model, args.fusion_json_path, allocation_results)


if __name__ == "__main__":
    main()
