'''
Define unit tests for conv dataflow
'''


import os
import unittest

from dmacompiler import memory_tile
from scheduler.conv.conv_common import (
    LinearOpType,
)
from scheduler.conv.conv_config_builders import (
    ConvShape,
)
from scheduler.conv.conv_L2_schedule import compile_L2_dataflow
from tiler.conv_tiler import generate_conv_mappings
from utils.utils_common import (
    L2Alloc,
)

RESNET50_SHAPES = [
    ConvShape((224, 224, 4), (112, 112, 64), (7, 7), (2, 2), (3, 3), LinearOpType.conv_A8W8_noqdq, 0),

    ConvShape((56, 56, 64), (56, 56, 64), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((56, 56, 64), (56, 56, 64), (3, 3), (1, 1), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((56, 56, 64), (56, 56, 256), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((56, 56, 256), (56, 56, 64), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((56, 56, 256), (56, 56, 128), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),

    ConvShape((56, 56, 128), (28, 28, 128), (3, 3), (2, 2), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((56, 56, 256), (28, 28, 512), (1, 1), (2, 2), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((28, 28, 128), (28, 28, 512), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((28, 28, 512), (28, 28, 128), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((28, 28, 128), (28, 28, 128), (3, 3), (1, 1), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((28, 28, 512), (28, 28, 256), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),

    ConvShape((28, 28, 256), (14, 14, 256), (3, 3), (2, 2), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((28, 28, 512), (14, 14, 1024), (1, 1), (2, 2), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((14, 14, 256), (14, 14, 1024), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((14, 14, 1024), (14, 14, 256), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((14, 14, 256), (14, 14, 256), (3, 3), (1, 1), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((14, 14, 1024), (14, 14, 512), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),

    ConvShape((14, 14, 512), (7, 7, 512), (3, 3), (2, 2), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((14, 14, 1024), (7, 7, 2048), (1, 1), (2, 2), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((7, 7, 512), (7, 7, 2048), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((7, 7, 2048), (7, 7, 512), (1, 1), (1, 1), (0, 0), LinearOpType.conv_A8W8_noqdq, 0),
    ConvShape((7, 7, 512), (7, 7, 512), (3, 3), (1, 1), (1, 1), LinearOpType.conv_A8W8_noqdq, 0),
]


class TestConv(unittest.TestCase):
    '''Test conv dataflow mapping and compilation'''

    def test_resnet50_conv_shapes(self) -> None:
        '''Compile all possible mappings of all Resnet50 layers'''
        wgt_relative_base_addr = 2 * 2**20  # 2 MiB
        wgt_reserved_size = 64 * 1024 * 4  # 64 KiB * 4 (unicast)
        wgt_ping_addr = wgt_relative_base_addr
        wgt_pong_addr = wgt_relative_base_addr + wgt_reserved_size
        prm_addr = wgt_pong_addr + wgt_reserved_size
        params = L2Alloc(
            (memory_tile(1), 0 * 2**20),
            (memory_tile(0), 0),
            (memory_tile(1), 1 * 2**20),
            [
                [memory_tile(0), wgt_ping_addr, wgt_pong_addr],
                [memory_tile(1), wgt_ping_addr, wgt_pong_addr],
                [memory_tile(2), wgt_ping_addr, wgt_pong_addr],
            ],
            [
                [memory_tile(0), prm_addr],
                [memory_tile(1), prm_addr],
                [memory_tile(2), prm_addr],
            ],
            True, True, True,
        )
        kernel_names: dict = {'run_conv_noqdq_a8w8': 0}
        kernel_includes: list[str] = ['super.hh', 'conv/conv_noqdq_a8w8_wrapper.cc']
        num_mappings = 0
        for shape in RESNET50_SHAPES:
            mappings = generate_conv_mappings(shape, False)
            for mapping in mappings:
                compile_L2_dataflow(
                    shape,
                    mapping,
                    params,
                    kernel_names,
                    kernel_includes,
                )
                num_mappings += 1
        os.remove('dma.hpp')
        os.remove('graph.hpp')
        os.remove('super.cc')
        os.remove('super.hh')
        print(f'Completed compilation of {num_mappings} mappings!')


if __name__ == '__main__':
    unittest.main()
