"""
Different allocation strategies for static ONNX computational graphs
"""

import argparse
import logging
import os
from typing import List
import json

import onnx

# pylint: disable=import-error,redefined-outer-name,no-name-in-module
from graph.allocation_types import (
    AllocDict,
    AllocationConfig,
    AllocationStrategy,
    MemoryBlock,
    AllocationAlignment
)
from graph.dag import construct_op_dag, scoped_id, scoped_tensor
from graph.L3_fusion_tiling import TilingSchema, Tiler
from graph.tensor_types import Operation, Tensor
from graph.L2_fusion import GraphMemoryScheduler
from graph.utilities import logger


CURRDIR = os.path.dirname(os.path.abspath(__file__))


class L3GraphMemorySchedulerAndTiler(GraphMemoryScheduler):
    """Schedules operators and performs tiling of DAG"""

    def __init__(self, alloc_config: AllocationConfig, memory_config: List[MemoryBlock]):
        super().__init__(alloc_config, memory_config)
        self.memory_config = memory_config

    def write_tiling(
        self,
        model: onnx.ModelProto,
        tiling_json_path: str,
        allocation_results: AllocDict,
    ) -> None:
        """Write tiling information to a JSON."""
        tiling: TilingSchema = {}
        tiler = Tiler(model)

        for step, op in enumerate(self.get_execution_order()):
            tiling[str(step)] = tiler.get_operator_tiling(op, allocation_results[step])

        ordered_keys = sorted(tiling.keys(), key=int)

        with open(tiling_json_path, "w", encoding="utf-8") as f:
            json.dump(
                {key: tiling[key].model_dump() for key in ordered_keys},
                f,
                indent=4,
                sort_keys=False,
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate allocation and spilling information for L3 memory."
    )
    parser.add_argument(
        "-m",
        "--model_path",
        default=os.path.join(CURRDIR, "ResNet50_INT8_Model.onnx"),
        help="Path to input ONNX model",
    )
    parser.add_argument(
        "-d",
        "--model_dtype",
        default="TensorProto.INT8",
        help="Data-type for all tensors",
    )
    parser.add_argument(
        "-j",
        "--fusion_json_path",
        default=os.path.join(CURRDIR, "L3_fusion_tiling.json"),
        help="Path to L3 fusion tiling JSON",
    )
    parser.add_argument(
        "-v",
        "--verbosity",
        default="INFO",
        choices=["DEBUG", "INFO"],
        help="Set the logging level",
    )
    args = parser.parse_args()

    # define logger verbosity level
    logger.setLevel(getattr(logging, args.verbosity))
    logger.debug("Provided command line arguments %s", vars(args))

    # load onnx model
    model = onnx.load(args.model_path)
    logger.debug("Loaded model from %s", args.model_path)

    # Construct a directed acyclic graph
    logger.debug("Constructing a directed acyclic graph from the ONNX model")
    dag = construct_op_dag(model, ignored_ops=set())

    # Create a graph memory scheduler to schedule memory by traversing the DAG
    GB = 1024 * 1024 * 1024
    scheduler = L3GraphMemorySchedulerAndTiler(
        AllocationConfig(AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4),
        [
            MemoryBlock(
                0,
                4 * GB,
                True,
            )
        ],
    )
    for t in dag["tensors"]:
        scheduler.add_tensor(Tensor(t["scoped_name"], tuple(t["shape"]), args.model_dtype))
    for op in dag["ops"]:
        # pylint: disable=C3001, W0640
        st = lambda items: [scoped_tensor(op["scope"], item) for item in items]  # noqa: E731
        scheduler.add_operation(
            Operation(
                scoped_id(op["scope"], op["name"]), op["op_type"],
                st(op["inputs"]), st(op["outputs"])
            )
        )

    # schedule tensors following the DAG topologically
    logger.debug("Sheduling memory")
    allocation_results = scheduler.schedule_memory()
    logger.info(scheduler.get_allocation_summary(allocation_results))

    # Write tiling to a json
    logger.debug("Writing tiling information")
    scheduler.write_tiling(model, args.fusion_json_path, allocation_results)
