"""
Different allocation strategies for static ONNX computational graphs
"""

from collections import defaultdict
from enum import Enum
import logging
import os
from pathlib import Path
from typing import Tuple, List, Dict, Any
from dataclasses import replace
import json

import onnx
from pydantic import Field

# pylint: disable=import-error,redefined-outer-name,no-name-in-module
from graph.allocation_types import (
    AllocDict,
    Alloc,
    AllocationConfig,
    AllocationResult,
    AllocationStrategy,
    MemoryBlock,
    AllocationAlignment,
    MemoryConfig,
    is_non_overlapping,
    TensorAllocation,
    TensorLifetime,
    TensorLocation
)
from graph.dag import IGNORED_OPS, construct_op_dag, scoped_id, scoped_tensor, descoped_id
from graph.multibin_allocator import MultiBinGraphMemoryScheduler
from graph.typed_parser import TypedBaseModel, build_typed_parser
from graph.tensor_types import Operation, Tensor, XrtId
from graph.L2_fusion import GraphMemoryScheduler, get_memory_config_for_model, GraphOps
from graph.tensor_memory_allocator import TensorMemoryAllocator
from graph.L2_fusion_tiling import Tiler as L2Tiler, GenericUnsupportedOp
from graph.L3_fusion_tiling import Tiler as L3Tiler, GenericOperator, as_base
from graph.utilities import logger, load_model_without_data
from graph.allocation_types import SubgraphAllocationMode
from graph.continuous_memory_allocator import ContinuousMemoryAllocator
from utils.utils_common import is_save_subgraph_json


CURRDIR = os.path.dirname(os.path.abspath(__file__))


class MultiLevelGraphMemoryScheduler:
    """Multi-level graph memory scheduler. Spills from level 0 to level 1, and so on."""

    class Level(Enum):
        """Memory levels"""
        L0 = 0
        L1 = 1

    def __init__(
        self,
        level0_scheduler: GraphMemoryScheduler,
        level1_scheduler: MultiBinGraphMemoryScheduler,
    ):
        self.level0_scheduler = level0_scheduler
        self.level1_scheduler = level1_scheduler

    def add_tensor(self, tensor: Tensor):
        """Add a tensor node to the DAG"""
        self.level0_scheduler.add_tensor(tensor.copy())
        self.level1_scheduler.add_tensor(tensor.copy())

    def update_tensor(self, tensor_id: str, level: Level, **kwargs):
        """Update a tensor node"""
        if level == self.Level.L0:
            self.level0_scheduler.update_tensor(tensor_id, **kwargs)
        elif level == self.Level.L1:
            self.level1_scheduler.update_tensor(tensor_id, **kwargs)

    def add_operation(self, operation: Operation):
        """Add an operator node to the DAG"""
        self.level0_scheduler.add_operation(operation.copy())
        self.level1_scheduler.add_operation(operation.copy())

    def schedule_memory(self, skip_level0: bool = False, is_nonwaic: bool = False) -> Tuple[AllocDict, AllocDict]:
        """Schedule memory for the DAG"""
        level0_allocs = self.level0_scheduler.schedule_memory(is_nonwaic=is_nonwaic)
        level1_tensors_to_alloc = set()

        l0_spilled_tensor_ids = set()
        if skip_level0:
            for results in level0_allocs.values():
                for i, (result, allocation) in enumerate(results):
                    if result != AllocationResult.DEALLOCATED:
                        l0_spilled_tensor_ids.add(allocation.tensor.id)
                    results[i] = (AllocationResult.SPILLED, allocation)

        # Find operators which have been spilled to level 1
        for _, results in level0_allocs.items():
            for result, allocation in results:
                if result == AllocationResult.SPILLED:
                    tensor_id = allocation.tensor.id
                    level1_tensors_to_alloc.add(tensor_id)
                    if not skip_level0:
                        l0_spilled_tensor_ids.add(tensor_id)

        # Spilled tensors have a non-zero size
        for tensor in self.level1_scheduler.tensors.values():
            if tensor.id not in level1_tensors_to_alloc and (not tensor.is_model_io):
                object.__setattr__(tensor, "size", 0)
            else:
                logger.debug("Tensor %s to be allocated on L3", tensor.id)
        level1_allocs = self.level1_scheduler.schedule_memory()

        # Ensure that tensors spilled to level 1 are allocated in level 1
        l1_alloc_tensor_ids = set()
        for _, l1_alloc in level1_allocs.items():
            for result, allocation in l1_alloc:
                assert result != AllocationResult.SPILLED
                if result != AllocationResult.DEALLOCATED:
                    l1_alloc_tensor_ids.add(allocation.tensor.id)
        assert l0_spilled_tensor_ids.issubset(l1_alloc_tensor_ids), "Some spilled tensors were not allocated in level 1"

        # sanity check that the execution order is the same for both levels
        assert self.level0_scheduler.get_execution_order() == self.level1_scheduler.get_execution_order()
        return (level0_allocs, level1_allocs)

    def print_allocation_summary(self, level0_allocs: AllocDict, level1_allocs: AllocDict):
        """Print allocation summary"""
        logger.info("Level 0:")
        logger.info(self.level0_scheduler.get_allocation_summary(level0_allocs))
        logger.info("Level 1:")
        logger.info(self.level1_scheduler.get_allocation_summary(level1_allocs))


class MultiLevelTiler:
    """Schedules operators and performs tiling of DAG"""

    def __init__(
        self,
        scheduler: MultiLevelGraphMemoryScheduler,
        level0_memory: MemoryConfig
    ):
        self.scheduler = scheduler
        self.level0_memory_config = level0_memory

    def write_tiling(
        self,
        model: onnx.ModelProto,
        tiling_json_path: str,
        allocation_results: Tuple[AllocDict, AllocDict],
        is_nonwaic: bool = False
    ) -> None:
        """Write tiling information to a JSON."""
        tiling = {}
        l0_tiler = L2Tiler(model, self.level0_memory_config)
        l1_tiler = L3Tiler(model)
        level0_allocs, level1_allocs = allocation_results

        for op_id, op in enumerate(self.scheduler.level0_scheduler.get_execution_order()):
            level0_alloc = level0_allocs[op_id]
            level0_tiling = l0_tiler.get_operator_tiling(op, level0_alloc, is_nonwaic=is_nonwaic)

            level1_tiling = {}
            if not level0_tiling.enable_L2_fusion:
                level1_tiling = l1_tiler.get_operator_tiling(op, level1_allocs[op_id]).to_l3()

            if isinstance(level0_tiling, GenericUnsupportedOp):
                level1_tiling = {}

            tiling[str(op_id)] = {
                **level0_tiling.model_dump(),
                **level1_tiling,
            }

        ordered_keys = sorted(tiling.keys(), key=int)

        with open(tiling_json_path, "w", encoding="utf-8") as f:
            json.dump(
                {key: tiling[key] for key in ordered_keys},
                f,
                indent=4,
                sort_keys=False,
            )


class Verbosity(str, Enum):
    """Verbosity levels"""
    INFO = "INFO"
    DEBUG = "DEBUG"


class Command(TypedBaseModel):
    """Command line arguments"""
    model_path: str = Field(
        default=os.path.join(CURRDIR, "ResNet50_INT8_Model.onnx"),
        description="Path to input ONNX model",
        alias="m",
    )
    fusion_json_path: str = Field(
        default=os.path.join(CURRDIR, "L2L3_fusion_tiling.json"),
        description="Path to L2L3 fusion tiling JSON",
        alias="j",
    )
    both_l2l3: bool = Field(default=False, description="Allocate both on L2 and L3")
    c64: bool = Field(
        default=False,
        description="Round-up C dimension of tensor to nearest multiple of 64",
    )
    l1_fuse: bool = Field(default=False, description="Attempt L1 fusion")
    verbosity: Verbosity = Field(
        default=Verbosity.INFO, description="Set the logging level", alias="v"
    )
    is_nonwaic: bool = Field(default=False, description="Has this graph been NOT generated by WAIC? Applies to QDQ models such as YoloV3.")


def create_level3_memconfig(start: int = 0):
    """Create a level 3 memory config"""
    GB = 1024 * 1024 * 1024
    return [
        MemoryBlock(
            start,
            3 * GB,
            True,
        )
    ]


def create_level3_scheduler(execution_order: List[Operation] | None = None) -> MultiBinGraphMemoryScheduler:
    """Create a level 3 scheduler"""
    level3_memory_config = create_level3_memconfig()
    return MultiBinGraphMemoryScheduler.build(
        [
            (
                AllocationConfig(
                    AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4, XrtId.IFM
                ),
                [memory_config.copy() for memory_config in level3_memory_config],
            ),
            (
                AllocationConfig(
                    AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4, XrtId.OFM
                ),
                [memory_config.copy() for memory_config in level3_memory_config],
            )
        ],
        execution_order=execution_order
    )


def main(args: Command, execution_order: List[Operation] | None = None) -> None:
    """Main function"""
    # define logger verbosity level
    logger.setLevel(getattr(logging, args.verbosity))
    logger.info("Provided command line arguments %s", vars(args))

    # load onnx model
    model = load_model_without_data(args.model_path)
    logger.debug("Loaded model from %s", args.model_path)

    # Construct a directed acyclic graph
    logger.debug("Constructing a directed acyclic graph from the ONNX model")
    dag = construct_op_dag(model, ignored_ops=IGNORED_OPS if args.l1_fuse else set())

    # Create a graph memory scheduler to schedule memory by traversing the DAG
    level2_memory_config = get_memory_config_for_model(Path(args.model_path).name)
    level2_scheduler = GraphMemoryScheduler(
        AllocationConfig(AllocationStrategy.BEST_FIT),
        [block.copy() for block in level2_memory_config.memory],
        execution_order=execution_order,
    )

    # Create a level 3 scheduler
    logger.debug("Creating L3 scheduler")
    level3_scheduler = create_level3_scheduler(execution_order=execution_order)

    # Create a multi-level scheduler
    logger.debug("Creating multi-level scheduler")
    scheduler = MultiLevelGraphMemoryScheduler(level2_scheduler, level3_scheduler)
    for t in dag["tensors"]:
        scheduler.add_tensor(Tensor(t["scoped_name"], tuple(t["shape"]), t["dtype"],
                                    is_constant=t.get("kind") == "init",
                                    is_channel_multiple_of_64=args.c64))

    # for L3, all tensors including scratch go to OFM bin by default
    for t in dag["tensors"]:
        scheduler.update_tensor(
            t["scoped_name"],
            scheduler.Level.L1,
            bin=XrtId.OFM,
        )
    # for L3, the inputs are in IFM bin
    for t in dag["model_inputs"]:
        scheduler.update_tensor(
            t,
            scheduler.Level.L1,
            bin=XrtId.IFM,
            is_model_io=True,
        )
    # for L3, the outputs are in OFM bin
    for t in dag["model_outputs"]:
        scheduler.update_tensor(
            t,
            scheduler.Level.L1,
            bin=XrtId.OFM,
            is_model_io=True,
        )

    # for L2, the inputs/outputs must be allocated
    for t in (dag["model_inputs"] | dag["model_outputs"]):
        scheduler.update_tensor(t, scheduler.Level.L0, is_model_io=True)

    for op in dag["ops"]:
        # pylint: disable=C3001, W0640
        st = lambda items: [scoped_tensor(op["scope"], item) for item in items]  # noqa: E731
        scheduler.add_operation(
            Operation(
                scoped_id(op["scope"], op["name"]), op["op_type"],
                st(op["inputs"]), st(op["outputs"])
            )
        )

    # schedule tensors following the DAG topologically
    logger.debug("Sheduling memory")
    allocation_results = scheduler.schedule_memory(skip_level0=not args.both_l2l3, is_nonwaic=args.is_nonwaic)
    scheduler.print_allocation_summary(*allocation_results)

    # Write tiling to a json
    logger.debug("Writing tiling information")
    tiler = MultiLevelTiler(scheduler, level2_memory_config)
    tiler.write_tiling(model, args.fusion_json_path, allocation_results, is_nonwaic=args.is_nonwaic)


class CommandSubgraph(Command):
    """Update allocations"""
    fusion_json: Dict[str, Any] = Field(description="Fusion tiling JSON")
    subgraph_ops: Dict[int, List[str]] = Field(
        description="Operations in each partitioned subgraph")
    subgraph_ios: Dict[int, Tuple[List[str], List[str]]] = Field(
        description="Input and output tensors in each partitioned subgraph")
    allocation_mode: SubgraphAllocationMode = Field(
        default=SubgraphAllocationMode.LOCAL, description="Track allocations across subgraphs"
    )


def update_allocations(args: CommandSubgraph) -> None | Dict[int, List[str]]:
    """Update allocations function"""

    # define logger verbosity level
    logger.setLevel(getattr(logging, args.verbosity))

    # L2 allocation is not supported for subgraphs
    if args.both_l2l3:
        raise ValueError("L2 allocation is not supported for subgraphs")

    # L1 fusion is not supported for subgraphs
    if args.l1_fuse:
        raise ValueError("L1 fusion is not supported for subgraphs")

    # load onnx model
    model = load_model_without_data(args.model_path)
    logger.debug("Loaded model from %s", args.model_path)

    # Construct a directed acyclic graph
    logger.debug("Constructing a directed acyclic graph from the ONNX model")
    dag = construct_op_dag(model, ignored_ops=IGNORED_OPS if args.l1_fuse else set())

    # Find the operators that are in the dag
    ops = {op["name"]: op for op in dag["ops"]}
    tensors = {t["name"]: t for t in dag["tensors"]}

    # find operator id in alloc json
    op_id_map = {v["name"]: k for k, v in args.fusion_json.items()}

    if args.allocation_mode == SubgraphAllocationMode.GLOBAL:
        return update_allocations_global(args, ops)
    elif args.allocation_mode == SubgraphAllocationMode.LOCAL:
        return update_allocations_local(args, ops, tensors, dag, op_id_map)
    else:
        return update_allocations_continuous(args, ops, tensors, op_id_map)


def update_allocations_local(
    args: CommandSubgraph,
    ops: Dict[str, Any],
    tensors: Dict[str, Any],
    dag: Dict[str, Any],
    op_id_map: Dict[str, Any],
) -> None:
    """Update allocations function - local tracking"""
    subgraph_ops, subgraph_ios, fusion_json = (
        args.subgraph_ops,
        args.subgraph_ios,
        args.fusion_json,
    )

    # Subgraph level allocation
    for subgraph_id in subgraph_ops:
        scheduler = create_level3_scheduler()

        # pylint: disable=C3001, W0640
        for op_id in subgraph_ops[subgraph_id]:

            # Skip operators that are not in the subgraph
            if op_id not in ops:
                raise ValueError(
                    f"Subgraph {subgraph_id} contains operators not found in model: {op_id}"
                )

            # Add the operators to the scheduler
            op = ops[op_id]
            st = lambda items: [scoped_tensor(op["scope"], item) for item in items]  # noqa: E731
            scheduler.add_operation(
                Operation(
                    scoped_id(op["scope"], op["name"]), op["op_type"],
                    st(op["inputs"]), st(op["outputs"])
                )
            )

            # add tensors to scheduler
            for tensor_id in op["inputs"] + op["outputs"]:
                t = tensors[tensor_id]
                scheduler.add_tensor(Tensor(t["scoped_name"], tuple(t["shape"]), t["dtype"],
                                            bin=XrtId.OFM, is_channel_multiple_of_64=args.c64,
                                            is_constant=t.get("kind") == "init"))

        # global ONNX inputs go to ifm bin
        inps, outs = subgraph_ios[subgraph_id]
        for inp in inps:
            scheduler.update_tensor(
                tensors[inp]["scoped_name"],
                bin=XrtId.IFM if inp in dag["model_inputs"] else XrtId.OFM,
                is_channel_multiple_of_64=args.c64,
                is_model_io=True,
            )

        for out in outs:
            scheduler.update_tensor(
                tensors[out]["scoped_name"],
                bin=XrtId.OFM,
                is_channel_multiple_of_64=args.c64,
                is_model_io=True,
            )

        # schedule tensors following the DAG topologically
        logger.debug("Scheduling memory")
        allocation_results = scheduler.schedule_memory()
        updated_addr = {}
        for allocs in allocation_results.values():
            for result, allocation in allocs:
                if result not in [AllocationResult.DEALLOCATED, AllocationResult.SPILLED]:
                    updated_addr[allocation.tensor.id] = (result, allocation)

        # update fusion json
        for op_id in subgraph_ops[subgraph_id]:
            fusion_op_id = op_id_map[op_id]
            op_scoped_id = scoped_id(ops[op_id]["scope"], ops[op_id]["name"])
            op = scheduler.get_operation_by_name(op_scoped_id)
            fusion_json[str(fusion_op_id)]["L3"] = GenericOperator.update_l3(op, updated_addr)

        # find addresses for all inputs
        input_addr = defaultdict(list)
        for inp in inps:
            alloc = updated_addr[tensors[inp]["scoped_name"]][1]
            input_addr[alloc.tensor.bin.value].append((alloc.block.start, alloc.block.start + alloc.tensor.size))

        # assert all(is_non_overlapping(sorted(v)) for v in input_addr.values())

    # write updated fusion json
    Path(args.fusion_json_path).write_text(json.dumps(fusion_json, indent=2), encoding="utf-8")


def update_allocations_global(args: CommandSubgraph, ops: Dict[str, Any]) -> None:
    """Update allocations function - global tracking"""
    subgraph_ops, subgraph_ios = args.subgraph_ops, args.subgraph_ios

    # find the ops that are missing
    contracted_ops = set().union(*subgraph_ops.values())
    global_ops = set(ops.keys())
    missing_ops = global_ops - contracted_ops

    # find a topological order of the contracted graph
    global_ops = GraphOps()

    # add subgraphs: model each subgraph as a contracted super op with input and output tensors
    for subgraph_id in subgraph_ops:
        global_ops.add_operation(
            Operation(
                id=str(subgraph_id),
                type="Subgraph",
                inputs=subgraph_ios[subgraph_id][0],
                outputs=subgraph_ios[subgraph_id][1],
            )
        )

    # add missing ops: model each missing op as a singleton super-op
    for op in missing_ops:
        global_ops.add_operation(Operation.from_dict(**ops[op]))

    # compute topological order
    execution_order = global_ops.get_execution_order(recompute=True)
    assert len(execution_order) == len(global_ops.operations), "Graph has at least one cycle"

    # Compute a linear execution order of all operators
    global_order: List[Operation] = []
    for op in execution_order:
        if op.id in missing_ops:
            global_order.append(Operation.from_dict_scope(**ops[op.id]))
        else:
            # find topological order within the subgraph
            local_ops = GraphOps()
            for sub_op_id in subgraph_ops[int(op.id)]:
                local_ops.add_operation(Operation.from_dict_scope(**ops[sub_op_id]))
            local_execution_order = local_ops.get_execution_order(recompute=True)
            global_order += local_execution_order
    assert len(global_order) == len(ops), "Some operators are missing in the global execution order"

    # Allocate as per the execution order
    main(as_base(args, Command), execution_order=global_order)


def create_level3_scheduler_continuous(continous_allocator: ContinuousMemoryAllocator) -> MultiBinGraphMemoryScheduler:
    """Create a level 3 scheduler"""

    for block in continous_allocator.memory_blocks:
        logger.debug("block: free:%s, start:%s, size:%s, tensor_id:%s", block.is_free, block.start, block.size, block.tensor_id)

    for alloc in continous_allocator.allocations.values():
        logger.debug("alloc: loc:%s tensor:%s, block:%s", alloc.location, alloc.tensor, alloc.block)

    if os.getenv("AIE4_ALLOCATOR_MODE") == "CONTINUOUS":
        os.environ["AIE4_FORCE_ALLOCATOR_MODE_L3"] = "CONTINUOUS"

    return MultiBinGraphMemoryScheduler(
        {
            XrtId.IFM: ContinuousMemoryAllocator(
                AllocationConfig(
                    AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4, XrtId.IFM
                ),
                create_level3_memconfig(),
            ),
            XrtId.OFM: TensorMemoryAllocator.from_allocations(
                AllocationConfig(
                    AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4, XrtId.OFM
                ),
                continous_allocator.memory_blocks,
                {k: replace(v, is_deallocatable=False) for k, v in continous_allocator.allocations.items()},
            )
        }
    )


def validate_addr(tensors: Dict[str, Any], io: List[str], updated_addr: Dict[str, Alloc]):
    """Validate address"""

    # find addresses for all IO tensors
    addrs = defaultdict(list)
    for tensor_id in io:
        alloc = updated_addr[tensors[tensor_id]["scoped_name"]][1]
        addrs[alloc.tensor.bin.value].append((alloc.block.start, alloc.block.start + alloc.tensor.size))

    # check if all of them are in the same bin
    assert len(addrs) == 1

    # check if it's all non-overlapping
    assert all(is_non_overlapping(sorted(v)) for v in addrs.values())

    # check if they start with zero
    for addr in addrs.values():
        assert any(a[0] == 0 for a in addr)


def update_allocations_continuous(
    args: CommandSubgraph,
    ops: Dict[str, Any],
    tensors: Dict[str, Any],
    op_id_map: Dict[str, Any],
) -> Dict[int, List[str]]:
    """Update allocations function - local tracking"""
    subgraph_ops, subgraph_ios, fusion_json = (
        args.subgraph_ops,
        args.subgraph_ios,
        args.fusion_json,
    )

    topo_order: Dict[int, List[str]] = {}

    # Subgraph level allocation
    for subgraph_id in subgraph_ops:
        logger.info("Allocating for subgraph %s", subgraph_id)

        # allocate outputs
        continous_allocator = ContinuousMemoryAllocator(
            AllocationConfig(
                AllocationStrategy.BEST_FIT, AllocationAlignment.KB_4, XrtId.OFM
            ),
            [memory_config.copy() for memory_config in create_level3_memconfig()],
        )
        out_allocs = {}
        for out in subgraph_ios[subgraph_id][1]:
            t = tensors[out]
            tensor = Tensor(t["scoped_name"], tuple(t["shape"]), t["dtype"],
                            bin=XrtId.OFM, is_channel_multiple_of_64=args.c64,
                            is_constant=t.get("kind") == "init")
            to_alloc = TensorAllocation(
                tensor, TensorLifetime(0, 0), TensorLocation.UNKNOWN)
            continous_allocator.allocate(to_alloc)
            out_allocs[tensor.id] = to_alloc.copy()

        min_addr, _ = continous_allocator.get_allocated_region_bounds()
        assert min_addr == 0

        # create a scheduler
        scheduler = create_level3_scheduler_continuous(continous_allocator)

        # structure for dumping subgraph data
        subgraph_json_path = Path(args.fusion_json_path).parent / f"{Path(args.fusion_json_path).stem}_subgraph_{subgraph_id}.json"
        subgraph_json = {}

        # pylint: disable=C3001, W0640
        for op_id in subgraph_ops[subgraph_id]:

            # Skip operators that are not in the subgraph
            if op_id not in ops:
                raise ValueError(
                    f"Subgraph {subgraph_id} contains operators not found in model: {op_id}"
                )

            # Add the operators to the scheduler
            op = ops[op_id]
            st = lambda items: [scoped_tensor(op["scope"], item) for item in items]  # noqa: E731
            scheduler.add_operation(
                Operation(
                    scoped_id(op["scope"], op["name"]), op["op_type"],
                    st(op["inputs"]), st(op["outputs"]),
                    attributes=op["attributes"]
                )
            )

            # By default all tensors go to scratch bin
            for tensor_id in op["inputs"] + op["outputs"]:
                t = tensors[tensor_id]
                scheduler.add_tensor(Tensor(t["scoped_name"], tuple(t["shape"]), t["dtype"],
                                            bin=XrtId.OFM, is_channel_multiple_of_64=args.c64,
                                            is_constant=t.get("kind") == "init"))

        # ONNX inputs go to ifm bin
        inps, outs = subgraph_ios[subgraph_id]
        for inp in inps:
            scheduler.update_tensor(
                tensors[inp]["scoped_name"],
                bin=XrtId.IFM,
                is_channel_multiple_of_64=args.c64,
                is_model_io=True,
            )

        # ONNX outputs are already allocated
        for out in outs:
            scheduler.update_tensor(
                tensors[out]["scoped_name"],
                bin=XrtId.OFM,
                is_channel_multiple_of_64=args.c64,
                is_model_io=True,
            )

        # schedule tensors following the DAG topologically
        logger.debug("Scheduling memory")
        allocation_results = scheduler.schedule_memory(enable_noop_optim=True, enable_runtime_optim=True)
        updated_addr = {}
        for allocs in allocation_results.values():
            for result, allocation in allocs:
                if result not in [AllocationResult.DEALLOCATED, AllocationResult.SPILLED]:
                    updated_addr[allocation.tensor.id] = (result, allocation)

        # add outputs
        for tensor_id, new_alloc in out_allocs.items():
            assert tensor_id in updated_addr
            assert new_alloc.tensor.size != 0
            old_alloc = updated_addr[tensor_id]
            assert old_alloc[0] == AllocationResult.ALLOCATED
            assert old_alloc[1].tensor.id == new_alloc.tensor.id
            assert old_alloc[1].tensor.shape == new_alloc.tensor.shape
            assert old_alloc[1].tensor.dtype == new_alloc.tensor.dtype
            assert old_alloc[1].tensor.bin == new_alloc.tensor.bin
            assert old_alloc[1].tensor.size == new_alloc.tensor.size
            assert old_alloc[1].tensor.is_channel_multiple_of_64 == new_alloc.tensor.is_channel_multiple_of_64
            # assert old_alloc[1].tensor.is_model_io != new_alloc.tensor.is_model_io
            assert old_alloc[1].block.is_free == new_alloc.block.is_free
            assert old_alloc[1].block.tensor_id == new_alloc.block.tensor_id
            assert old_alloc[1].block.size == new_alloc.block.size
            # assert old_alloc[1].block.start == new_alloc.block.start # due to alignment

        # update fusion json
        for op_id in subgraph_ops[subgraph_id]:
            fusion_op_id = op_id_map[op_id]
            op_scoped_id = scoped_id(ops[op_id]["scope"], ops[op_id]["name"])
            op = scheduler.get_operation_by_name(op_scoped_id)
            fusion_json[str(fusion_op_id)]["L3"] = GenericOperator.update_l3(op, updated_addr).copy()
            subgraph_json[str(fusion_op_id)] = fusion_json[str(fusion_op_id)]

        # find addresses for all inputs
        validate_addr(tensors, inps, updated_addr)
        validate_addr(tensors, outs, updated_addr)
        if is_save_subgraph_json():
            subgraph_json_path.write_text(json.dumps(subgraph_json, indent=2), encoding="utf-8")

        topo_order[subgraph_id] = [descoped_id(op.id) for op in scheduler.get_execution_order()]

    # write updated fusion json
    Path(args.fusion_json_path).write_text(json.dumps(fusion_json, indent=2), encoding="utf-8")
    return topo_order


if __name__ == "__main__":
    parser = build_typed_parser(
        Command,
        description="Generate tensor allocation and spilling information for L2/L3 memory.",
    )
    args = parser.parse_args()
    main(Command.model_validate(vars(args), strict=False))
