# pylint: skip-file
'''
Fuse Resnet50 network and allocate L2 for each Node
'''
import os
import json
import sys
from typing import Optional, Collection
import onnx
from onnx import helper

from scheduler.common import LinearOpType
from utils.utils_common import iceil
from scheduler.conv.conv_config_builders import (
    ConvShape,
)

from tiler.conv_tiler import (
    generate_mappings,
)


CURRDIR = os.path.dirname(os.path.abspath(__file__))

max_overcompute_size = 0


def tag_conv_with_is_relu_inplace(graph):
    '''Mark the Conv  that are followed with ReLu'''
    # Build output -> node mapping
    output_to_node = {out: node for node in graph.node for out in node.output}

    # Tag Conv -> Relu with is_relu = 1
    for node in graph.node:
        if node.op_type == "Relu" and node.input:
            relu_input = node.input[0]
            if relu_input in output_to_node:
                producer = output_to_node[relu_input]
                if producer.op_type == "Conv":
                    if not any(attr.name == "is_relu" for attr in producer.attribute):
                        producer.attribute.append(helper.make_attribute("is_relu", 1))

    # Tag all remaining Conv with is_relu = 0
    for node in graph.node:
        if node.op_type == "Conv":
            if not any(attr.name == "is_relu" for attr in node.attribute):
                node.attribute.append(helper.make_attribute("is_relu", 0))


def ceildiv(x: int, d: int) -> int:
    '''method for ceil div'''
    return -(x // -d)


def filter_compute_nodes(model: onnx.ModelProto) -> list[onnx.NodeProto]:
    '''filter the nodes to be fused in the graph at L2 level'''
    compute_op_types = ('Conv', 'Add', 'Gemm', 'MaxPool', 'GlobalAveragePool')
    compute_nodes = [node for node in model.graph.node
                     if node.op_type in compute_op_types]
    return compute_nodes


def infer_tensor_shapes(model: onnx.ModelProto) -> dict:
    '''infer tensor shapes from nodes'''
    input_to_node = {node.input[0]: node for node in model.graph.node}

    def maybe_next(node: Optional[onnx.NodeProto]) -> Optional[onnx.NodeProto]:
        return (
            input_to_node[node.output[0]]
            if (node is not None) and (node.output[0] in input_to_node) else None
        )

    tensor_shapes = {}
    for input_ in model.graph.input:
        tensor_shapes[input_.name] = tuple(dim.dim_value
                                           for dim in input_.type.tensor_type.shape.dim)
    for info in model.graph.value_info:
        tensor_shapes[info.name] = tuple(dim.dim_value
                                         for dim in info.type.tensor_type.shape.dim)
    for output_ in model.graph.output:
        tensor_shapes[output_.name] = tuple(dim.dim_value
                                            for dim in output_.type.tensor_type.shape.dim)
    for init in model.graph.initializer:
        tensor_shapes[init.name] = tuple(val for val in init.dims)
    for node in model.graph.node:
        if node.op_type in ('QuantizeLinear', 'DequantizeLinear'):
            input_tensor = node.input[0]
            output_tensor = node.output[0]
            if input_tensor in tensor_shapes:
                tensor_shapes[output_tensor] = tensor_shapes[input_tensor]
    for node in model.graph.node:
        node0 = node
        node1 = maybe_next(node0)
        node2 = maybe_next(node1)
        if (node1 is not None) and (node2 is not None):
            is_qdq = (
                (node1.op_type == 'QuantizeLinear') and
                (node2.op_type == 'DequantizeLinear')
            )
            if is_qdq:
                if node0.output[0] in tensor_shapes:
                    tensor_name = node0.output[0]
                elif node1.input[0] in tensor_shapes:
                    tensor_name = node1.input[0]
                elif node1.output[0] in tensor_shapes:
                    tensor_name = node1.output[0]
                elif node2.input[0] in tensor_shapes:
                    tensor_name = node2.input[0]
                elif node2.output[0] in tensor_shapes:
                    tensor_name = node2.output[0]
                else:
                    assert False
                shape = tensor_shapes[tensor_name]
                tensor_shapes[node0.output[0]] = shape
                tensor_shapes[node1.input[0]] = shape
                tensor_shapes[node1.output[0]] = shape
                tensor_shapes[node2.input[0]] = shape
                tensor_shapes[node2.output[0]] = shape
    return tensor_shapes


def infer_ofm_names(model: onnx.ModelProto) -> dict:
    '''return the ofm edge list'''
    input_to_node = {node.input[0]: node for node in model.graph.node}

    def next_node(node: onnx.NodeProto) -> onnx.NodeProto:
        return input_to_node[node.output[0]]

    ofm_names = {}
    compute_nodes = filter_compute_nodes(model)
    for node in compute_nodes:
        ofm_node = next_node(node)
        nonlinear_functions = ('Clip', 'Relu')
        if ofm_node.op_type in nonlinear_functions:
            ofm_node = next_node(ofm_node)
        if ofm_node.op_type == 'QuantizeLinear':
            ofm_node = next_node(ofm_node)
        assert ofm_node.op_type == 'DequantizeLinear'
        ofm_names[node.name] = ofm_node.output[0]
    return ofm_names


def parse_conv_layer(node: onnx.NodeProto, tensor_shapes: dict, ofm_names: dict, allow_overcompute: bool) -> dict:
    '''function to extract conv shape info from onnx node'''
    assert node.op_type == 'Conv'
    attrs = {a.name: a.ints if a.ints else a.i for a in node.attribute}
    input_shape = tensor_shapes[node.input[0]]
    output_shape = tensor_shapes[node.output[0]]
    Ni, Ci, Yi, Xi = input_shape
    No, Co, Yo, Xo = output_shape
    assert Ni == No == 1
    act_type = attrs['is_relu']
    Ky, Kx = attrs['kernel_shape']
    Sy, Sx = attrs['strides']
    PyB, PxB, PyE, PxE = attrs['pads']
    assert PyB == PyE
    assert PxB == PxE
    Py, Px = PxB, PyB
    Dy, Dx = attrs['dilations']
    group = attrs['group']
    is_depthwise = Ci == Co == group
    assert is_depthwise or (group == 1)
    op = 'dwc' if is_depthwise else 'conv_noqdq_a8w8'
    print(f"Calling tiler for layer {node.name} with shape {input_shape} -> {output_shape}, ")
    conv_shape = ConvShape(
        (Yi, Xi, Ci if Ci > 8 else 8),
        (Yo, Xo, Co),
        (Ky, Kx),
        (Sy, Sx),
        (Py, Px),
        LinearOpType.conv_A8W8_noqdq, 0)
    mappings = generate_mappings(conv_shape, allow_overcompute)
    if len(mappings) == 0:
        raise ValueError("No valid mappings found for the given shape.")
    mapping = mappings[0]
    Yo_overcompute, Xo_overcompute, Co_overcompute = mapping.ofm_pad
    print(f"overcompute for layer {node.name} = {Yo_overcompute} x {Xo_overcompute} x {Co_overcompute} = {Yo_overcompute*Xo_overcompute*Co_overcompute} Bytes")
    global max_overcompute_size
    if allow_overcompute:
        max_overcompute_size = max(max_overcompute_size, Yo_overcompute*Xo_overcompute*Co_overcompute)
    else:
        max_overcompute_size = max(max_overcompute_size, Yo*Xo*Co)
    layer = {
        'name': node.name,
        'input_name': node.input[0],
        'output_name': ofm_names[node.name],
        'op': op,
        'input': (Ci, Yi, Xi),
        'output': (Co, Yo, Xo),
        'kernel': (Ky, Kx),
        'stride': (Sy, Sx),
        'pad': (Py, Px),
        'dilation': (Dy, Dx),
        'act_type': act_type
    }
    if allow_overcompute:
        layer['output_overcompute'] = (Co_overcompute, Yo_overcompute, Xo_overcompute)
    return layer


def parse_maxpool_layer(node: onnx.NodeProto, tensor_shapes: dict, ofm_names: dict, allow_overcompute: bool) -> dict:
    '''function to extract conv shape info from onnx node'''
    print(f"allow_overcompute = {allow_overcompute}")
    assert node.op_type == 'MaxPool'
    attrs = {a.name: a.ints if a.ints else a.i for a in node.attribute}
    input_shape = tensor_shapes[node.input[0]]
    output_shape = tensor_shapes[node.output[0]]
    Ni, Ci, Yi, Xi = input_shape
    No, Co, Yo, Xo = output_shape
    assert Ni == No == 1
    Ky, Kx = attrs['kernel_shape']
    Sy, Sx = attrs['strides']
    PyB, PxB, PyE, PxE = attrs['pads']
    ceil_mode = attrs['ceil_mode']
    assert PyB == PyE
    assert PxB == PxE
    Py, Px = PxB, PyB

    op = 'maxpool_noqdq_a8'
    layer = {
        'name': node.name,
        'input_name': node.input[0],
        'output_name': ofm_names[node.name],
        'op': op,
        'input': (Ci, Yi, Xi),
        'output': (Co, Yo, Xo),
        'kernel': (Ky, Kx),
        'stride': (Sy, Sx),
        'pad': (Py, Px),
        'ceil_mode': ceil_mode,
    }
    return layer


def parse_add_layer(node: onnx.NodeProto, tensor_shapes: dict, ofm_names: dict, allow_overcompute: bool) -> dict:
    '''function to extract add shape info from onnx node'''
    print(f"allow_overcompute = {allow_overcompute}")
    assert node.op_type == 'Add'
    input_shape = tensor_shapes[node.input[0]]
    output_shape = tensor_shapes[node.output[0]]
    Ni, Ci, Yi, Xi = input_shape
    No, Co, Yo, Xo = output_shape
    assert Ni == No
    assert Ci == Co
    assert Yi == Yo
    assert Xi == Xo
    layer = {
        'name': node.name,
        'input1_name': node.input[0],
        'input2_name': node.input[1],
        'output_name': ofm_names[node.name],
        'op': 'add_noqdq_a8',
        'input': (Ci, Yi, Xi),
        'output': (Co, Yo, Xo),
        'allow_overcompute': allow_overcompute,
    }
    if allow_overcompute:
        layer['input_overcompute'] = (0, 0, 0)
        layer['output_overcompute'] = (0, 0, 0)
    return layer


def parse_gemm_layer(node: onnx.NodeProto, tensor_shapes: dict, ofm_names: dict, allow_overcompute: bool) -> dict:
    '''function to extract gemm shape info from onnx node'''
    _ = ofm_names
    print(f"allow_overcompute = {allow_overcompute}")
    assert node.op_type == 'Gemm'
    input_shape = tensor_shapes[node.input[0]]
    output_shape = tensor_shapes[node.output[0]]
    Yi, Xi = input_shape
    Yo, Xo = output_shape
    Ci, Co = 1, 1
    Ky, Kx = 1, 1
    Sy, Sx = 1, 1
    Py, Px = 0, 0
    print(f"Calling tiler for layer {node.name} with shape {input_shape} -> {output_shape}, ")
    conv_shape = ConvShape(
        (Yi, Xi, Ci if Ci > 8 else 8),
        (Yo, Xo, Co),
        (Ky, Kx),
        (Sy, Sx),
        (Py, Px),
        LinearOpType.conv_A8W8_noqdq, 0)
    mappings = generate_mappings(conv_shape, allow_overcompute)
    if len(mappings) == 0:
        raise ValueError("No valid mappings found for the given shape.")
    mapping = mappings[0]
    Yo_overcompute, Xo_overcompute, Co_overcompute = mapping.ofm_pad
    print(f"overcompute for layer {node.name} = {Yo_overcompute} x {Xo_overcompute} x {Co_overcompute} = {Yo_overcompute*Xo_overcompute*Co_overcompute} Bytes")
    global max_overcompute_size
    if allow_overcompute:
        max_overcompute_size = max(max_overcompute_size, Yo_overcompute*Xo_overcompute*Co_overcompute)
    else:
        max_overcompute_size = max(max_overcompute_size, Yo*Xo*Co)
    layer = {
        'name': node.name,
        'input_name': "/avgpool/GlobalAveragePool_output_0_DequantizeLinear_Output",
        'output_name': node.output[0],
        'op': 'conv_noqdq_a8w8',
        'input': (Xi, Yi, Ci),
        'output': (iceil(Xo, 64), Yo, Co),
        'kernel': (Ky, Kx),
        'stride': (Sy, Sx),
        'pad': (Py, Px),
        'dilation': (1, 1),
        'act_type': 0
    }
    if allow_overcompute:
        layer['output_overcompute'] = (Co_overcompute, Yo_overcompute, Xo_overcompute)
    return layer


def parse_gap_layer(node: onnx.NodeProto, tensor_shapes: dict, ofm_names: dict, allow_overcompute: bool) -> dict:
    '''function to extract GlobalAveragePool shape info from onnx node'''
    _ = allow_overcompute
    assert node.op_type == 'GlobalAveragePool'
    input_shape = tensor_shapes[node.input[0]]
    output_shape = tensor_shapes[node.output[0]]
    Ni, Ci, Yi, Xi = input_shape
    No, Co, Yo, Xo = output_shape
    assert Ni == No == 1

    op = 'gap'
    layer = {
        'name': node.name,
        'input_name': node.input[0],
        'output_name': ofm_names[node.name],
        'op': op,
        'input': (Ci, Yi, Xi),
        'output': (Co, Yo, Xo),
    }
    return layer


def parse_layers(model: onnx.ModelProto, allow_overcompute: bool) -> list[dict]:
    '''parse important compute nodes in onnx graph'''
    layers = []
    ofm_to_overcompute = {}
    # Infer shift and clip parameters from graph
    tensor_shapes = infer_tensor_shapes(model)
    ofm_names = infer_ofm_names(model)
    compute_parse_fn = {
        'Conv': parse_conv_layer,
        'Add': parse_add_layer,
        'Gemm': parse_gemm_layer,
        'MaxPool': parse_maxpool_layer,
        'GlobalAveragePool': parse_gap_layer,
    }
    compute_nodes = filter_compute_nodes(model)
    num_layers = 0
    for node in compute_nodes:
        layer = compute_parse_fn[node.op_type](node, tensor_shapes, ofm_names, allow_overcompute)
        if allow_overcompute:
            key = ofm_names[node.name]
            if layer['op'] in ['conv_noqdq_a8w8', 'add_noqdq_a8']:
                ofm_to_overcompute[key] = layer['output_overcompute']
            else:
                ofm_to_overcompute[key] = layer['output']
            if layer['op'] == 'add_noqdq_a8':
                in0_size = ofm_to_overcompute[node.input[0]]
                in1_size = ofm_to_overcompute[node.input[1]]
                assert in0_size == in1_size, f"[ERR] Overcompute inputs for Add-{num_layers} are different"
                layer['input_overcompute'] = layer['output_overcompute'] = in0_size
                ofm_to_overcompute[key] = layer['output_overcompute']
        layers.append(layer)
        num_layers += 1
    return layers


def chain_dwc_conv_1x1(layers: list[dict]) -> list[dict]:
    '''function to perform L1 chaining'''
    chained_layers = []
    idx = 0
    while idx < len(layers):
        has_trailing_conv_1x1 = (idx + 1 < len(layers)) and (
            (layers[idx + 1]['op'] == 'conv_noqdq_a8w8') and
            (layers[idx + 1]['kernel'] == (1, 1)) and
            (layers[idx + 1]['stride'] == (1, 1)) and
            (layers[idx + 1]['pad'] == (0, 0))
        )
        Ci_cutoff = 800
        Ci, _, _ = layers[idx]['input']
        can_chain = (
            (layers[idx]['op'] == 'dwc') and
            has_trailing_conv_1x1 and
            (Ci < Ci_cutoff)
        )
        if can_chain:
            op = 'dwc_conv_1x1'
            layer = {
                'name': layers[idx]['name'] + ':' + layers[idx + 1]['name'],
                'input_name': layers[idx]['input_name'],
                'output_name': layers[idx + 1]['output_name'],
                'op': op,
                'input': layers[idx]['input'],
                'output': layers[idx + 1]['output'],
                'kernel': layers[idx]['kernel'],
                'stride': layers[idx]['stride'],
                'pad': layers[idx]['pad'],
                'dilation': layers[idx]['dilation'],
            }
            chained_layers.append(layer)
            idx += 2
        else:
            chained_layers.append(layers[idx])
            idx += 1
    return chained_layers


class BufferManager:
    '''buffer manager class to reuse L2 buffers'''
    def __init__(self, act_tile: int, addresses: list[int]) -> None:
        '''Initialize buffers with specific addresses'''
        self.act_tile = act_tile
        self.buffers = {address: True for address in addresses}
        self.prev_layer_out_addr = None

    def get_free_buffer(self) -> int:
        '''get a free buffer from pool'''
        for address, state in self.buffers.items():
            if state:
                self.buffers[address] = False
                return {self.act_tile: address}
        assert False

    def release_buffer(self, address: dict) -> None:
        '''release a buffer'''
        address_value = address[self.act_tile]
        if address_value in self.buffers:
            # Mark the buffer as available
            self.buffers[address_value] = True

    def set_last_output_addr(self, addr: dict) -> None:
        '''set a buffer as last used'''
        self.prev_layer_out_addr = addr

    def get_last_output_addr(self) -> int:
        '''return last used buffer'''
        assert self.prev_layer_out_addr
        return self.prev_layer_out_addr


def alloc_L2_fusion(layers: list[dict]) -> list[dict]:
    '''
    Check if fixed patterns in onnx graph can be
    L2 fused alloc buffer addresses
    '''
    structure_1_size = 5
    structure_2_size = 4
    aie_rows = 4
    C_gran = 64

    memtile_size = 3 * (2**20)

    prm_core_size = 1024
    L2_prm_buffer_size = prm_core_size * aie_rows
    L2_col1_prm_addr = 0
    L2_col3_prm_addr = 0

    print(f"Max overcompute size across all layers = {max_overcompute_size}")
    L2_act_buffer_size = max_overcompute_size

    # NOTE: This is an assumption that WGT will never be more than 64KB in core
    wgt_core_size = 64 * 1024
    L2_wgt_buffer_ping_size = wgt_core_size * aie_rows
    L2_col1_wgt_addr = [L2_prm_buffer_size, L2_prm_buffer_size+L2_wgt_buffer_ping_size]
    L2_col3_wgt_addr = [L2_prm_buffer_size, L2_prm_buffer_size+L2_wgt_buffer_ping_size]

    def buffer_size(shape: tuple[int, int, int]) -> int:
        C, Y, X = shape
        size = ceildiv(C, C_gran) * C_gran * Y * X
        return size

    def is_structure_1(idx: int) -> bool:
        '''
        Structure 1: --> CONV --> CONV --> CONV --> ADD -->
        #               |                           /
                        +--> CONV ----------------+
        '''
        layer0 = layers[idx]
        layer1 = layers[idx + 1]
        layer2 = layers[idx + 2]
        layer3 = layers[idx + 3]
        layer4 = layers[idx + 4]
        res = (
            (layer0['op'] == 'conv_noqdq_a8w8') and
            (layer1['op'] == 'conv_noqdq_a8w8') and
            (layer2['op'] == 'conv_noqdq_a8w8') and
            (layer3['op'] == 'conv_noqdq_a8w8') and
            (layer4['op'] == 'add_noqdq_a8') and
            (layer0['output_name'] == layer2['input_name']) and
            (layer2['output_name'] == layer3['input_name']) and
            (layer3['output_name'] == layer4['input1_name']) and
            (layer0['input_name'] == layer1['input_name']) and
            (layer1['output_name'] == layer4['input2_name'])
        )
        return res

    def is_structure_2(idx: int) -> bool:
        '''
        Structure 2: --> CONV --> CONV --> CONV --> ADD -->
                        |                           /
                        +-------------------------+
        '''
        layer0 = layers[idx]
        layer1 = layers[idx + 1]
        layer2 = layers[idx + 2]
        layer3 = layers[idx + 3]
        res = (
            (layer0['op'] == 'conv_noqdq_a8w8') and
            (layer1['op'] == 'conv_noqdq_a8w8') and
            (layer2['op'] == 'conv_noqdq_a8w8') and
            (layer3['op'] == 'add_noqdq_a8') and
            (layer0['output_name'] == layer1['input_name']) and
            (layer0['input_name'] == layer3['input2_name']) and
            (layer1['output_name'] == layer2['input_name']) and
            (layer2['output_name'] == layer3['input1_name'])
        )
        return res

    def can_fuse_structure_1(idx: int) -> bool:
        layer0 = layers[idx]
        layer1 = layers[idx + 1]
        layer2 = layers[idx + 2]
        layer3 = layers[idx + 3]
        layer4 = layers[idx + 4]
        res = (
            (buffer_size(layer0['input']) <= L2_act_buffer_size) and
            (buffer_size(layer0['output']) <= L2_act_buffer_size) and
            (buffer_size(layer1['output']) <= L2_act_buffer_size) and
            (buffer_size(layer2['output']) <= L2_act_buffer_size) and
            (buffer_size(layer3['output']) <= L2_act_buffer_size) and
            (buffer_size(layer4['output']) <= L2_act_buffer_size)
        )
        return res

    def can_fuse_structure_2(idx: int) -> bool:
        layer0 = layers[idx]
        layer1 = layers[idx + 1]
        layer2 = layers[idx + 2]
        layer3 = layers[idx + 3]
        res = (
            (buffer_size(layer0['input']) <= L2_act_buffer_size) and
            (buffer_size(layer0['output']) <= L2_act_buffer_size) and
            (buffer_size(layer1['input']) <= L2_act_buffer_size) and
            (buffer_size(layer2['input']) <= L2_act_buffer_size) and
            (buffer_size(layer3['input']) <= L2_act_buffer_size)
        )
        return res

    def fusion_start_idx() -> int:
        idx = 0
        while idx < len(layers) - 1:
            if idx == 0 and layers[idx]['op'] == 'conv_noqdq_a8w8':
                if buffer_size(layers[idx]['output']) <= L2_act_buffer_size:
                    return idx
                idx += 1
            elif idx == 1 and layers[idx]['op'] == 'maxpool_noqdq_a8':
                H, W, C = layers[idx]['output']
                if H*W*C <= L2_act_buffer_size:
                    return idx
                idx += 1
            elif idx == len(layers) - 2 and layers[idx]['op'] == 'gap':
                H, W, C = layers[idx]['output']
                if H*W*C <= L2_act_buffer_size:
                    return idx
                idx += 1
            elif is_structure_1(idx):
                if can_fuse_structure_1(idx):
                    return idx
                idx += structure_1_size
            elif is_structure_2(idx):
                if can_fuse_structure_2(idx):
                    return idx
                idx += structure_2_size
            else:
                assert False
        assert False

    def fusion_end_idx() -> int:
        curr_idx = fusion_start_idx()
        prev_idx = curr_idx
        while curr_idx < len(layers):
            if curr_idx == 0 and layers[curr_idx]['op'] == 'conv_noqdq_a8w8':
                if not buffer_size(layers[curr_idx]['output']) <= L2_act_buffer_size:
                    return curr_idx
                prev_idx = curr_idx
                curr_idx += 1
            elif curr_idx == 1 and layers[curr_idx]['op'] == 'maxpool_noqdq_a8':
                H, W, C = layers[curr_idx]['output']
                if not H*W*C <= L2_act_buffer_size:
                    return curr_idx
                prev_idx = curr_idx
                curr_idx += 1
            elif curr_idx == len(layers) - 2 and layers[curr_idx]['op'] == 'gap':
                H, W, C = layers[curr_idx]['output']
                if not H*W*C <= L2_act_buffer_size:
                    return curr_idx
                prev_idx = curr_idx
                curr_idx += 1
            elif curr_idx == len(layers) - 1 and layers[curr_idx]['op'] == 'conv_noqdq_a8w8':
                _, H, W = layers[curr_idx]['input']
                if H*W <= L2_act_buffer_size:
                    return curr_idx
            elif is_structure_1(curr_idx):
                if not can_fuse_structure_1(curr_idx):
                    return prev_idx
                prev_idx = curr_idx
                curr_idx += structure_1_size
            elif is_structure_2(curr_idx):
                if not can_fuse_structure_2(curr_idx):
                    return prev_idx
                prev_idx = curr_idx
                curr_idx += structure_2_size
            else:
                return prev_idx
        assert False

    start_idx = fusion_start_idx()
    end_idx = fusion_end_idx()
    fused_layers = [layer.copy() for layer in layers]
    for i in range(len(layers)):
        fused_layers[i]['enable_L2_fusion'] = False
    idx = start_idx

    L2_act_tile = 1
    L2_act_addr = [0, L2_act_buffer_size, 2*L2_act_buffer_size]

    try:
        L2_col2_prm_addr = L2_col1_wgt_addr[1] + L2_wgt_buffer_ping_size
        L2_col2_prm_addr_offset = L2_col2_prm_addr + L2_prm_buffer_size
        L2_col2_wgt_addr = [L2_col2_prm_addr_offset, L2_col2_prm_addr_offset+L2_wgt_buffer_ping_size]
        L2_prm_addr = [
            [0, L2_col1_prm_addr],
            [0, L2_col2_prm_addr],
            [2, L2_col3_prm_addr],
        ]
        L2_wgt_addr = [
            [0, L2_col1_wgt_addr],
            [0, L2_col2_wgt_addr],
            [2, L2_col3_wgt_addr],
        ]
        assert L2_col2_wgt_addr[-1] + L2_wgt_buffer_ping_size < memtile_size
    except AssertionError:
        print("param/weights buffers cannot be placed in the col 1 moving to col 0")
        L2_col2_prm_addr = L2_col1_prm_addr + L2_prm_buffer_size
        L2_col2_prm_addr_offset = L2_col2_prm_addr + L2_prm_buffer_size
        L2_col2_wgt_addr = [L2_col2_prm_addr_offset, L2_col2_prm_addr_offset+L2_wgt_buffer_ping_size]
        L2_prm_addr = {
            0: L2_col1_prm_addr,
            1: L2_col2_prm_addr,
            2: L2_col3_prm_addr
        }
        L2_wgt_addr = {
            0: L2_col1_wgt_addr,
            1: L2_col2_wgt_addr,
            2: L2_col3_wgt_addr
        }
        assert L2_col2_wgt_addr[-1] + L2_wgt_buffer_ping_size < memtile_size
    buffer_mngr = BufferManager(L2_act_tile, L2_act_addr)
    # buffer_mngr = BufferManager(L2_act_addr)
    while idx <= end_idx:
        if idx == 0 and fused_layers[idx]['op'] == 'conv_noqdq_a8w8':
            fused_layers[idx]['output_addr'] = buffer_mngr.get_free_buffer()
            fused_layers[idx]['input_addr'] = buffer_mngr.get_free_buffer()
            fused_layers[idx]['wgt_addr'] = L2_wgt_addr
            fused_layers[idx]['prm_addr'] = L2_prm_addr
            fused_layers[idx]['load_input_from_ddr'] = True
            fused_layers[idx]['enable_L2_fusion'] = True
            fused_layers[idx]['store_output_to_ddr'] = False
            buffer_mngr.set_last_output_addr(fused_layers[idx]['output_addr'])
            buffer_mngr.release_buffer(fused_layers[idx]['input_addr'])
            idx += 1
        elif idx == 1 and fused_layers[idx]['op'] == 'maxpool_noqdq_a8':
            fused_layers[idx]['output_addr'] = buffer_mngr.get_free_buffer()
            fused_layers[idx]['input_addr'] = buffer_mngr.get_last_output_addr()
            fused_layers[idx]['wgt_addr'] = L2_wgt_addr
            fused_layers[idx]['prm_addr'] = L2_prm_addr
            fused_layers[idx]['enable_L2_fusion'] = True
            fused_layers[idx]['load_input_from_ddr'] = False
            fused_layers[idx]['store_output_to_ddr'] = False
            buffer_mngr.set_last_output_addr(fused_layers[idx]['output_addr'])
            buffer_mngr.release_buffer(fused_layers[idx]['input_addr'])
            idx += 1
        elif idx == len(fused_layers)-2 and fused_layers[idx]['op'] == 'gap':
            fused_layers[idx]['output_addr'] = buffer_mngr.get_free_buffer()
            fused_layers[idx]['input_addr'] = buffer_mngr.get_last_output_addr()
            fused_layers[idx]['wgt_addr'] = L2_wgt_addr
            fused_layers[idx]['prm_addr'] = L2_prm_addr
            fused_layers[idx]['enable_L2_fusion'] = True
            fused_layers[idx]['load_input_from_ddr'] = False
            fused_layers[idx]['store_output_to_ddr'] = False
            buffer_mngr.set_last_output_addr(fused_layers[idx]['output_addr'])
            buffer_mngr.release_buffer(fused_layers[idx]['input_addr'])
            idx += 1
        elif idx == len(fused_layers)-1 and fused_layers[idx]['op'] == 'conv_noqdq_a8w8':
            fused_layers[idx]['input_addr'] = buffer_mngr.get_last_output_addr()
            fused_layers[idx]['output_addr'] = buffer_mngr.get_free_buffer()
            fused_layers[idx]['wgt_addr'] = L2_wgt_addr
            fused_layers[idx]['prm_addr'] = L2_prm_addr
            fused_layers[idx]['enable_L2_fusion'] = True
            fused_layers[idx]['load_input_from_ddr'] = False
            fused_layers[idx]['store_output_to_ddr'] = True
            idx += 1
        elif is_structure_1(idx):
            layer0 = fused_layers[idx]
            layer1 = fused_layers[idx + 1]
            layer2 = fused_layers[idx + 2]
            layer3 = fused_layers[idx + 3]
            layer4 = fused_layers[idx + 4]
            for layer in (layer0, layer1, layer2, layer3, layer4):
                layer['enable_L2_fusion'] = True
                layer['load_input_from_ddr'] = False
                layer['store_output_to_ddr'] = False
                layer['wgt_addr'] = L2_wgt_addr
                layer['prm_addr'] = L2_prm_addr
            layer0['input_addr'] = buffer_mngr.get_last_output_addr()
            layer1['input_addr'] = buffer_mngr.get_last_output_addr()
            layer0['output_addr'] = buffer_mngr.get_free_buffer()
            layer1['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer0['input_addr'])
            layer2['input_addr'] = layer0['output_addr']
            layer2['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer2['input_addr'])
            layer3['input_addr'] = layer2['output_addr']
            layer3['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer3['input_addr'])
            layer4['input0_addr'] = layer3['output_addr']
            layer4['input1_addr'] = layer1['output_addr']
            layer4['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer4['input0_addr'])
            buffer_mngr.release_buffer(layer4['input1_addr'])
            buffer_mngr.set_last_output_addr(layer4['output_addr'])
            idx += structure_1_size

        elif is_structure_2(idx):
            layer0 = fused_layers[idx]
            layer1 = fused_layers[idx + 1]
            layer2 = fused_layers[idx + 2]
            layer3 = fused_layers[idx + 3]
            for layer in (layer0, layer1, layer2, layer3):
                layer['enable_L2_fusion'] = True
                layer['load_input_from_ddr'] = False
                layer['store_output_to_ddr'] = False
                layer['wgt_addr'] = L2_wgt_addr
                layer['prm_addr'] = L2_prm_addr
            layer0['input_addr'] = buffer_mngr.get_last_output_addr()
            layer0['output_addr'] = buffer_mngr.get_free_buffer()
            layer1['input_addr'] = layer0['output_addr']
            layer1['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer0['output_addr'])
            layer2['input_addr'] = layer1['output_addr']
            layer2['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer2['input_addr'])
            layer3['input0_addr'] = layer2['output_addr']
            layer3['input1_addr'] = buffer_mngr.get_last_output_addr()
            layer3['output_addr'] = buffer_mngr.get_free_buffer()
            buffer_mngr.release_buffer(layer3['input0_addr'])
            buffer_mngr.release_buffer(layer3['input1_addr'])
            buffer_mngr.set_last_output_addr(layer3['output_addr'])
            idx += structure_2_size
        else:
            assert False
    fused_layers[start_idx]['load_input_from_ddr'] = True
    fused_layers[end_idx]['store_output_to_ddr'] = True
    return fused_layers


def save_tiling_json(layers: list[dict], filename: str) -> None:
    '''
    function to save the shapes and L2 allocs of
    onnx nodes to Json
    '''
    indexed_layers = dict(enumerate(layers))
    with open(filename, 'w', encoding="utf-8") as f:
        f.write(json.dumps(indexed_layers, sort_keys=True, indent=4))


def fuse_graph() -> None:
    '''
    main function to parse onnx graph,
    extract shape info and alloc L2
    '''
    assert (len(sys.argv) <= 2), "Usage: python resnet50.py [0-disable overcompute | 1-enable overcompute]"
    allow_overcompute = sys.argv[1] == '1' if len(sys.argv) == 2 else False
    verbose = True
    model_filename = os.path.join(CURRDIR, 'ResNet50_INT8_Model.onnx')
    model = onnx.load(model_filename)

    tag_conv_with_is_relu_inplace(model.graph)

    layers = parse_layers(model, allow_overcompute)
    # layers = chain_dwc_conv_1x1(layers)
    # layers = chain_conv_resize(layers)
    layers = alloc_L2_fusion(layers)
    out_file = 'resnet_tiling_L2_fused_with_overcompute.json' if allow_overcompute else 'resnet_tiling_L2_fused.json'
    save_tiling_json(layers, os.path.join(CURRDIR, out_file))
    for layer in layers:
        if verbose:
            op_filter: Optional[Collection[str]] = None
            key_filter = (
                'input', 'output', 'kernel', 'stride', 'pad',
                'input_subv', 'output_subv', 'X_split', 'Co_split',
                'enable_L2_fusion', 'input_is_hwc', 'output_is_hwc', 'op',
            )

            allowed_ops = set(op_filter) if op_filter else None
            if allowed_ops is None or layer['op'] in allowed_ops:
                for key in key_filter:
                    if key in layer:
                        print(f'{key}: {layer[key]}', end=' ')
                print()


if __name__ == '__main__':
    fuse_graph()
